본문 바로가기

deep learning

[Chapter 1]starGAN 리뷰

 

StarGAN이란

starGAN[1]은 multi-domain image-to-image translation 방법이다. 헤어, 성별, 나이 등 다양한 도메인 간 변환을 수행해야 할 때 cycleGAN[2]의 경우 각각의 변환당 generator를 생성해야 하지만, starGAN의 경우 unified architecture model로 하나의 generator로 다양한 도메인 간 변환이 수행 가능하다.

cycleGAN의 경우 머리와 성별, 머리와 나이 등 한 이미지를 다양한 도메인으로 변환을 하려는 경우, 여러 번의 변환을 거쳐야 한다. 이런 방식은 비효율적이고, 결과의 품질도 낮다. 반면에 starGAN은 한 번에 여러 도메인으로 변환을 수행할 수 있다. starGAN은 유연한 변환이 가능해 효율적이며, 품질이 우수하다.

 

StarGAN은 cross-domain model인 cycleGAN과 비슷한 구조다. 다만 domain 구분을 위해 사용하는 classification 부분이 추가되며, generator에 input은 image와 domain정보를 포함한다. cycleGAN을 이용해 다양한 도메인 간 training을 진행하려면 각 도메인 간 변환마다 generator가 필요하다. k 도메인 간 변환을 수행하려면 k(k-1) 개의 generator가 있어야 하는데, starGAN은 단 1개의 generator만드로 모든 변환이 수행 가능하다.

 

또한, mask vector를 이용해 다양한 데이터셋을 이용해 네트워크를 training이 가능하다. 이 경우, 단일 데이터셋을 이용해 training을 진행했을 때 보다 품질이 우수할 수 있다. 공통적인 feature인 global feature를 이용할 수 있기 때문이다. 특히, 사용하려는 데이터셋의 크기가 적은 경우 비슷한 task의 데이터셋을 이용해 함께 training을 진행하면 보다 좋은 결과를 얻을 수 있을 것이다.

starGAN은 main contribution을 다음과 같이 설명한다.

  • We propose StarGAN, a novel generative adversarial network that learns the mappings among multiple domains using only a single generator and a discriminator, training effectively from images of all domains.
  • We demonstrate how we can successfully learn multidomain image translation between multiple datasets by utilizing a mask vector method that enables StarGAN to control all available domain labels.
  • We provide both qualitative and quantitative results on facial attribute transfer and facial expression synthesis tasks using StarGAN, showing its superiority over baseline models.

 

StarGAN의 구성

(a)는 disctiminator가 training될 때이고, (b)(c)(d)는 generator가 training 될 때의 도식이다. StarGAN에서는 optimize를 위해 adversarial loss, domain classification loss, reconstruction loss를 이용해 objective functions를 구성

  • x: input image
  • c: target domain
  • c`: original domain

 

Adversarial Loss. real image와 구분하기 어려운 fake image를 만들기 위해 사용한다. generator는 fake image를 real image와 최대한 근사하려는 목적을 가진다. fake image를 discriminator가 real에 가깝다고 판단할수록 적은 loss를 갖는다. discriminator는 generator가 생성한 fake image를 fake라고 구분하고, real images를 real이라고 구분할수록 적은 loss를 갖는다.

starGAN에서는 training을 안정적으로 진행하여 higher quality image를 generate하기 위해 adversarial loss를 wasserstein GAN objective with gradient penalty[5]로 대체한다. 따라서 log 부분이 사라지고, gp부분이 추가된다.

 

Domain Classification Loss. generator를 이용해 target domain에서 주어진 정보의 형태로 변환되기를 바란다. 이를 위해 discriminator에 classfier 부분을 추가한다. classifier는 주어진 image의 domain정보를 예측한 결과를 feedback 하여 generator가 주어진 domain에 해당하는 generated images를 만들도록 하는 목적이다.

objective는 discriminator를 optimize하기 위한 real images의 domain을 classification 하는 부분과, generator를 optimize 하기 위해 generated images의 domain을 classification 하는 부분을 사용한다. discriminator에서 adversarial loss의 경우 real images와 generated image를 모두 이용해 loss를 계산한다. 하지만 domain classification loss의 경우 real images만 사용한다. 이는 generated images가 real images에 가까워지는 게 목적이라, training 중 분포가 real images set과 가깝게 이동하는 generated images가 아닌 real images에서 domain을 잘 예측해야 하기 때문이라 추측한다.

real image에 대한 domain classification loss 계산
fake image에 대한 domain classification loss 계산

 

Reconstruction Loss. reconstruction loss는 다음 두 가지 이유로 사용할 수 있다.

  1. starGAN에서는 image의 content는 유지한채 style만 변화시키길 원한다. 하지만 discriminator는 속일 수 있는 generated images는 content를 보존하지 않고, 완전히 다른 형상으로 변환이 수행될 수 있다. 누군지 알아보기 힘들 정도로 변환이 수행되는 것은 원하지 않기 때문에 reconstruction loss를 이용해 generated images가 input image의 domain과 함께 generator에 통과되면 input image와 같은 형상으로 복구할 수 있도록 reconstruction loss를 사용한다.
  2. mode collapse 현상을 방지해야 한다. mode collapse 현상은 생성자가 판별자를 속이는 적은 수의 샘플을 찾을 때 일어난다. 따라서 한정된 이 sample 이외에는 다른 것을 생성하지 못한다. 판별자의 가중치를 업데이트하지 않고 몇 번의 배치를 하는 동안 생성자를 훈련한다고 가정할 때 generator는 이를 항상 속이는 하나의 generated image(이를 모드(mode)라고 부른다)를 찾으려는 경향이 있고 latent space의 모든 point를 이 sample에 매핑할 수 있다. 이 말은 '손실 함수의 그레이디언트가 0에 가까운 값으로 무너진다(collapse)'는 뜻이다. 하나의 point에 속아 넘어가지 못하도록 discriminator를 다시 훈련하더라고 생성자는 판별자를 속이는 또 다른 mode를 쉽게 찾을 것이다. generator가 이미 input에 무감각해져서 다양한 output을 만들 이유가 없기 때문이다.[6]

 

Full Objective. 모든 loss를 이용해 최종적으로 generator와 discriminator를 optimize 하기 위한 objective functions는 다음과 같다.

 

 

Datasets

본 논문에서는 두 가지 데이터셋을 사용한다. celeba와 rafd이다. CelebA[7]: 얼굴 속성에 대한 202,599개의 연예인 dataset이다. image는 40개의 label로 annotated되어있다. RaFD[8]: 67명의 인물으로 수집한 표정에 대한 4,824개의 데이터셋이다. 각 인물은 세 가지 다른 시선 방향으로 8개의 얼굴 표정을 만들어 세 가지 다른 각도에서 만든다.

 

 

실험 결과

평가를 위해 사람이 직접 평가하는 Amazon Mechanical Turk (AMT)[9]를 이용한다. starGAN은 AMT 결과의 모든 부분에서 압도적인 표를 받았다. starGAN은 여러 datasets을 이용해 global feature를 활용할 수 있기 때문에 큰 격차가 발생한다. 1000개의 images를 보유한 a dataset과, 3000개의 images를 보유한 b의 dataset이 있을 때 a에 해당하는 변환으로 train 될 때 DIAT나 cycleGAN 같은 경우에는 오직 1000 training images를 이용할 수 있지만, starGAN은 mask vector를 이용해 사용 가능한 모든 dataset (a와 b)에서 4000 개의 images 모두를 사용할 수 있다. starGAN은 여러 dataset으로 global feature를 활용해 training시 높은 quaulity를 얻을 수 있다.

 

StarGAN 이외의 방법은 blurry하거나 어색하며, 원 얼굴 형태를 일부만 보존한다.

 

여러 dataset을 이용한 starGAN-JNT는 높은 시각적 quality 표정을 표현한다. 반면에 StarGAN-SNG는 표정 변화가 나타나지만 blurry 하며 배경이 gray 한 image를 생성한다. StarGAN-JNT는 low-level-task 공유로 개선하기 위해 두 dataset을 활용할 수 있다.

 

 

결론

starGAN은 multi-domain image-to-image translation을 효과적으로 수행한다. 여러 domain간 변환을 수행할 여러 개의 generator를 사용하는 [cycleGAN, discoGAN[10], pix2pix[2], cGAN[11]]과 같은 기존 방법과 달리 starGAN은 단 하나의 generator를 사용한다. 하나의 generator를 이용함에도 결과가 매우 우수하고, 유연한 확장이 가능하다. mask vector를 이용해 다양한 datasets으로 training을 진행할 수도 있는데, 이는 각 dataset의 공통된 feature를 통해 더 향상된 결과를 얻을 수 있도록 돕는다. 

 

 

참고자료:

[1]StarGAN: Unified Generative Adversarial Networks for Multi-Domain Image-to-Image Translation

[2]Image-to-Image Translation with Conditional Adversarial Networks

[3]Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks

[4]Wasserstein GAN

[5]Improved Training of Wasserstein GANs

[6]미술관에 GAN 딥러닝 실전 프로젝트

[7]Large-scale CelebFaces Attributes (CelebA) Dataset

[8]Radboud Faces Database

[9]Amazon Mechanical Turk

[10]Learning to Discover Cross-Domain Relations with Generative Adversarial Networks

[11]Conditional Generative Adversarial Nets

 

 

 

 

'deep learning' 카테고리의 다른 글

EfficientNet 리뷰  (0) 2021.02.19
Skip connection 정리  (2) 2021.02.02
[Chapter 2]starGAN 코드 레벨 분석  (0) 2021.01.28
[Style Trasfer]Instance normalization  (0) 2021.01.22
Anomaly detection  (0) 2021.01.22