앎을 경계하기

Machine Learning

[Pytorch Tutorials] Image and Video - Transfer Learning for Computer Vision Tutorial

양갱맨 2021. 3. 16. 13:57

이번 튜토리얼에서는 transfer learning(전이학습)을 사용하여 image classification용 CNN을 어떻게 학습하는지를 배운다.

transfer learning의 두 가지 주요 시나리오는 다음과 같다.

  • Finetuning the convnet
    • 랜덤 초기화 대신, imagenet 데이터셋으로 학습된 모델같이 pretrained network를 사용하여 네트워크를 초기화한다.
  • ConvNet as fixed feature extractor
    • 마지막 fully connected layer를 제외한 모든 레이어의 weight를 고정시키고, 마지막 fully connected layer를 랜덤 weight로 새롭게 구성하고 그것만 학습시킨다.
from __future__ import print_function, division

import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
import time
import os
import copy

plt.ion()

Load Data

torchvision.utils.data를 사용해서 data를 읽어온다.

개미와 벌을 분류하는 모델을 학습시킬 것이다. 각 개미와 벌 120장 학습 이미지와 75장 검증 이미지 를 사용하는데, 이는 일반화하는데 매우 작은 데이터셋이다. 그러나 transfer learning을 사용하기 때문에 괜찮은 성능을 갖는 모델을 학습시킬 수 있다.

data_transforms = {
	'train': transforms.Compose([
			transforms.RandomResizedCrop(224),
			transforms.RandomHorizontalFlip(),
			transforms.ToTensor(),
			transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
	]),
	'val' : transforms.Compose([
			transforms.Resize(256),
			transforms.CenterCrop(224),
			transforms.ToTensor(),
			transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
	])
}

data_dir = 'data/hymenoptera_data'
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x), data_transforms[x]) for x in ['train', 'val']}
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=4, shuffle=True, num_workers=4) for x in ['train', 'val']}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
class_names = image_datasets['train'].classes

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

Visualize a few images

data augmentation을 이해하기 위해 학습용 이미지 데이터를 시각화한다.

def imshow(inp, title=None):
	inp = inp.numpy().transpose((1,2,0))
	mean = np.array([0.485, 0.456, 0.406])
	std = np.array([0.229, 0.224, 0.225])
	inp = std * inp + mean
	inp = np.clip(inp, 0, 1)
	plt.imshow(inp)
	if title is not None:
		plt.title(title)
	plt.pause(0.001)

inputs, classes = next(iter(dataloaders['train']))
out = torchvision.utils.make_grid(inputs)
imshow(out, title=[class_names[x] for x in classes])

Training the model

def train_model(model, criterion, optimizer, scheduler, num_epochs=25):
	since = time.time()
	best_model_wts = copy.deepcopy(model.state_dict())
	best_acc = 0.0

	for epoch in range(num_epochs):
		print('Epoch {}/{}'.format(epoch, num_epochs -1))
		print('-'*10)

		for phase in ['train', 'val']:
			if phase == 'train':
				model.train()
			else:
				model.eval()

			running_loss = 0.0
			running_corrects = 0

			for inputs, labels in dataloaders[phase]:
				inputs = inputs.to(device)
				labels = labels.to(device)

				optimizer.zero_grad()

				with torch.set_grad_enabled(phase == 'train'):
					outputs = model(inputs)
					_, preds = torch.max(outputs, 1)
					loss = criterion(outputs, labels)
					
					if phase == 'train':
						loss.backward()
						optimizer.step()

				running_loss += loss.item() * inputs.size(0)
				running_corrects += torch.sum(preds == labels.data)
			if phase == 'train':
				scheduler.step()

			epoch_loss = running_loss / dataset_sizes[phase]
			epoch_acc = running_corrects.double() / dataset_sizes[phase]

			print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))

			if phase == 'val' and epoch_acc > best_acc:
				best_acc = epoch_acc
				best_model_wts = copy.deepcopy(model.state_dict())

		print()

	time_elapsed = time.time()
	print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed//60, time_elapsed % 60))
	print('Best val Acc : {:4f}'.format(best_acc))
	
	model.load_state_dict(best_model_wts)
	return model

예측 결과 시각화

def visualize_model(model, num_images=6):
	was_training = model.training
	model.eval()
	images_so_far = 0
	fig = plt.figure()

	with torch.no_grad():
		for i, (inputs, labels) in enumerate(dataloaders['val']):
			inputs = inputs.to(device)
			labels = labels.to(device)

			outputs = model(inputs)
			_, preds = torch.max(outputs, 1)

			for j in range(inputs.size()[0]):
				images_so_far += 1
				ax = plt.subplot(num_images//2, 2, images_so_far)
				ax.axis('off')
				ax.set_title('predicted: {}'.format(class_names[preds[j]]))
				imshow(inputs.cpu().data[j])

				if images_so_far == num_images:
					model.train(mode=was_training)
					return
		model.train(mode=was_training)

Finetuning the convnet

model_ft = models.resnet18(pretrained=True)
num_ftrs = model_ft.fc.in_features
model_ft.fc = nn.Linear(num_ftrs, 2)

model_ft = model_ft.to(device)

criterion = nn.CrossEntropyLoss()

optimizer_ft = optim.SGD(model_ft.parameters(), lr=0.001, momentum=0.9)

exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)

Train and evaluate

model_ft = train_model(model_ft, criterion, optimizer_ft, exp_lr_scheduler,
                       num_epochs=25)
Epoch 0/24
----------
train Loss: 0.4362 Acc: 0.7787
val Loss: 0.2196 Acc: 0.9020

Epoch 1/24
----------
train Loss: 0.4087 Acc: 0.8443
val Loss: 0.5274 Acc: 0.8301

Epoch 2/24
----------
train Loss: 0.6056 Acc: 0.7418
val Loss: 0.5974 Acc: 0.7712

Epoch 3/24
----------
train Loss: 0.6581 Acc: 0.7582
val Loss: 0.3794 Acc: 0.8497

Epoch 4/24
----------
train Loss: 0.4117 Acc: 0.8238
val Loss: 0.3582 Acc: 0.8758

Epoch 5/24
----------
train Loss: 0.4863 Acc: 0.8279
val Loss: 0.6489 Acc: 0.8366

Epoch 6/24
----------
train Loss: 0.5448 Acc: 0.8402
val Loss: 0.3037 Acc: 0.9150

Epoch 7/24
----------
train Loss: 0.3368 Acc: 0.8484
val Loss: 0.4420 Acc: 0.8431

Epoch 8/24
----------
train Loss: 0.2848 Acc: 0.8811
val Loss: 0.2793 Acc: 0.9085

Epoch 9/24
----------
train Loss: 0.2802 Acc: 0.8852
val Loss: 0.2977 Acc: 0.9085

Epoch 10/24
----------
train Loss: 0.2903 Acc: 0.8566
val Loss: 0.2894 Acc: 0.9020

Epoch 11/24
----------
train Loss: 0.3842 Acc: 0.8484
val Loss: 0.4083 Acc: 0.8627

Epoch 12/24
----------
train Loss: 0.2663 Acc: 0.8770
val Loss: 0.2967 Acc: 0.8954

Epoch 13/24
----------
train Loss: 0.3077 Acc: 0.8689
val Loss: 0.3020 Acc: 0.8954

Epoch 14/24
----------
train Loss: 0.3163 Acc: 0.8566
val Loss: 0.3288 Acc: 0.8954

Epoch 15/24
----------
train Loss: 0.2852 Acc: 0.8689
val Loss: 0.2925 Acc: 0.8889

Epoch 16/24
----------
train Loss: 0.2880 Acc: 0.8770
val Loss: 0.3031 Acc: 0.8954

Epoch 17/24
----------
train Loss: 0.2863 Acc: 0.8689
val Loss: 0.2741 Acc: 0.9020

Epoch 18/24
----------
train Loss: 0.3598 Acc: 0.8484
val Loss: 0.3041 Acc: 0.8954

Epoch 19/24
----------
train Loss: 0.2486 Acc: 0.8934
val Loss: 0.2920 Acc: 0.9020

Epoch 20/24
----------
train Loss: 0.2438 Acc: 0.9016
val Loss: 0.2994 Acc: 0.8954

Epoch 21/24
----------
train Loss: 0.2575 Acc: 0.8811
val Loss: 0.3426 Acc: 0.8824

Epoch 22/24
----------
train Loss: 0.2722 Acc: 0.8648
val Loss: 0.2950 Acc: 0.8889

Epoch 23/24
----------
train Loss: 0.2486 Acc: 0.9016
val Loss: 0.2783 Acc: 0.9085

Epoch 24/24
----------
train Loss: 0.2748 Acc: 0.8770
val Loss: 0.3877 Acc: 0.8758

Training complete in 26931165m 55s
Best val Acc : 0.915033
visualize_model(model_ft)

feature extractor를 고정시켜서 학습

이번에는 convolution layer의 파라미터를 고정시키고 학습을 진행한다.

model_conv = torchvision.models.resnet18(pretrained=True)
for param in model_conv.parameters():
    param.requires_grad = False

# Parameters of newly constructed modules have requires_grad=True by default
num_ftrs = model_conv.fc.in_features
model_conv.fc = nn.Linear(num_ftrs, 2)

model_conv = model_conv.to(device)

criterion = nn.CrossEntropyLoss()

# Observe that only parameters of final layer are being optimized as
# opposed to before.
optimizer_conv = optim.SGD(model_conv.fc.parameters(), lr=0.001, momentum=0.9)

# Decay LR by a factor of 0.1 every 7 epochs
exp_lr_scheduler = lr_scheduler.StepLR(optimizer_conv, step_size=7, gamma=0.1)
model_conv = train_model(model_conv, criterion, optimizer_conv,
                         exp_lr_scheduler, num_epochs=25)
Epoch 0/24
----------
train Loss: 0.6198 Acc: 0.6926
val Loss: 0.2351 Acc: 0.9346

Epoch 1/24
----------
train Loss: 0.5749 Acc: 0.7254
val Loss: 0.1985 Acc: 0.9216

Epoch 2/24
----------
train Loss: 0.4445 Acc: 0.8033
val Loss: 0.2443 Acc: 0.9020

Epoch 3/24
----------
train Loss: 0.3505 Acc: 0.8730
val Loss: 0.3841 Acc: 0.8562

Epoch 4/24
----------
train Loss: 0.4937 Acc: 0.7910
val Loss: 0.2694 Acc: 0.9020

Epoch 5/24
----------
train Loss: 0.5425 Acc: 0.7869
val Loss: 0.2807 Acc: 0.9085

Epoch 6/24
----------
train Loss: 0.5045 Acc: 0.7869
val Loss: 0.1828 Acc: 0.9477

Epoch 7/24
----------
train Loss: 0.4584 Acc: 0.7992
val Loss: 0.1709 Acc: 0.9542

Epoch 8/24
----------
train Loss: 0.3065 Acc: 0.8648
val Loss: 0.1867 Acc: 0.9542

Epoch 9/24
----------
train Loss: 0.3182 Acc: 0.8689
val Loss: 0.1881 Acc: 0.9477

Epoch 10/24
----------
train Loss: 0.3183 Acc: 0.8402
val Loss: 0.1840 Acc: 0.9608

Epoch 11/24
----------
train Loss: 0.3105 Acc: 0.8689
val Loss: 0.1691 Acc: 0.9608

Epoch 12/24
----------
train Loss: 0.4099 Acc: 0.8156
val Loss: 0.1992 Acc: 0.9412

Epoch 13/24
----------
train Loss: 0.3504 Acc: 0.8361
val Loss: 0.1898 Acc: 0.9412

Epoch 14/24
----------
train Loss: 0.3777 Acc: 0.8320
val Loss: 0.1912 Acc: 0.9477

Epoch 15/24
----------
train Loss: 0.3123 Acc: 0.8648
val Loss: 0.1818 Acc: 0.9608

Epoch 16/24
----------
train Loss: 0.3374 Acc: 0.8443
val Loss: 0.1955 Acc: 0.9477

Epoch 17/24
----------
train Loss: 0.3657 Acc: 0.8074
val Loss: 0.1858 Acc: 0.9346

Epoch 18/24
----------
train Loss: 0.3461 Acc: 0.8566
val Loss: 0.2153 Acc: 0.9281

Epoch 19/24
----------
train Loss: 0.3296 Acc: 0.8525
val Loss: 0.1992 Acc: 0.9412

Epoch 20/24
----------
train Loss: 0.4027 Acc: 0.8279
val Loss: 0.1844 Acc: 0.9542

Epoch 21/24
----------
train Loss: 0.2992 Acc: 0.8730
val Loss: 0.1966 Acc: 0.9477

Epoch 22/24
----------
train Loss: 0.3217 Acc: 0.8730
val Loss: 0.1856 Acc: 0.9542

Epoch 23/24
----------
train Loss: 0.3621 Acc: 0.8402
val Loss: 0.2070 Acc: 0.9412

Epoch 24/24
----------
train Loss: 0.3389 Acc: 0.8525
val Loss: 0.1876 Acc: 0.9477

Training complete in 26931172m 3s
Best val Acc : 0.960784
visualize_model(model_conv)

plt.ioff()
plt.show()