😃Classification
torch를 이용한 간단한 classification 모델을 구현해봅시다.
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torchvision
from torchvision import transforms, datasets
# Colab
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
# Data Download
train_data = datasets.CIFAR10('./data', download= True, train = True, transform = transforms.ToTensor())
class MyDataset(Dataset):
def __init__(self, data = train_data.data, label = train_data.targets):
self.x = torch.tensor( data / 255. , dtype = torch.float).permute(0, 3, 1, 2)
self.y = torch.tensor(label, dtype = torch.long)
def __len__(self):
return self.x.shape[0]
def __getitem__(self, index):
return self.x[index], self.y[index]
def prepare_loaders(datas = train_data.data, labels = train_data.targets, index_num = 30000, batch = 128*2):
train = datas[:index_num]
valid = datas[index_num:]
train_label = labels[:index_num] # 30000 labels
valid_label = labels[index_num:] # 20000 labels
train_ds = MyDataset(data = train, label = train_label)
valid_ds = MyDataset(data = valid, label = valid_label)
train_loader = DataLoader(train_ds, batch_size = batch, shuffle = True)
valid_loader = DataLoader(valid_ds, batch_size = batch, shuffle = False)
return train_loader, valid_loader
train_loader, valid_loader = prepare_loaders()
data = next(iter(train_loader))
class Model(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(in_channels = 3, out_channels = 16, kernel_size = 3, stride = 1, padding = 1)
self.conv2 = nn.Conv2d(in_channels = 16, out_channels = 32, kernel_size = 7, stride = 1, padding = 0)
self.pool = nn.MaxPool2d(2, 2)
k = 32*13*13
self.seq = nn.Sequential(
nn.Linear(k, 512),
nn.ReLU(),
nn.Linear(512, 10),
nn.LogSoftmax(dim=-1),
)
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
x = self.pool(x)
y = torch.flatten(x, start_dim = 1)
y = self.seq(y)
return y
model = Model().to(device)
loss_fn = nn.NLLLoss()
optimizer = torch.optim.Adam(model.parameters())
def train_one_epoch(model = model, dataloader = train_loader, loss_fn = loss_fn, optimizer = optimizer, device = device, epoch = 1):
model.train()
train_loss, dataset_size = 0, 0
preds, trues = [], []
bar = tqdm(dataloader, total= len(dataloader))
for data in bar:
x = data[0].to(device)
bs = x.shape[0]
y_true = data[1].to(device)
y_pred = model(x)
loss = loss_fn(y_pred, y_true)
# backprop
optimizer.zero_grad()
loss.backward()
optimizer.step()
dataset_size += bs # 실시간으로 크기가 update
train_loss += (loss.item() * bs) # batch 단위 loss가 누적
train_epoch_loss = train_loss / dataset_size # 실시간 train_loss
preds.append(y_pred)
trues.append(y_true)
preds_cat = torch.cat(preds, dim = 0)
trues_cat = torch.cat(trues, dim = 0)
train_acc = 100 * (trues_cat == torch.argmax(preds_cat, dim=1)).sum().detach().cpu().item() / dataset_size # 전체 맞은 개수
bar.set_description(f"Epoch : [{epoch:02d}] | Train Loss : [{train_epoch_loss:.3e}] | Accuracy : [{train_acc:.2f}]")
return train_epoch_loss, train_acc
@torch.no_grad()
def valid_one_epoch(model = model, dataloader = valid_loader, loss_fn = loss_fn, device = device, epoch = 1):
model.eval()
valid_loss, dataset_size = 0, 0
preds, trues = [], []
bar = tqdm(dataloader, total= len(dataloader))
with torch.no_grad():
for data in bar:
x = data[0].to(device) # x: [bs, 3, 32, 32]
bs = x.shape[0]
y_true = data[1].to(device) # y_true: [bs]
y_pred = model(x) # y_pred: [bs, 10]
loss = loss_fn(y_pred, y_true)
dataset_size += bs # 실시간으로 크기가 update
valid_loss += (loss.item() * bs) # batch 단위 loss가 누적
valid_epoch_loss = valid_loss / dataset_size # 실시간 train_loss
preds.append(y_pred) # preds: [ [bs, 10], [bs, 10], ...]
trues.append(y_true) # trues: [ [bs], [bs], [bs], ...]
preds_cat = torch.cat(preds, dim = 0) # preds_cat: [total_bs, 10]
trues_cat = torch.cat(trues, dim = 0) # trues_cat: [total_bs]
val_acc = 100 * (trues_cat == torch.argmax(preds_cat, dim=1)).sum().detach().cpu().item() / dataset_size # 전체 맞은 개수
bar.set_description(f"Epoch : [{epoch:02d}] | Val Loss : [{valid_epoch_loss:.3e}] | Accuracy : [{val_acc:.2f}]")
return valid_epoch_loss, val_acc
def run_train(model = model, train_loader = train_loader, valid_loader = valid_loader,
loss_fn = loss_fn, optimizer = optimizer, device = device, n_epochs = 150):
train_hs, valid_hs, train_accs, valid_accs = [], [], [], []
print_iter= 10
lowest_loss, lowest_epoch = np.inf, np.inf
early_stop = 20
for epoch in range(n_epochs):
train_loss, train_acc = train_one_epoch(model = model, optimizer = optimizer,
dataloader = train_loader,
loss_fn = loss_fn,
device = device, epoch = epoch)
valid_loss, valid_acc = valid_one_epoch(model = model,
dataloader = valid_loader,
loss_fn = loss_fn,
device = device, epoch = epoch)
train_hs.append(train_loss)
train_accs.append(train_acc)
valid_hs.append(valid_loss)
valid_accs.append(valid_acc)
# 모니터링
if (epoch + 1) % print_iter == 0:
print(f"Ep:[{epoch}]|TL:{train_loss:.3e}|VL:{valid_loss:.3e}|LL:{lowest_loss:.3e}")
# Lowest Loss 갱신 -> valid loss 기준
if valid_loss < lowest_loss:
lowest_loss = valid_loss
lowest_epoch = epoch
torch.save(model.state_dict(), './model_classification.bin')
else:
if early_stop > 0 and lowest_epoch + early_stop < epoch +1:
print("Early Stopping..!")
break
print()
print("The Best Validation Loss= %.3e at %d Epoch" % (lowest_loss, lowest_epoch))
# model load
model.load_state_dict(torch.load('./model_classification.bin'))
# result
result = dict()
result["Train Loss"] = train_hs
result["Valid Loss"] = valid_hs
result["Train Accs"] = train_accs
result["Valid Accs"] = valid_accs
return model, result
model, result = run_train()
def plot_loss(res):
## Train/Valid Loss History Visualization
plot_from = 0
plt.figure(figsize=(20, 10))
plt.title("Train/Valid Loss History", fontsize = 20)
plt.plot(
range(0, len(res['Train Loss'][plot_from:])),
res['Train Loss'][plot_from:],
label = 'Train Loss'
)
plt.plot(
range(0, len(res['Valid Loss'][plot_from:])),
res['Valid Loss'][plot_from:],
label = 'Valid Loss'
)
plt.legend()
plt.yscale('log')
plt.grid(True)
plt.show()
def plot_acc(res):
# Train/Valid Accuracy History Visualization
plot_from = 0
plt.figure(figsize=(20, 10))
plt.title("Train/Valid Accuracy History", fontsize = 20)
plt.plot(
range(0, len(res['Train Accs'])),
res['Train Accs'],
label = 'Train Accs'
)
plt.plot(
range(0, len(res['Valid Accs'])),
res['Valid Accs'],
label = 'Valid Accs'
)
plt.legend()
# plt.yscale('log')
plt.grid(True)
plt.show()
plot_loss(result)
plot_acc(result)
Last updated