AI/Deep Learning

[Deep Learning] 메타 러닝 (Meta Learning)

byunghyun23 2023. 12. 1. 03:11

일반적으로 딥러닝은 많은 데이터를 필요로 합니다. 고양이와 강아지를 분류하는 모델은 두 클래스에 대하여 충분한 데이터를 학습해야 입력 이미지가 무엇인지 분류할 수 있습니다.

 

사람은 데이터를 몇 번만 보고도 빠르게 학습할 수 있으나, 딥러닝 모델은 수많은 데이터를 기반으로 학습  시행착오를 겪어야 비로소 어떤 개념을 학습할 수 있습니다.

"딥러닝 모델도 사람처럼 '적은 데이터'만으로도 '빠르게' 학습할 수는 없을까?"라는 질문에 대한 답변으로 제안된 방법이 바로 메타 러닝(meta learning)입니다.

 

메타 러닝은 새로운 개념 또는 태스크를 빠르게 학습하기 위해 '학습을 학습(learning to learn)'하는 방법입니다.

메타 러닝의 핵심 아이디어는 모델이 단순히 해당 데이터를 학습하는 것뿐만 아니라 자신의 학습 능력을 스스로 향상시킨다는 것입니다.

 

이와 유사한 개념으로는 퓨샷 러닝(few-shot learning)이 있습니다. 고양이와 강아지 사진을 각각 3장씩만 보고, 새로운 사진이 고양이인지 강아지인지 분류할 수 있도록 하는 학습 방법을 말합니다.(적은 데이터로 학습)

또한 해당 예시에서는  2개의 클래스와 3개의 데이터를 학습했기 때문에 2-way 3-shot 방식이라고 부릅니다.

 

이러한 퓨삿 러닝을 잘 하기 위한 방법 중 하나가 바로 메타 러닝입니다. (+ 퓨샷 러닝을 잘 하기 위한 방법에는 Transfer Learning도 있습니다.)

 

메타 러닝은 크게 모델 기반(Model-based), 메트릭 기반(Metric-based), 최적화 기반(Optimization-based)으로 나눌 수 있습니다. 본 포스팅은 회귀, 분류, 강화 학습 등 다양한 머신 러닝 알고리즘에 적용 가능한 최적화 기반의 MAML(Model-Agnostic Meta-Learning)만 다루겠습니다.

 

MAML은 'Model-Agnostic (모델과 상관없이)'하게, 즉, 대부분의 머신 러닝 모델(지도 학습, 강화 학습 등)에 적용할 수 있습니다.

 

 

 

학습은 Parameter θ를 찾아나가는 방식으로 Gradient Descent를 진행합니다.

θ가 가리키는 곳이 Task에 대한 최적은 아니지만, Task를 빠르게 Adaptation 할 수 있는 곳이기 때문에 Parameter θ가 화살표가 가리키는 점으로 이동합니다.

이후, 얻은 θ에서 새로운 Task에 맞는 최적의 Parameter θ∗를 찾아가는 방식으로 Gradient Descent를 진행합니다.

 

지도 학습에 적용할 경우, Loss 함수만 MSE로 변경하면 됩니다.

 

 

적은 데이터로 최선의 결과를 얻기 위한 방법인 Meta Learning, Contrastive Learning, Transfer Learning 등은 알아두면 굉장히 도움이 될 것 같습니다.

 

코드 구현: https://github.com/byunghyun23/meta-learning

 

GitHub - byunghyun23/meta-learning: Meta Learning with MAML

Meta Learning with MAML. Contribute to byunghyun23/meta-learning development by creating an account on GitHub.

github.com

※ torch, torchmeta 버전을 확인해주세요. (torch=1.7.1+cu110, torchmeta==1.7.0)

torch 1.7, torchmeta 1.8이면 실행 중 오류 발생합니다.