앎을 경계하기

DAFIT/910 - Pytorch를 통한 classification 입문 (1)

<DAFIT> 10 Pytorch를 통한 Classification 입문 - 01 라이브러리 준비, 데이터 전처리

양갱맨 2019. 12. 18. 00:33

fashion MNIST 데이터셋

  • TF

      from tensorflow import keras
      keras.datasets.fashion_mnist(...)
  • PyTorch

      from torchvision import datasets
      datasets.FashionMNIST('~/.pytorch/F_MNIST_data/',...)
transform = transforms.Compose([
                                transforms.ToTensor(),
                                transforms.Normalize(mean=0.5, std=0.5)
])

ToTensor() : 텐서로 변환해줌. 0~1로 변경

Normalize(mean=0.5, std=0.5) : -1 ~ 1 사이로 normalize함

grayscale이기 때문에 0.5로 해준다.

만약 3 color channel이면 (0.5, 0.5, 0.5)로 해주면 된다.

 

import torch
from torchvision import datasets, transforms

import cv2
from google.colab.patches import cv2_imshow
import numpy as np

trans = transforms.Compose([
                                transforms.ToTensor(),
                                transforms.Normalize(0.5,0.5)
])

train_dataset = datasets.FashionMNIST('~/.pytorch/F_MNIST_data/', download=True, transform=trans)
test_dataset = datasets.FashionMNIST('~/.pytorch/F_MNIST_data/', download=True, transform=trans, train=False)
#data 확인해보기
for i in range(0,10):
    cv2_imshow(np.array(train_dataset.data[i]))