AI/Machine Learning

[Machine Learning] 스태킹 (Stacking)

byunghyun23 2023. 5. 31. 01:45

스태킹(stacking)은 앙상블 학습(ensemble learning) 방법 중 하나입니다.
앙상블 학습은 같은 데이터를 기반으로 학습한 여러 모델을 비교 및 결합하여 개별적인 모델보다 성능이 더 나은 최종 모델을 만드는 것입니다.

 

스태킹은 말 그대로 여러 가지 모델을 쌓아서 학습하는 방법으로, 베이스 모델과 메타 모델로 구성됩니다.

베이스 모델이 먼저 학습한 후 메타 모델은 베이스 모델의 예측을 피처(feature) 데이터로 활용해 최종 예측을 합니다.

베이스 모델은 여러 개의 모델들을 사용합니다.

 

위 그림에서 베이스 모델들은 Level0, 메타 모델은 Level1입니다.

각 베이스 모델을 A, B, C, 메타 모델을 D라고 할 때 D의 입력 feature는 A, B, C의 출력을 feature로 사용하고 target은 A, B, C에서 사용된 target을 그대로 사용합니다.

D의 입력 feature  A의출력값, B의출력값, C의출력값이며 3차원 데이터라고 할 수 있습니다.

 

예를 들어 현재 사용된 데이터가 (feature x,  target y)이고,

A, B, C의 출력이 (y1, y2, y3)라고 하면

D의 입력은 ((y1, y2, y3), y)가 됩니다.

 

실습은 sklearn StackingClassifier를 이용하여 유방암 데이터를 분류해 보겠습니다.

from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler

from sklearn import svm
from sklearn.naive_bayes import GaussianNB
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import StackingClassifier

from sklearn.metrics import accuracy_score
from sklearn.metrics import confusion_matrix
from sklearn.metrics import classification_report


# 데이터 불러오기
raw_breast_cancer = datasets.load_breast_cancer()

# 피쳐, 타겟 데이터 지정
X = raw_breast_cancer.data
y = raw_breast_cancer.target

# 트레이닝/테스트 데이터 분할
X_tn, X_te, y_tn, y_te = train_test_split(X, y, random_state=0)

# 데이터 표준화
std_scale = StandardScaler()
std_scale.fit(X_tn)
X_tn_std = std_scale.transform(X_tn)
X_te_std = std_scale.transform(X_te)

# 스태킹 학습
clf1 = svm.SVC(kernel='linear', random_state=1)
clf2 = GaussianNB()

clf_stkg = StackingClassifier(
            estimators=[
                ('svm', clf1),
                ('gnb', clf2)
            ],
            final_estimator=LogisticRegression())
clf_stkg.fit(X_tn_std, y_tn)

# 예측
pred_stkg = clf_stkg.predict(X_te_std)
print(pred_stkg)

# 정확도
accuracy = accuracy_score(y_te, pred_stkg)
print(accuracy)

# confusion matrix 확인
conf_matrix = confusion_matrix(y_te, pred_stkg)
print(conf_matrix)

# 분류 레포트 확인
class_report = classification_report(y_te, pred_stkg)
print(class_report)