k-최근접 이웃 알고리즘(KNN)은 비교 대상이 되는 데이터 주변에 가까이 존재하는 k개의 데이터와 비교해 가장 가까운 데이터로 판별하는 방법입니다.
예를 들어 위 그림에서 판별해야 할 데이터가 빨간색 삼각형 이라고 할 때, k=1로 설정할 경우 삼각형은 초록색 원으로 판별됩니다.
만약 k=3으로 설정하면 노란색 사각형 개수가 초록색 원의 개수보다 많기 때문에 삼각형은 사각형으로 판별됩니다.
다시 말해, k는 판별해야 할 데이터로부터 가장 가까운 데이터의 개수를 의미하고 이를 기준으로 판별하는 방법입니다.
만약 위 예시처럼 분류가 아닌 회귀(연속형)라면 가장 가까운 k개의 데이터 평균값으로 예측할 수 있습니다.
이러한 KNN 알고리즘은 다른 머신러닝 알고리즘과 달리 학습 데이터 전체를 메모리에 로드 후 테스트 데이터가 입력으로 들어왔을 때, 즉시 클래스 분류 또는 연속형 값을 예측합니다.
따라서 공간 비용은 크지만 시간 비용은 적다는 점이 특징입니다.
사이킷런의 KNeighborsClassifier 클래스를 사용하여 iris 꽃 품종을 분류해보겠습니다.
from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import accuracy_score
from sklearn.metrics import confusion_matrix
from sklearn.metrics import classification_report
raw_iris = datasets.load_iris()
X = raw_iris.data
y = raw_iris.target
# X.shape, y.shape: (150, 4) (150,)
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=144)
# X_train.shape, X_test.shape, y_train.shape, y_test.shape: (112, 4) (38, 4) (112,) (38,)
std_scaler = StandardScaler()
X_train = std_scaler.fit_transform(X_train)
X_test = std_scaler.transform(X_test)
clf_knn = KNeighborsClassifier(n_neighbors=3)
clf_knn.fit(X_train, y_train)
clf_knn_pred = clf_knn.predict(X_test)
# [1 2 0 2 2 0 1 2 2 0 1 1 2 0 1 0 2 1 1 2 0 0 1 1 1 2 2 1 2 1 1 0 2 1 2 0 2 1]
accuracy = accuracy_score(y_test, clf_knn_pred)
# 1.0
conf_matrix = confusion_matrix(y_test, clf_knn_pred)
# [[ 9 0 0]
# [ 0 15 0]
# [ 0 0 14]]
class_report = classification_report(y_test, clf_knn_pred)
# precision recall f1-score support
#
# 0 1.00 1.00 1.00 9
# 1 1.00 1.00 1.00 15
# 2 1.00 1.00 1.00 14
#
# accuracy 1.00 38
# macro avg 1.00 1.00 1.00 38
# weighted avg 1.00 1.00 1.00 38
테스트 데이터 38개를 각 클래스 별로(Class 0-9개, Class 1-15개, Class 2-14개)로 정확하게 분류하였습니다.
'AI > Machine Learning' 카테고리의 다른 글
[Machine Learning] 라쏘, 릿지, 엘라스틱넷 (Ridge, Lasso, ElasticNet) (0) | 2022.09.28 |
---|---|
[Machine Learning] 선형 회귀 (Linear Regression) (0) | 2022.09.28 |
[Machine Learning] 모델 성능 평가 - 분류, 회귀 (0) | 2022.09.27 |
[Machine Learning] 오차 행렬 (Confusion Matrix) (0) | 2022.09.27 |
[Machine Learning] 손실 함수 - 비용 함수, 목적 함수 (0) | 2022.09.27 |