앎을 경계하기

Machine Learning/scikit-learn diary

3. 사이킷런의 Estimator object와 그 설정

양갱맨 2021. 9. 28. 17:19

사이킷런의 irist dataset(붓꽃데이터)는 (샘플 수, 피쳐 수) 사이즈로 구성되어 있다.

iris는 150개의 데이터를 가지고 있고 이 각 데이터는 4개의 features로 이루어져 있다.

features : 꽃받침과 꽃잎의 길이, 너비

>>> from sklearn import datasets
>>> iris = datasets.load_iris()
>>> data = iris.data
>>> data.shape
(150, 4)

지난 번에 함께 보았던 digits 데이터셋은 어떨까?
digits데이터는 iris와 다르게 이미지 데이터였다.

digits = datasets.load_digits()
digits.images.shape
(1797, 8, 8)

import matplotlib.pyplot as plt
plt.imshow(digits.images[-1], cmap=plt.cm.gray_r)

사이킷런에서 digits dataset을 사용할 때는 8x8 이미지를 64길이의 벡터로 변환하여 사용한다.

data = digits.images.reshape((digits.images.shape[0], -1))
data.shape
(1797,64)

사이킷런에서 estimator라고 하는 것은 classfication, regression, clustering 등의 알고리즘이나 주요 피쳐를 추출, 필터링하는 transformer이다.
estimator라는 모델이 있는게 아님!

모든 estimator들은 fit메소드를 사용해서 데이터셋을 적용할 수 있다.
estimator의 파라미터는 다음처럼 설정 및 확인이 가능하다.

estimator = Estimator(param1=1, param2=2)
estimator.param1

데이터셋에 학습이 완료된 파라미터를 얻으려면 estimator.estimated_param_을 사용하면 된다.