AI/Deep Learning

[Deep Learning] 적대적 생성 신경망 (GAN)

byunghyun23 2023. 6. 10. 02:41

적대적 생성 신경망(GAN: Generative Adversarial Networks)는 생성자(generator)와 판별자(discriminator)라고 하는 두 네트워크가 서로 경쟁함으로써 학습하는 네트워크입니다.

판별자는 데이터셋으로부터 생성되었는지 아니면 가짜로 생성되었는지를 확인합니다.

반면에 생성자의 임무는 판별자가 진짜와 가짜를 구분할 수 없도록 진짜 같은 가짜 데이터를 생성하는 것입니다.

즉, 생성자가 입력과 구별할 수 없는 가짜 데이터를 만들도록 학습하는 것이 GAN입니다.

 

GAN을 제안한 Ian Goodfellow은 경찰과 도둑을 예로 들어 적대적 신경망 학습을 설명했습니다.

도둑은 계속 위조 지폐를 생성하고, 경찰은 위조 지폐와 실제 지폐를 구별하는 것입니다.

 

 

즉, GAN의 아키텍처는 다음과 같습니다.

ref: 선형대수와 통계학으로 배우는 머신러닝 with 파이썬, 장철원, p564

 

그렇다면 생성자는 가짜 데이터를 어떻게 생성할까요?

먼저 생성자는 데이터셋과 관련없는 입력과 동일한 랜덤 데이터를 생성합니다.

그리고 판별자는 진짜 데이터와 생성자가 만들어낸 가짜 데이터를 입력으로 받아 어떤 것이 진짜이고 가짜인지 판별합니다.

학습이 진행되면서 생성자는 진짜 데이터와 비슷한 데이터를 만들어낼 것이고, 판별자가 진짜와 가짜를 구분하지 못할 때 네트워크가 잘 학습된 것이라고 볼 수 있습니다.

여기서 생성자는 오직 판별자를 속였는지 여부만으로 학습을 진행합니다. 즉, 판별자를 속였을 때 얻은 정보를 기반으로 진짜 데이터와 비슷한 데이터를 생성할 수 있도록 하는 것입니다.

 

생성자와 판별자의 손실 함수는 다음과 같습니다.

 

먼저 좌항 D(x)는 판별자가 진짜 데이터를 입력으로 받아 진짜로 판별하는 확률입니다. 따라서 판별자 입장에서는 좌항은  1에 가까울수록 좋습니다.

우항 D(G(z))는 생성자가 만들어낸 가짜 데이터를 판별자가 판별하는 확률입니다. 판별자는 가짜 데이터를 0에 가까운 확률로 판별하려 하고, log(1-x)는 0에서 1로 갈수록 값이 작아지기 때문에 판별자 입장에서는 우항이 작을수록 좋습니다. 

반대로 생성자는 D(G(z))가 1에 가까울수록 가짜 데이터를 진짜로 판별했다는 의미이기 때문에 우항이 클수록 좋습니다.

 

정리하면

생성자는 가짜 데이터를 판별자가 잘 분류하지 못하도록 매우 정교한 가짜 데이터를 만들어내는 것을 목표로 합니다.
즉, D(G(z))가 1이 되도록 학습하는 것이 생성자의 목표입니다.
판별자는 진짜 데이터는 진짜로, 가짜 데이터는 가짜라고 정확히 분류하는 것을 목표로 합니다.
즉, D(x)는 1이 되도록, D(G(z))는 0이 되도록 학습하는 것이 D의 목표입니다.

 

실습은 이곳을 참조해 주세요.