AI/Machine Learning

[Machine Learning] 그리드 서치

byunghyun23 2022. 9. 13. 17:20

그리드 서치(grid search)는 학습 과정에서 하이퍼 파라미터의 후보군을 정하여 학습 후 모델 성능을 비교하여 최적의 하이퍼 파라미터를 선정하는 방법입니다.

예를 들어, KNN(K-Nearest Neighbor) 알고리즘에서 사용할 수 있는 k 값의 후보는 여러가지가 있을 수 있습니다.

k 값의 후보를 정해 놓고 모든 후보에 대한 모델 생성 및 학습 후, 성능을 비교하는 방법입니다.

from sklearn import datasets
from sklearn.metrics import accuracy_score
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifier

raw_iris = datasets.load_iris()

X = raw_iris.data
y = raw_iris.target

X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=7)

std_scaler = StandardScaler()
X_train = std_scaler.fit_transform(X_train)
X_test = std_scaler.transform(X_test)

best_acc = 0

final_k = None

for k in [1, 2, 3, 4, 5, 6, 7, 8, 9]:
    knn = KNeighborsClassifier(n_neighbors=k)
    knn.fit(X_train, y_train)

    y_pred = knn.predict(X_test)
    acc = accuracy_score(y_test, y_pred)

    if acc > best_acc:
        best_acc = acc
        final_k = k
    print('k:', k, 'acc:', acc)

print('final_k:', final_k)
print('best_acc:', best_acc)
k: 1 acc: 0.9210526315789473
k: 2 acc: 0.8947368421052632
k: 3 acc: 0.868421052631579
k: 4 acc: 0.8947368421052632
k: 5 acc: 0.868421052631579
k: 6 acc: 0.8947368421052632
k: 7 acc: 0.9473684210526315
k: 8 acc: 0.9473684210526315
k: 9 acc: 0.9210526315789473
final_k: 7
best_acc: 0.9473684210526315

k=7일 때 (k=8과 동일), 가장 좋은 성능을 보였습니다.