AI/Machine Learning

[Machine Learning] K-최근접 이웃 알고리즘 (KNN)

byunghyun23 2022. 9. 28. 15:11

k-최근접 이웃 알고리즘(KNN)은 비교 대상이 되는 데이터 주변에 가까이 존재하는 k개의 데이터와 비교해 가장 가까운 데이터로 판별하는 방법입니다.

예를 들어 위 그림에서 판별해야 할 데이터가 빨간색 삼각형 이라고 할 때, k=1로 설정할 경우 삼각형은 초록색 원으로 판별됩니다.

만약 k=3으로 설정하면 노란색 사각형 개수가 초록색 원의 개수보다 많기 때문에 삼각형은 사각형으로 판별됩니다.

다시 말해, k는 판별해야 할 데이터로부터 가장 가까운 데이터의 개수를 의미하고 이를 기준으로 판별하는 방법입니다.

 

만약 위 예시처럼 분류가 아닌 회귀(연속형)라면 가장 가까운 k개의 데이터 평균값으로 예측할 수 있습니다.

 

이러한 KNN 알고리즘은 다른 머신러닝 알고리즘과 달리 학습 데이터 전체를 메모리에 로드 후 테스트 데이터가 입력으로 들어왔을 때, 즉시 클래스 분류 또는 연속형 값을 예측합니다.

따라서 공간 비용은 크지만 시간 비용은 적다는 점이 특징입니다.

 

사이킷런의 KNeighborsClassifier 클래스를 사용하여 iris 꽃 품종을 분류해보겠습니다.

from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import accuracy_score
from sklearn.metrics import confusion_matrix
from sklearn.metrics import classification_report

raw_iris = datasets.load_iris()

X = raw_iris.data
y = raw_iris.target
# X.shape, y.shape: (150, 4) (150,)

X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=144)
# X_train.shape, X_test.shape, y_train.shape, y_test.shape: (112, 4) (38, 4) (112,) (38,)

std_scaler = StandardScaler()
X_train = std_scaler.fit_transform(X_train)
X_test = std_scaler.transform(X_test)

clf_knn = KNeighborsClassifier(n_neighbors=3)
clf_knn.fit(X_train, y_train)

clf_knn_pred = clf_knn.predict(X_test)
# [1 2 0 2 2 0 1 2 2 0 1 1 2 0 1 0 2 1 1 2 0 0 1 1 1 2 2 1 2 1 1 0 2 1 2 0 2 1]

accuracy = accuracy_score(y_test, clf_knn_pred)
# 1.0

conf_matrix = confusion_matrix(y_test, clf_knn_pred)
# [[ 9  0  0]
#  [ 0 15  0]
#  [ 0  0 14]]

class_report = classification_report(y_test, clf_knn_pred)
#               precision    recall  f1-score   support
# 
#            0       1.00      1.00      1.00         9
#            1       1.00      1.00      1.00        15
#            2       1.00      1.00      1.00        14
# 
#     accuracy                           1.00        38
#    macro avg       1.00      1.00      1.00        38
# weighted avg       1.00      1.00      1.00        38

테스트 데이터 38개를 각 클래스 별로(Class 0-9개, Class 1-15개, Class 2-14개)로 정확하게 분류하였습니다.