이번 튜토리얼에서는 transfer learning(전이학습)을 사용하여 image classification용 CNN을 어떻게 학습하는지를 배운다.
transfer learning의 두 가지 주요 시나리오는 다음과 같다.
- Finetuning the convnet
- 랜덤 초기화 대신, imagenet 데이터셋으로 학습된 모델같이 pretrained network를 사용하여 네트워크를 초기화한다.
- ConvNet as fixed feature extractor
- 마지막 fully connected layer를 제외한 모든 레이어의 weight를 고정시키고, 마지막 fully connected layer를 랜덤 weight로 새롭게 구성하고 그것만 학습시킨다.
from __future__ import print_function, division
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
import time
import os
import copy
plt.ion()
Load Data
torchvision.utils.data
를 사용해서 data를 읽어온다.
개미와 벌을 분류하는 모델을 학습시킬 것이다. 각 개미와 벌 120장 학습 이미지와 75장 검증 이미지 를 사용하는데, 이는 일반화하는데 매우 작은 데이터셋이다. 그러나 transfer learning을 사용하기 때문에 괜찮은 성능을 갖는 모델을 학습시킬 수 있다.
data_transforms = {
'train': transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
'val' : transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
}
data_dir = 'data/hymenoptera_data'
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x), data_transforms[x]) for x in ['train', 'val']}
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=4, shuffle=True, num_workers=4) for x in ['train', 'val']}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
class_names = image_datasets['train'].classes
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
Visualize a few images
data augmentation을 이해하기 위해 학습용 이미지 데이터를 시각화한다.
def imshow(inp, title=None):
inp = inp.numpy().transpose((1,2,0))
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])
inp = std * inp + mean
inp = np.clip(inp, 0, 1)
plt.imshow(inp)
if title is not None:
plt.title(title)
plt.pause(0.001)
inputs, classes = next(iter(dataloaders['train']))
out = torchvision.utils.make_grid(inputs)
imshow(out, title=[class_names[x] for x in classes])
Training the model
def train_model(model, criterion, optimizer, scheduler, num_epochs=25):
since = time.time()
best_model_wts = copy.deepcopy(model.state_dict())
best_acc = 0.0
for epoch in range(num_epochs):
print('Epoch {}/{}'.format(epoch, num_epochs -1))
print('-'*10)
for phase in ['train', 'val']:
if phase == 'train':
model.train()
else:
model.eval()
running_loss = 0.0
running_corrects = 0
for inputs, labels in dataloaders[phase]:
inputs = inputs.to(device)
labels = labels.to(device)
optimizer.zero_grad()
with torch.set_grad_enabled(phase == 'train'):
outputs = model(inputs)
_, preds = torch.max(outputs, 1)
loss = criterion(outputs, labels)
if phase == 'train':
loss.backward()
optimizer.step()
running_loss += loss.item() * inputs.size(0)
running_corrects += torch.sum(preds == labels.data)
if phase == 'train':
scheduler.step()
epoch_loss = running_loss / dataset_sizes[phase]
epoch_acc = running_corrects.double() / dataset_sizes[phase]
print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))
if phase == 'val' and epoch_acc > best_acc:
best_acc = epoch_acc
best_model_wts = copy.deepcopy(model.state_dict())
print()
time_elapsed = time.time()
print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed//60, time_elapsed % 60))
print('Best val Acc : {:4f}'.format(best_acc))
model.load_state_dict(best_model_wts)
return model
예측 결과 시각화
def visualize_model(model, num_images=6):
was_training = model.training
model.eval()
images_so_far = 0
fig = plt.figure()
with torch.no_grad():
for i, (inputs, labels) in enumerate(dataloaders['val']):
inputs = inputs.to(device)
labels = labels.to(device)
outputs = model(inputs)
_, preds = torch.max(outputs, 1)
for j in range(inputs.size()[0]):
images_so_far += 1
ax = plt.subplot(num_images//2, 2, images_so_far)
ax.axis('off')
ax.set_title('predicted: {}'.format(class_names[preds[j]]))
imshow(inputs.cpu().data[j])
if images_so_far == num_images:
model.train(mode=was_training)
return
model.train(mode=was_training)
Finetuning the convnet
model_ft = models.resnet18(pretrained=True)
num_ftrs = model_ft.fc.in_features
model_ft.fc = nn.Linear(num_ftrs, 2)
model_ft = model_ft.to(device)
criterion = nn.CrossEntropyLoss()
optimizer_ft = optim.SGD(model_ft.parameters(), lr=0.001, momentum=0.9)
exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)
Train and evaluate
model_ft = train_model(model_ft, criterion, optimizer_ft, exp_lr_scheduler,
num_epochs=25)
Epoch 0/24
----------
train Loss: 0.4362 Acc: 0.7787
val Loss: 0.2196 Acc: 0.9020
Epoch 1/24
----------
train Loss: 0.4087 Acc: 0.8443
val Loss: 0.5274 Acc: 0.8301
Epoch 2/24
----------
train Loss: 0.6056 Acc: 0.7418
val Loss: 0.5974 Acc: 0.7712
Epoch 3/24
----------
train Loss: 0.6581 Acc: 0.7582
val Loss: 0.3794 Acc: 0.8497
Epoch 4/24
----------
train Loss: 0.4117 Acc: 0.8238
val Loss: 0.3582 Acc: 0.8758
Epoch 5/24
----------
train Loss: 0.4863 Acc: 0.8279
val Loss: 0.6489 Acc: 0.8366
Epoch 6/24
----------
train Loss: 0.5448 Acc: 0.8402
val Loss: 0.3037 Acc: 0.9150
Epoch 7/24
----------
train Loss: 0.3368 Acc: 0.8484
val Loss: 0.4420 Acc: 0.8431
Epoch 8/24
----------
train Loss: 0.2848 Acc: 0.8811
val Loss: 0.2793 Acc: 0.9085
Epoch 9/24
----------
train Loss: 0.2802 Acc: 0.8852
val Loss: 0.2977 Acc: 0.9085
Epoch 10/24
----------
train Loss: 0.2903 Acc: 0.8566
val Loss: 0.2894 Acc: 0.9020
Epoch 11/24
----------
train Loss: 0.3842 Acc: 0.8484
val Loss: 0.4083 Acc: 0.8627
Epoch 12/24
----------
train Loss: 0.2663 Acc: 0.8770
val Loss: 0.2967 Acc: 0.8954
Epoch 13/24
----------
train Loss: 0.3077 Acc: 0.8689
val Loss: 0.3020 Acc: 0.8954
Epoch 14/24
----------
train Loss: 0.3163 Acc: 0.8566
val Loss: 0.3288 Acc: 0.8954
Epoch 15/24
----------
train Loss: 0.2852 Acc: 0.8689
val Loss: 0.2925 Acc: 0.8889
Epoch 16/24
----------
train Loss: 0.2880 Acc: 0.8770
val Loss: 0.3031 Acc: 0.8954
Epoch 17/24
----------
train Loss: 0.2863 Acc: 0.8689
val Loss: 0.2741 Acc: 0.9020
Epoch 18/24
----------
train Loss: 0.3598 Acc: 0.8484
val Loss: 0.3041 Acc: 0.8954
Epoch 19/24
----------
train Loss: 0.2486 Acc: 0.8934
val Loss: 0.2920 Acc: 0.9020
Epoch 20/24
----------
train Loss: 0.2438 Acc: 0.9016
val Loss: 0.2994 Acc: 0.8954
Epoch 21/24
----------
train Loss: 0.2575 Acc: 0.8811
val Loss: 0.3426 Acc: 0.8824
Epoch 22/24
----------
train Loss: 0.2722 Acc: 0.8648
val Loss: 0.2950 Acc: 0.8889
Epoch 23/24
----------
train Loss: 0.2486 Acc: 0.9016
val Loss: 0.2783 Acc: 0.9085
Epoch 24/24
----------
train Loss: 0.2748 Acc: 0.8770
val Loss: 0.3877 Acc: 0.8758
Training complete in 26931165m 55s
Best val Acc : 0.915033
visualize_model(model_ft)
feature extractor를 고정시켜서 학습
이번에는 convolution layer의 파라미터를 고정시키고 학습을 진행한다.
model_conv = torchvision.models.resnet18(pretrained=True)
for param in model_conv.parameters():
param.requires_grad = False
# Parameters of newly constructed modules have requires_grad=True by default
num_ftrs = model_conv.fc.in_features
model_conv.fc = nn.Linear(num_ftrs, 2)
model_conv = model_conv.to(device)
criterion = nn.CrossEntropyLoss()
# Observe that only parameters of final layer are being optimized as
# opposed to before.
optimizer_conv = optim.SGD(model_conv.fc.parameters(), lr=0.001, momentum=0.9)
# Decay LR by a factor of 0.1 every 7 epochs
exp_lr_scheduler = lr_scheduler.StepLR(optimizer_conv, step_size=7, gamma=0.1)
model_conv = train_model(model_conv, criterion, optimizer_conv,
exp_lr_scheduler, num_epochs=25)
Epoch 0/24
----------
train Loss: 0.6198 Acc: 0.6926
val Loss: 0.2351 Acc: 0.9346
Epoch 1/24
----------
train Loss: 0.5749 Acc: 0.7254
val Loss: 0.1985 Acc: 0.9216
Epoch 2/24
----------
train Loss: 0.4445 Acc: 0.8033
val Loss: 0.2443 Acc: 0.9020
Epoch 3/24
----------
train Loss: 0.3505 Acc: 0.8730
val Loss: 0.3841 Acc: 0.8562
Epoch 4/24
----------
train Loss: 0.4937 Acc: 0.7910
val Loss: 0.2694 Acc: 0.9020
Epoch 5/24
----------
train Loss: 0.5425 Acc: 0.7869
val Loss: 0.2807 Acc: 0.9085
Epoch 6/24
----------
train Loss: 0.5045 Acc: 0.7869
val Loss: 0.1828 Acc: 0.9477
Epoch 7/24
----------
train Loss: 0.4584 Acc: 0.7992
val Loss: 0.1709 Acc: 0.9542
Epoch 8/24
----------
train Loss: 0.3065 Acc: 0.8648
val Loss: 0.1867 Acc: 0.9542
Epoch 9/24
----------
train Loss: 0.3182 Acc: 0.8689
val Loss: 0.1881 Acc: 0.9477
Epoch 10/24
----------
train Loss: 0.3183 Acc: 0.8402
val Loss: 0.1840 Acc: 0.9608
Epoch 11/24
----------
train Loss: 0.3105 Acc: 0.8689
val Loss: 0.1691 Acc: 0.9608
Epoch 12/24
----------
train Loss: 0.4099 Acc: 0.8156
val Loss: 0.1992 Acc: 0.9412
Epoch 13/24
----------
train Loss: 0.3504 Acc: 0.8361
val Loss: 0.1898 Acc: 0.9412
Epoch 14/24
----------
train Loss: 0.3777 Acc: 0.8320
val Loss: 0.1912 Acc: 0.9477
Epoch 15/24
----------
train Loss: 0.3123 Acc: 0.8648
val Loss: 0.1818 Acc: 0.9608
Epoch 16/24
----------
train Loss: 0.3374 Acc: 0.8443
val Loss: 0.1955 Acc: 0.9477
Epoch 17/24
----------
train Loss: 0.3657 Acc: 0.8074
val Loss: 0.1858 Acc: 0.9346
Epoch 18/24
----------
train Loss: 0.3461 Acc: 0.8566
val Loss: 0.2153 Acc: 0.9281
Epoch 19/24
----------
train Loss: 0.3296 Acc: 0.8525
val Loss: 0.1992 Acc: 0.9412
Epoch 20/24
----------
train Loss: 0.4027 Acc: 0.8279
val Loss: 0.1844 Acc: 0.9542
Epoch 21/24
----------
train Loss: 0.2992 Acc: 0.8730
val Loss: 0.1966 Acc: 0.9477
Epoch 22/24
----------
train Loss: 0.3217 Acc: 0.8730
val Loss: 0.1856 Acc: 0.9542
Epoch 23/24
----------
train Loss: 0.3621 Acc: 0.8402
val Loss: 0.2070 Acc: 0.9412
Epoch 24/24
----------
train Loss: 0.3389 Acc: 0.8525
val Loss: 0.1876 Acc: 0.9477
Training complete in 26931172m 3s
Best val Acc : 0.960784
visualize_model(model_conv)
plt.ioff()
plt.show()
Uploaded by Notion2Tistory v1.1.0