앎을 경계하기

Machine Learning

[Pytorch Tutorials] Image and Video - DCGAN Tutorial

양갱맨 2021. 3. 22. 13:07

Introduction

예제를 통해 DCGAN을 알아보는 튜토리얼을 진행한다. 실제 유명인들의 사진을 통해서 새로운 사람을 생성하는 Generative Adversarial Network(GAN)을 학습시킨다. DCGAN의 구현은 https://github.com/pytorch/examples 을 참고한다.

Generative Adversarial Networks

What is GAN?

GANs는 학습 데이터셋의 분포를 통해 새로운 데이터의 분포를 학습 데이터 분포와 똑같이 만들어내는 것이다. generatordiscriminator라는 두개의 명확한 모델을 만들어낸다.

generator는 학습 이미지와 비슷한 "fake" 이미지를 만들어낸다.

discriminator는 generator가 생성한 fake image output을 보고 이것이 진짜 이미지인지 생성된 가짜 이미지인지 판단한다.

트레이닝 동안에 generator는 discriminator를 속이기 위해 더 발전된 가짜 이미지를 생성해 내고, discriminator는 발전된 가짜이미지에 대해서도 진짜 학습 데이터인지 아닌지 구별해내는 방향으로 학습하게 된다.

generator와 discriminator 간의 평형은 generator는 완벽한 가짜를 생성하고 discriminator는 항상 50%의 신뢰도로 추측하는 경우이다.

xx는 training image data이고, D(x)D(x)는 data xx로부터 나온 확률을 출력하는 Discriminator network이다. D(x)D(x)에 들어가는 input image의 크기는 3*64*64(channel*Height*Width)이다.

만약 xx가 training data라면 D(x)D(x)가 높아야하고 generator가 생성한 데이터라면 낮아야한다.

D(x)D(x)는 real, fake만 구별하기 때문에 전통적인 binary classifier를 사용하면 된다.

generator는 standard normal distribution으로부터 샘플링된 zz라는 latent space vector를 input으로 한다. G(z)G(z)는 generator model로 vector zz를 data space에 매핑시킨다. GG의 목표는 학습데이터분포와 같은 형태의 fake data 분포를 추정하는 것이다.

위 식은 GAN의 loss function이다. Discriminator에 전달하는 입력에 대해 첫 번째 expectation에서는 real training data를 입력으로 하고 두 번째 expectation은 generator의 fake data를 입력으로 사용한다.

쉽게 이해하자면, 진짜 데이터에 대해서는 진짜라고 해야하고 가짜는 가짜라고 했을때 Discriminator가 좋아지는 것이고 Generator는 Discriminator가 가짜를 진짜라고 하게 될수록 잘 생성해낸다는 것이다.

What is a DCGAN?

DCGAN은 GAN 구조에서 convolution, convolutional-transpose layer를 사용한 네트워크이다.

DCGAN paper : https://arxiv.org/pdf/1511.06434.pdf

discriminator는 strided conv layers, batch norm layers, Leaky ReLU activation function을 사용한다. generator는 convolutional-transpose layer, batch norm layers, ReLU activation function을 사용한다. latent vector zz를 입력으로 하고 3*64*64 RGB 이미지를 출력으로 한다.

from __future__ import print_function
#%matplotlib inline
import argparse
import os
import random
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML

# Set random seed for reproducibility
manualSeed = 999
#manualSeed = random.randint(1, 10000) # use if you want new results
print("Random Seed: ", manualSeed)
random.seed(manualSeed)
torch.manual_seed(manualSeed)

Inputs

  • dataroot - dataset folder path.
  • workers - DataLoader를 사용하여 data를 로딩할 worker threads 수
  • batch_size - 학습에 사용할 batch size. DCGAN에서는 128 사용
  • image_size - 학습 시 이미지 사이즈.
  • nc - channel 개수
  • nz - latent vector의 길이
  • ngf - generator에 전달되는 feature map의 깊이
  • ndf - discriminator에 전달되는 feature map의 깊이
  • num_epochs - epoch 수
  • lr - learning rate
  • beta1 - adam optimizers에서 사용되는 beta1 hyperparameter. 논문에서는 0.5로 설정
  • ngpu - 사용 가능한 GPU 수
# Root directory for dataset
dataroot = "data/celeba"

# Number of workers for dataloader
workers = 2

# Batch size during training
batch_size = 128

# Spatial size of training images. All images will be resized to this
#   size using a transformer.
image_size = 64

# Number of channels in the training images. For color images this is 3
nc = 3

# Size of z latent vector (i.e. size of generator input)
nz = 100

# Size of feature maps in generator
ngf = 64

# Size of feature maps in discriminator
ndf = 64

# Number of training epochs
num_epochs = 5

# Learning rate for optimizers
lr = 0.0002

# Beta1 hyperparam for Adam optimizers
beta1 = 0.5

# Number of GPUs available. Use 0 for CPU mode.
ngpu = 1

Data

사용할 데이터셋은 Celeb0A Faces dataset을 사용한다. data를 읽어서 dataloader를 생성한다.

ImageFolder와 DataLoader를 사용해서 파이토치에서 쉽게 데이터를 다룰 수 있도록 한다.

dataset = dset.ImageFolder(root=dataroot,
                          transform=transforms.Compose([
                              transforms.Resize(image_size),
                              transforms.CenterCrop(image_size),
                              transforms.ToTensor(),
                              transforms.Normalize((0.5, 0.5, 0.5),(0.5, 0.5, 0.5))
                          ]))
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
                                        shuffle=True, num_workers=workers)

device = torch.device("cuda:0" if (torch.cuda.is_available() and ngpu>0) else "cpu")

real_batch = next(iter(dataloader))
plt.figure(figsize=(8,8))
plt.axis("off")
plt.title("Training Images")
plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(device)[:64], padding=2, normalize=True).cpu(),(1,2,0)))

Implementation

준비된 데이터셋과 파라미터셋을 사용해서 구현을 한다.

weight initialization

DCGAN 논문에 따르면, 정규분포(mean=0, std=0.02)로부터 랜덤하게 모든 모델의 가중치를 초기화한다.

def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:        # Conv 관련 레이어라면, Conv or Conv-Transpose
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1: # BatchNorm layer 초기화
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

Generator

class Generator(nn.Module):
    def __init__(self, ngpu):
        super(Generator, self).__init__()
        self.ngpu = ngpu
        self.main = nn.Sequential(
            # latent vector z에서 convolution으로 진행
            # ConvTransposed2d(in, out, ksize, ...)
            # nz = 100
            # nc = 3
            # ngf = 64
            
            # size 100 x 4 x 4
            nn.ConvTranspose2d(nz, ngf * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(True),
            
            # size 512 x 4 x 4
            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            
            # size 256 x 8 x 8
            nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            
            # size 128 x 16 x 16
            nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),
            
            # size 64 x 32 x 32
            nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),
            nn.Tanh()
        )
        
    def forward(self, input):
        return self.main(input)
netG = Generator(ngpu).to(device)

# multi-gpu 사용
if (device.type=='cuda') and (ngpu>1):
    netG = nn.DataParallel(netG, list(range(ngpu)))
    
netG.apply(weights_init)

print(netG)
Generator(
  (main): Sequential(
    (0): ConvTranspose2d(100, 512, kernel_size=(4, 4), stride=(1, 1), bias=False)
    (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): ConvTranspose2d(512, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
    (6): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (7): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): ReLU(inplace=True)
    (9): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (10): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (11): ReLU(inplace=True)
    (12): ConvTranspose2d(64, 3, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (13): Tanh()
  )
)

Discriminator

Discriminator는 input으로 이미지를 받아서 진짜인지 가짜인지 판별하는 scalar probability를 출력한다.

class Discriminator(nn.Module):
    def __init__(self, ngpu):
        super(Discriminator, self).__init__()
        self.ngpu = ngpu
        self.main = nn.Sequential(
            # nc = 3
            # ndf = 64
            
            # size 3 x 64 x 64
            nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            
            # size 64 x 32 x. 32
            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),
            
            # size 128 x 16 x 16
            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),
            
            # size 256 x 8 x 8
            nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 8),
            nn.LeakyReLU(0.2, inplace=True),
            
            # size 512 x 4 x 4
            nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )
        
    def forward(self, input):
        return self.main(input)
netD = Discriminator(ngpu).to(device)

if (device.type == 'cuda') and (ngpu > 1):
    netD = nn.DataParallel(netD, list(range(ngpu)))
    
netD.apply(weights_init)

print(netD)
Discriminator(
  (main): Sequential(
    (0): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (1): LeakyReLU(negative_slope=0.2, inplace=True)
    (2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (4): LeakyReLU(negative_slope=0.2, inplace=True)
    (5): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (6): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (7): LeakyReLU(negative_slope=0.2, inplace=True)
    (8): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (9): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (10): LeakyReLU(negative_slope=0.2, inplace=True)
    (11): Conv2d(512, 1, kernel_size=(4, 4), stride=(1, 1), bias=False)
    (12): Sigmoid()
  )
)

Loss Functions and Optimizers

D, G에 대한 loss function과 optimizer를 설정한다. Binary Cross Entropy loss(BCE loss)를 사용한다.

criterion = nn.BCELoss()

#latent vector의 배치를 생성한다.
fixed_noise = torch.randn(64, nz, 1, 1, device=device)

real_label = 1.
fake_label = 0.

optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999))

Training

part1 - Train the Discriminator

discriminator의 목표는 주어진 입력 이미지 진짜와 가짜를 판별하는 확률을 최대화하는 것이다.

처음엔 학습용 데이터셋으로 D의 forward pass를 통해 Loss log(D(x))\log(D(x))를 계산하고 backprop 과정을 거친다.

다음으로, 최근 generator로 생성한 fake sample batch를 구성해서 D의 forward pass를 통해 loss log(1D(G(z)))\log(1-D(G(z)))를 구한다. 그다음 backprop과정에서 계산된 gradient를 앞서 계산한 gradient와 누적한다. real data와 fake data로부터 계산되어 누적된 gradients를 discriminator의 optimizer에 전달한다.

part2 - Train the Generator

generator는 가짜 데이터를 더 잘 만들기 위해 log(1D(G(z)))\log(1-D(G(z)))를 최소화하는 방향으로 학습되어야한다. 그러나 학습 과정 초기에는 충분한 gradiens를 제공하지 않기 때문에, 대신 log(D(G(z)))\log(D(G(z)))를 최대화한다. discriminator를 사용해서 part 1의 generator 출력을 분류하고 실제 레이블을 ground truth로 사용해서 generator의 loss를 계산하고 backward pass에서 generator의 gradient를 계산한다. 마지막으로 optimizer를 통해 generator의 매개변수를 업데이트한다.

(cpu로 하면 굉장히 느림..)

img_list = []
G_losses = []
D_losses = []
iters = 0

print("Starting Training Loop...")

for epoch in range(num_epochs):
    for i, data in enumerate(dataloader, 0):
        # update D network : maximize log(D(x)) + log(1-D(G(z)))
        # train with all-real
        netD.zero_grad()
        real_cpu = data[0].to(device)
        b_size = real_cpu.size(0)
        label = torch.full((b_size,), real_label, dtype=torch.float, device=device)
        
        output = netD(real_cpu).view(-1)
        errD_real = criterion(output, label)
        errD_real.backward()
        D_x = output.mean().item()
        
        # train with all-fake
        noise = torch.randn(b_size, nz, 1, 1, device=device)
        fake = netG(noise)
        label.fill_(fake_label)
        output = netD(fake.detach()).view(-1)
        errD_fake = criterion(output, label)
        errD_fake.backward()
        D_G_z1 = output.mean().item()
        errD = errD_real + errD_fake
        optimizerD.step()
        
        # update G network : maximize log(D(G(z)))
        netG.zero_grad()
        label.fill_(real_label)
        output = netD(fake).view(-1)
        errG = criterion(output,label)
        errG.backward()
        D_G_z2 = output.mean().item()
        optimizerG.step()
        
        if i % 50 == 0:
            print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'
                  % (epoch, num_epochs, i, len(dataloader),
                     errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))
        G_losses.append(errG.item())
        D_losses.append(errD.item())
        
        if (iters % 500 == 0) or ((epoch == num_epochs-1) and (i == len(datalodaer)-1)):
            with torch.no_grad():
                fake = netG(fixed_noise).detach().cpu()
            img_list.append(vutils.make_grid(fake, padding=2, normalize=True))
            
        iters += 1
Starting Training Loop...
[0/5][0/1583]   Loss_D: 1.9847  Loss_G: 5.5914  D(x): 0.6004    D(G(z)): 0.6680 / 0.0062
[0/5][50/1583]  Loss_D: 0.2652  Loss_G: 15.9832 D(x): 0.8800    D(G(z)): 0.0000 / 0.0000
[0/5][100/1583] Loss_D: 0.3224  Loss_G: 9.6934  D(x): 0.8650    D(G(z)): 0.0095 / 0.0037
[0/5][150/1583] Loss_D: 0.3354  Loss_G: 5.3241  D(x): 0.8957    D(G(z)): 0.1428 / 0.0123
[0/5][200/1583] Loss_D: 0.2916  Loss_G: 3.2344  D(x): 0.8690    D(G(z)): 0.0772 / 0.0652
[0/5][250/1583] Loss_D: 0.5850  Loss_G: 6.1474  D(x): 0.8365    D(G(z)): 0.1927 / 0.0050
[0/5][300/1583] Loss_D: 3.0240  Loss_G: 3.2477  D(x): 0.1510    D(G(z)): 0.0036 / 0.0788
[0/5][350/1583] Loss_D: 0.5418  Loss_G: 3.6662  D(x): 0.6844    D(G(z)): 0.0498 / 0.0524
[0/5][400/1583] Loss_D: 0.4191  Loss_G: 3.9014  D(x): 0.8552    D(G(z)): 0.1811 / 0.0318
[0/5][450/1583] Loss_D: 0.4397  Loss_G: 6.1456  D(x): 0.9380    D(G(z)): 0.2634 / 0.0043
[0/5][500/1583] Loss_D: 0.5576  Loss_G: 2.9677  D(x): 0.6896    D(G(z)): 0.0696 / 0.0915
[0/5][550/1583] Loss_D: 0.4624  Loss_G: 5.7491  D(x): 0.8529    D(G(z)): 0.1966 / 0.0066
[0/5][600/1583] Loss_D: 2.1937  Loss_G: 7.7897  D(x): 0.2468    D(G(z)): 0.0003 / 0.0025
[0/5][650/1583] Loss_D: 0.2656  Loss_G: 5.3016  D(x): 0.8623    D(G(z)): 0.0520 / 0.0101
[0/5][700/1583] Loss_D: 0.6262  Loss_G: 6.5893  D(x): 0.8710    D(G(z)): 0.3298 / 0.0024
[0/5][750/1583] Loss_D: 0.3854  Loss_G: 4.1290  D(x): 0.7930    D(G(z)): 0.0558 / 0.0280
[0/5][800/1583] Loss_D: 0.3838  Loss_G: 3.8788  D(x): 0.8528    D(G(z)): 0.1572 / 0.0368
[0/5][850/1583] Loss_D: 0.2807  Loss_G: 3.4945  D(x): 0.8757    D(G(z)): 0.0948 / 0.0534
[0/5][900/1583] Loss_D: 0.4041  Loss_G: 4.5653  D(x): 0.8529    D(G(z)): 0.1725 / 0.0214
[0/5][950/1583] Loss_D: 0.6492  Loss_G: 5.6423  D(x): 0.8552    D(G(z)): 0.3022 / 0.0081
[0/5][1000/1583]        Loss_D: 1.5301  Loss_G: 8.6701  D(x): 0.9478    D(G(z)): 0.6222 / 0.0006
[0/5][1050/1583]        Loss_D: 0.3216  Loss_G: 4.1127  D(x): 0.8192    D(G(z)): 0.0774 / 0.0298
[0/5][1100/1583]        Loss_D: 0.3462  Loss_G: 2.9686  D(x): 0.8035    D(G(z)): 0.0803 / 0.0813
[0/5][1150/1583]        Loss_D: 0.3262  Loss_G: 5.0737  D(x): 0.8973    D(G(z)): 0.1739 / 0.0117
[0/5][1200/1583]        Loss_D: 0.6552  Loss_G: 3.3029  D(x): 0.6320    D(G(z)): 0.0344 / 0.0907
[0/5][1250/1583]        Loss_D: 0.2856  Loss_G: 4.5012  D(x): 0.9309    D(G(z)): 0.1689 / 0.0205
[0/5][1300/1583]        Loss_D: 0.9179  Loss_G: 3.9231  D(x): 0.5341    D(G(z)): 0.0226 / 0.0402
[0/5][1350/1583]        Loss_D: 0.7714  Loss_G: 2.6326  D(x): 0.6035    D(G(z)): 0.0600 / 0.1194
[0/5][1400/1583]        Loss_D: 0.9019  Loss_G: 6.1882  D(x): 0.9673    D(G(z)): 0.5248 / 0.0037
[0/5][1450/1583]        Loss_D: 0.4428  Loss_G: 4.9077  D(x): 0.8761    D(G(z)): 0.2400 / 0.0118
[0/5][1500/1583]        Loss_D: 0.4039  Loss_G: 3.9760  D(x): 0.8304    D(G(z)): 0.1394 / 0.0365
[0/5][1550/1583]        Loss_D: 0.5804  Loss_G: 4.2960  D(x): 0.8497    D(G(z)): 0.2725 / 0.0210
[1/5][0/1583]   Loss_D: 0.3821  Loss_G: 3.3792  D(x): 0.8240    D(G(z)): 0.1222 / 0.0475
[1/5][50/1583]  Loss_D: 0.7003  Loss_G: 1.7546  D(x): 0.6499    D(G(z)): 0.1080 / 0.2183
[1/5][100/1583] Loss_D: 0.8822  Loss_G: 4.5926  D(x): 0.7927    D(G(z)): 0.3824 / 0.0184
[1/5][150/1583] Loss_D: 0.4028  Loss_G: 3.2279  D(x): 0.7987    D(G(z)): 0.1057 / 0.0629
[1/5][200/1583] Loss_D: 0.4314  Loss_G: 3.3842  D(x): 0.7311    D(G(z)): 0.0661 / 0.0517
[1/5][250/1583] Loss_D: 0.2275  Loss_G: 4.3143  D(x): 0.8797    D(G(z)): 0.0607 / 0.0243
[1/5][300/1583] Loss_D: 0.4614  Loss_G: 4.9470  D(x): 0.9173    D(G(z)): 0.2759 / 0.0124
[1/5][350/1583] Loss_D: 1.3237  Loss_G: 8.0344  D(x): 0.9241    D(G(z)): 0.6221 / 0.0010
[1/5][400/1583] Loss_D: 0.9322  Loss_G: 6.5259  D(x): 0.9439    D(G(z)): 0.5017 / 0.0051
[1/5][450/1583] Loss_D: 0.4483  Loss_G: 4.0047  D(x): 0.8207    D(G(z)): 0.1705 / 0.0295
[1/5][500/1583] Loss_D: 0.7877  Loss_G: 5.2322  D(x): 0.9344    D(G(z)): 0.4512 / 0.0093
[1/5][550/1583] Loss_D: 0.5864  Loss_G: 5.7760  D(x): 0.9018    D(G(z)): 0.3078 / 0.0080
[1/5][600/1583] Loss_D: 0.5254  Loss_G: 2.8334  D(x): 0.7670    D(G(z)): 0.1487 / 0.1127
[1/5][650/1583] Loss_D: 0.5434  Loss_G: 5.3955  D(x): 0.8802    D(G(z)): 0.2904 / 0.0077
[1/5][700/1583] Loss_D: 0.7707  Loss_G: 2.6770  D(x): 0.5634    D(G(z)): 0.0373 / 0.1157
[1/5][750/1583] Loss_D: 0.6900  Loss_G: 1.4504  D(x): 0.6034    D(G(z)): 0.0552 / 0.2950
[1/5][800/1583] Loss_D: 0.4646  Loss_G: 3.1006  D(x): 0.7311    D(G(z)): 0.0799 / 0.0696
[1/5][850/1583] Loss_D: 0.4232  Loss_G: 3.6385  D(x): 0.7432    D(G(z)): 0.0575 / 0.0497
[1/5][900/1583] Loss_D: 0.8402  Loss_G: 5.6566  D(x): 0.9698    D(G(z)): 0.4983 / 0.0060
[1/5][950/1583] Loss_D: 0.3548  Loss_G: 3.6078  D(x): 0.8763    D(G(z)): 0.1749 / 0.0395
[1/5][1000/1583]        Loss_D: 0.6420  Loss_G: 4.0902  D(x): 0.8398    D(G(z)): 0.3085 / 0.0259
[1/5][1050/1583]        Loss_D: 1.9727  Loss_G: 2.7134  D(x): 0.2709    D(G(z)): 0.0099 / 0.1298
[1/5][1100/1583]        Loss_D: 1.4046  Loss_G: 0.8421  D(x): 0.3452    D(G(z)): 0.0348 / 0.5092
[1/5][1150/1583]        Loss_D: 0.5905  Loss_G: 2.2227  D(x): 0.6522    D(G(z)): 0.0642 / 0.1602
[1/5][1200/1583]        Loss_D: 0.6719  Loss_G: 4.4221  D(x): 0.8853    D(G(z)): 0.3627 / 0.0195
[1/5][1250/1583]        Loss_D: 0.8745  Loss_G: 3.9972  D(x): 0.8924    D(G(z)): 0.4437 / 0.0362
[1/5][1300/1583]        Loss_D: 0.6208  Loss_G: 4.0644  D(x): 0.8923    D(G(z)): 0.3534 / 0.0285
[1/5][1350/1583]        Loss_D: 0.5547  Loss_G: 2.7381  D(x): 0.7136    D(G(z)): 0.1350 / 0.0962
[1/5][1400/1583]        Loss_D: 0.6572  Loss_G: 1.7537  D(x): 0.6746    D(G(z)): 0.1582 / 0.2278
[1/5][1450/1583]        Loss_D: 0.6838  Loss_G: 1.6344  D(x): 0.5980    D(G(z)): 0.0645 / 0.2483
[1/5][1500/1583]        Loss_D: 0.5892  Loss_G: 4.7122  D(x): 0.9430    D(G(z)): 0.3709 / 0.0138
[1/5][1550/1583]        Loss_D: 0.4792  Loss_G: 2.5298  D(x): 0.7543    D(G(z)): 0.1329 / 0.1141
[2/5][0/1583]   Loss_D: 0.5904  Loss_G: 3.1206  D(x): 0.8520    D(G(z)): 0.3055 / 0.0614
[2/5][50/1583]  Loss_D: 0.5367  Loss_G: 1.9870  D(x): 0.6923    D(G(z)): 0.1047 / 0.1818
[2/5][100/1583] Loss_D: 0.8669  Loss_G: 1.2615  D(x): 0.5328    D(G(z)): 0.1086 / 0.3520
[2/5][150/1583] Loss_D: 0.6199  Loss_G: 2.2105  D(x): 0.6467    D(G(z)): 0.0964 / 0.1488
[2/5][200/1583] Loss_D: 1.2154  Loss_G: 1.6290  D(x): 0.4046    D(G(z)): 0.0761 / 0.2653
[2/5][250/1583] Loss_D: 0.6157  Loss_G: 2.1240  D(x): 0.6744    D(G(z)): 0.1374 / 0.1637
[2/5][300/1583] Loss_D: 0.5913  Loss_G: 2.8705  D(x): 0.8179    D(G(z)): 0.2770 / 0.0791
[2/5][350/1583] Loss_D: 0.6857  Loss_G: 2.4697  D(x): 0.7082    D(G(z)): 0.2123 / 0.1226
[2/5][400/1583] Loss_D: 0.7663  Loss_G: 1.8449  D(x): 0.6793    D(G(z)): 0.2401 / 0.1961
[2/5][450/1583] Loss_D: 2.4996  Loss_G: 0.3235  D(x): 0.1284    D(G(z)): 0.0044 / 0.7800
[2/5][500/1583] Loss_D: 0.7485  Loss_G: 1.8201  D(x): 0.6402    D(G(z)): 0.1848 / 0.1964
[2/5][550/1583] Loss_D: 0.7465  Loss_G: 4.1430  D(x): 0.8827    D(G(z)): 0.4152 / 0.0247
[2/5][600/1583] Loss_D: 0.5524  Loss_G: 2.3417  D(x): 0.7176    D(G(z)): 0.1467 / 0.1343
[2/5][650/1583] Loss_D: 0.8758  Loss_G: 3.0139  D(x): 0.8958    D(G(z)): 0.4540 / 0.0774
[2/5][700/1583] Loss_D: 0.7126  Loss_G: 3.1231  D(x): 0.8119    D(G(z)): 0.3488 / 0.0593
[2/5][750/1583] Loss_D: 1.0471  Loss_G: 4.4106  D(x): 0.8977    D(G(z)): 0.5552 / 0.0161
[2/5][800/1583] Loss_D: 0.6491  Loss_G: 3.6195  D(x): 0.8855    D(G(z)): 0.3686 / 0.0360
[2/5][850/1583] Loss_D: 0.5847  Loss_G: 2.5494  D(x): 0.7037    D(G(z)): 0.1575 / 0.1056
[2/5][900/1583] Loss_D: 0.5743  Loss_G: 2.8365  D(x): 0.7724    D(G(z)): 0.2197 / 0.0797
[2/5][950/1583] Loss_D: 0.8805  Loss_G: 5.0843  D(x): 0.9232    D(G(z)): 0.4993 / 0.0152
[2/5][1000/1583]        Loss_D: 0.5223  Loss_G: 1.5597  D(x): 0.7171    D(G(z)): 0.1324 / 0.2597
[2/5][1050/1583]        Loss_D: 0.5691  Loss_G: 1.6808  D(x): 0.6826    D(G(z)): 0.1094 / 0.2361
[2/5][1100/1583]        Loss_D: 0.9202  Loss_G: 2.5241  D(x): 0.7229    D(G(z)): 0.3759 / 0.1065
[2/5][1150/1583]        Loss_D: 0.5358  Loss_G: 2.8091  D(x): 0.8339    D(G(z)): 0.2697 / 0.0769
[2/5][1200/1583]        Loss_D: 1.0419  Loss_G: 1.0625  D(x): 0.4409    D(G(z)): 0.0617 / 0.4066
[2/5][1250/1583]        Loss_D: 1.1722  Loss_G: 1.1264  D(x): 0.4239    D(G(z)): 0.0707 / 0.3723
[2/5][1300/1583]        Loss_D: 0.6667  Loss_G: 0.7631  D(x): 0.6007    D(G(z)): 0.0807 / 0.5147
[2/5][1350/1583]        Loss_D: 0.5174  Loss_G: 2.9906  D(x): 0.8329    D(G(z)): 0.2571 / 0.0664
[2/5][1400/1583]        Loss_D: 0.7723  Loss_G: 1.6059  D(x): 0.5522    D(G(z)): 0.0609 / 0.2519
[2/5][1450/1583]        Loss_D: 0.5614  Loss_G: 1.8621  D(x): 0.6704    D(G(z)): 0.1028 / 0.1899
[2/5][1500/1583]        Loss_D: 0.7175  Loss_G: 2.3705  D(x): 0.7208    D(G(z)): 0.2681 / 0.1261
[2/5][1550/1583]        Loss_D: 0.7763  Loss_G: 3.9854  D(x): 0.7579    D(G(z)): 0.3363 / 0.0269
[3/5][0/1583]   Loss_D: 0.7060  Loss_G: 3.1357  D(x): 0.8826    D(G(z)): 0.3991 / 0.0608
[3/5][50/1583]  Loss_D: 0.7001  Loss_G: 1.8399  D(x): 0.6205    D(G(z)): 0.1307 / 0.2050
[3/5][100/1583] Loss_D: 0.6360  Loss_G: 2.4409  D(x): 0.7582    D(G(z)): 0.2572 / 0.1131
[3/5][150/1583] Loss_D: 0.5670  Loss_G: 2.6579  D(x): 0.7605    D(G(z)): 0.2156 / 0.1024
[3/5][200/1583] Loss_D: 0.7861  Loss_G: 1.6922  D(x): 0.5788    D(G(z)): 0.1118 / 0.2417
[3/5][250/1583] Loss_D: 1.6397  Loss_G: 4.1940  D(x): 0.9722    D(G(z)): 0.7472 / 0.0208
[3/5][300/1583] Loss_D: 0.5237  Loss_G: 2.3733  D(x): 0.8396    D(G(z)): 0.2561 / 0.1244
[3/5][350/1583] Loss_D: 2.1766  Loss_G: 4.2685  D(x): 0.9679    D(G(z)): 0.8276 / 0.0245
[3/5][400/1583] Loss_D: 0.5153  Loss_G: 2.7309  D(x): 0.8185    D(G(z)): 0.2429 / 0.0838
[3/5][450/1583] Loss_D: 0.8268  Loss_G: 1.4034  D(x): 0.5334    D(G(z)): 0.0975 / 0.3035
[3/5][500/1583] Loss_D: 0.6022  Loss_G: 2.8113  D(x): 0.8246    D(G(z)): 0.3081 / 0.0741
[3/5][550/1583] Loss_D: 0.6876  Loss_G: 3.4770  D(x): 0.8769    D(G(z)): 0.3842 / 0.0424
[3/5][600/1583] Loss_D: 0.5468  Loss_G: 2.2875  D(x): 0.7716    D(G(z)): 0.2153 / 0.1260
[3/5][650/1583] Loss_D: 0.6838  Loss_G: 1.6692  D(x): 0.7002    D(G(z)): 0.2266 / 0.2235
[3/5][700/1583] Loss_D: 0.7285  Loss_G: 1.8428  D(x): 0.6887    D(G(z)): 0.2422 / 0.1950
[3/5][750/1583] Loss_D: 0.5003  Loss_G: 2.4744  D(x): 0.7993    D(G(z)): 0.2120 / 0.1048
[3/5][800/1583] Loss_D: 0.9057  Loss_G: 4.0852  D(x): 0.8667    D(G(z)): 0.4780 / 0.0253
[3/5][850/1583] Loss_D: 0.6373  Loss_G: 2.5060  D(x): 0.7095    D(G(z)): 0.2056 / 0.1019
[3/5][900/1583] Loss_D: 0.3341  Loss_G: 2.8532  D(x): 0.8151    D(G(z)): 0.1038 / 0.0801
[3/5][950/1583] Loss_D: 1.1489  Loss_G: 1.0212  D(x): 0.4311    D(G(z)): 0.1474 / 0.3983
[3/5][1000/1583]        Loss_D: 0.6599  Loss_G: 1.6092  D(x): 0.5938    D(G(z)): 0.0665 / 0.2463
[3/5][1050/1583]        Loss_D: 1.0521  Loss_G: 1.7065  D(x): 0.6596    D(G(z)): 0.3813 / 0.2446
[3/5][1100/1583]        Loss_D: 0.7600  Loss_G: 4.2516  D(x): 0.8864    D(G(z)): 0.4245 / 0.0211
[3/5][1150/1583]        Loss_D: 0.5433  Loss_G: 1.9638  D(x): 0.7442    D(G(z)): 0.1832 / 0.1752
[3/5][1200/1583]        Loss_D: 1.0476  Loss_G: 3.4951  D(x): 0.8485    D(G(z)): 0.5330 / 0.0437
[3/5][1250/1583]        Loss_D: 0.4937  Loss_G: 1.7732  D(x): 0.7623    D(G(z)): 0.1674 / 0.2067
[3/5][1300/1583]        Loss_D: 0.6659  Loss_G: 3.0364  D(x): 0.8329    D(G(z)): 0.3464 / 0.0643
[3/5][1350/1583]        Loss_D: 0.5310  Loss_G: 1.9677  D(x): 0.8048    D(G(z)): 0.2339 / 0.1716
[3/5][1400/1583]        Loss_D: 0.5985  Loss_G: 2.8069  D(x): 0.8378    D(G(z)): 0.3057 / 0.0792
[3/5][1450/1583]        Loss_D: 0.5504  Loss_G: 2.4549  D(x): 0.7730    D(G(z)): 0.2165 / 0.1102
[3/5][1500/1583]        Loss_D: 0.7319  Loss_G: 3.3080  D(x): 0.8730    D(G(z)): 0.4114 / 0.0460
[3/5][1550/1583]        Loss_D: 0.5684  Loss_G: 2.4365  D(x): 0.7602    D(G(z)): 0.2184 / 0.1095
[4/5][0/1583]   Loss_D: 0.6895  Loss_G: 2.5095  D(x): 0.8091    D(G(z)): 0.3500 / 0.1000
[4/5][50/1583]  Loss_D: 0.6066  Loss_G: 3.2294  D(x): 0.8720    D(G(z)): 0.3434 / 0.0533
[4/5][100/1583] Loss_D: 0.8148  Loss_G: 1.2435  D(x): 0.5607    D(G(z)): 0.1286 / 0.3332
[4/5][150/1583] Loss_D: 0.5761  Loss_G: 1.5608  D(x): 0.6560    D(G(z)): 0.0854 / 0.2631
[4/5][200/1583] Loss_D: 0.6059  Loss_G: 2.4068  D(x): 0.8255    D(G(z)): 0.3035 / 0.1078
[4/5][250/1583] Loss_D: 0.4133  Loss_G: 2.4450  D(x): 0.7319    D(G(z)): 0.0736 / 0.1076
[4/5][300/1583] Loss_D: 0.8157  Loss_G: 2.0052  D(x): 0.6484    D(G(z)): 0.2162 / 0.1775
[4/5][350/1583] Loss_D: 0.6577  Loss_G: 2.7243  D(x): 0.7826    D(G(z)): 0.3012 / 0.0839
[4/5][400/1583] Loss_D: 0.5437  Loss_G: 2.9919  D(x): 0.8378    D(G(z)): 0.2778 / 0.0667
[4/5][450/1583] Loss_D: 0.5760  Loss_G: 1.8268  D(x): 0.7353    D(G(z)): 0.1975 / 0.2002
[4/5][500/1583] Loss_D: 1.8671  Loss_G: 2.1033  D(x): 0.6212    D(G(z)): 0.6070 / 0.1775
[4/5][550/1583] Loss_D: 0.5170  Loss_G: 2.3967  D(x): 0.7297    D(G(z)): 0.1405 / 0.1259
[4/5][600/1583] Loss_D: 0.8346  Loss_G: 3.0076  D(x): 0.8999    D(G(z)): 0.4764 / 0.0664
[4/5][650/1583] Loss_D: 1.0365  Loss_G: 0.4779  D(x): 0.4517    D(G(z)): 0.0811 / 0.6576
[4/5][700/1583] Loss_D: 0.6032  Loss_G: 2.8317  D(x): 0.8016    D(G(z)): 0.2860 / 0.0754
[4/5][750/1583] Loss_D: 0.8192  Loss_G: 3.8873  D(x): 0.9060    D(G(z)): 0.4613 / 0.0291
[4/5][800/1583] Loss_D: 0.8764  Loss_G: 1.4569  D(x): 0.4994    D(G(z)): 0.0480 / 0.2974
[4/5][850/1583] Loss_D: 0.6431  Loss_G: 3.3888  D(x): 0.8812    D(G(z)): 0.3722 / 0.0445
[4/5][900/1583] Loss_D: 1.0377  Loss_G: 0.7569  D(x): 0.4218    D(G(z)): 0.0400 / 0.5229
[4/5][950/1583] Loss_D: 0.7470  Loss_G: 1.6481  D(x): 0.5752    D(G(z)): 0.0956 / 0.2336
[4/5][1000/1583]        Loss_D: 0.6285  Loss_G: 3.4266  D(x): 0.8209    D(G(z)): 0.3082 / 0.0444
[4/5][1050/1583]        Loss_D: 3.3041  Loss_G: 0.3051  D(x): 0.0697    D(G(z)): 0.0096 / 0.7713
[4/5][1100/1583]        Loss_D: 0.9699  Loss_G: 4.2771  D(x): 0.8788    D(G(z)): 0.5109 / 0.0197
[4/5][1150/1583]        Loss_D: 0.7104  Loss_G: 1.4323  D(x): 0.6035    D(G(z)): 0.1188 / 0.2866
[4/5][1200/1583]        Loss_D: 1.0160  Loss_G: 4.1287  D(x): 0.9435    D(G(z)): 0.5681 / 0.0213
[4/5][1250/1583]        Loss_D: 0.7865  Loss_G: 1.6221  D(x): 0.6140    D(G(z)): 0.1791 / 0.2309
[4/5][1300/1583]        Loss_D: 0.6572  Loss_G: 1.4624  D(x): 0.5997    D(G(z)): 0.0739 / 0.2720
[4/5][1350/1583]        Loss_D: 0.6109  Loss_G: 2.0715  D(x): 0.6398    D(G(z)): 0.0932 / 0.1798
[4/5][1400/1583]        Loss_D: 1.1801  Loss_G: 4.6358  D(x): 0.9242    D(G(z)): 0.6060 / 0.0150
[4/5][1450/1583]        Loss_D: 0.7747  Loss_G: 3.1557  D(x): 0.8952    D(G(z)): 0.4376 / 0.0586
[4/5][1500/1583]        Loss_D: 0.6088  Loss_G: 2.1813  D(x): 0.6904    D(G(z)): 0.1688 / 0.1416
[4/5][1550/1583]        Loss_D: 0.7178  Loss_G: 2.8539  D(x): 0.8350    D(G(z)): 0.3687 / 0.0780

Result

각 모델에 대한 loss 시각화.

plt.figure(figsize=(10,5))
plt.title("Generator and Discriminator Loss During Training")
plt.plot(G_losses,label="G")
plt.plot(D_losses,label="D")
plt.xlabel("iterations")
plt.ylabel("Loss")
plt.legend()
plt.show()

Generator가 어떻게 생성하는지 이미지를 시각화한다.

#%%capture
fig = plt.figure(figsize=(8,8))
plt.axis("off")
ims = [[plt.imshow(np.transpose(i,(1,2,0)), animated=True)] for i in img_list]
ani = animation.ArtistAnimation(fig, ims, interval=1000, repeat_delay=1000, blit=True)

HTML(ani.to_jshtml())

real images와 fake images 비교