内容简介:So you’ve read about GANs. Maybe you’ve even trained one. Maybe you’ve‘MODE COLLAPSE, WHAT TO DO?’‘Add more data,’ your professor says. ‘And maybe take a nap.’
So you’ve read about GANs. Maybe you’ve even trained one. Maybe you’ve tried to train one and watched as the discriminator loss goes down and down and down and “BOOM” you’ve overfitted on the training data. You print out 100 images and 50 of them are the same malformed picture of a golden retriever. You’ve gone to your professor, maybe with a tear in your eyes, and you’ve proclaimed:
‘MODE COLLAPSE, WHAT TO DO?’
‘Add more data,’ your professor says. ‘And maybe take a nap.’
Fear not! The recent paper “Differentiable Augmentation for Data-Efficient GAN Training” from MIT claims to be your salvation, or at least part of it (Zhao, Liu, Lin, Zhu & Han, 2020). The paper claims to require less data whilst still achieving state-of-the-art results using a special kind of data augmentation called ‘differentiable’ augmentation. In this blog post, I’ll spill the tea on this paper (which if you can’t already tell, I’m very excited about). I’ll tell you how this paper claims to improve GAN training, and I’ll talk about whether it actually works. So grab a cuppa and a notebook, and pray to the GAN gods that this helps to vanquish the old monster of ‘mode collapse.’
The Problem with GANs: A Very Quick Recap
So here’s a figure from the paper that talks about what’s wrong. One of the first and most important concepts you’ll learn about in machine learning is ‘overfitting.’ That is essentially what’s going on here. The discriminator continues to improve in training but does awfully during validation because it’s begun to ‘memorize’ the data . This does not necessarily lead to mode collapse, though it often does. If you observe mode collapse, it’s a form of evidence that the discriminator has overfitted on the data. Often, we just add more data in to prevent this problem — and of course, this does often help…but so much data is not necessarily easy to collect. The paper provides a potent example: what if we are trying to generate images of a rare species? We do not have access to more data. We do not have to limit ourselves to such extreme edge cases as rare species, however. Even when we are talking about regular items like clothing, collecting data is expensive . Annotating data is expensive. It takes years . We want the model to work now .
The Solution: Augmentation?
So now we’ll get to the solution that our paper talks about. The paper observes that when overfitting arises in situations of supervised learning (say a straightforward image classification problem), and we do not have more data to add, we would do something called augmentation of the data. [As a side note, do feel free to read up on other working solutions to the overfitting problem such as regularization].
Image augmentation refers to flipping the picture on its side or changing its color a little, etc. etc. We just change the pictures a little so we get more samples. But with GANs, this augmentation can’t straightforwardly work. The authors provide two ways we might augment data during GAN training, and why both fail to achieve good output images. And then they provide a third option which does work (differentiable augmentation) and that’s what their paper is all about. So here are the two options that don’t work:
Solution 1: Augment Reals Only
As you’ll recall, when we’re training a GAN, we have input images that are actual pictures of actual objects. We use this alongside the fakes that our generator makes, to input to the discriminator. So in this first approach to augmentation, we just augment these real images. Simple, right?
Wrong.
Zhao et al (2020) report that the augmentation random horizontal flips do improve the results moderately. But stronger augmentations such as translations & cutouts to only the reals causes problems like distortion and weird coloring in the generated images. Whilst vanilla augmentation might work for regular classification problems, with GANs we are not classifying, we are trying to generate the true distribution of the data. But if we go and distort the real input data, then our generated outputs will be similarly distorted as well. The generator is encouraged to match the augmented & distorted distribution, NOT the true distribution. So what about option #2 for augmenting the data?
Solution 2: Augment All Discriminator Inputs
In this option, we augment not just the reals, but also the fakes that are outputted by our generator and go into our discriminator. Interestingly, whilst the discriminator learns to perfectly classify between reals that are augmented and fakes that are augmented with an accuracy of above 90%, the discriminator fails to identify fakes that are not augmented , leading to an accuracy of below 10%. This is because the generator G receives its gradient from the non-augmented fake images . So we need some way of propagating the gradient to our generator G. Or else, in the horror-inducing words of Zhao et al (2020):
the generator completely fools the discriminator
Solution 3: Enter Differentiable Augmentation
So this is where the authors present a type of augmentation that does work, namely differentiable augmentation. To solve the issues with both solution 1 and 2, the authors provide a solution that 1. augments real and fake images used in the discriminator network, but also 2. successfully “propagates the gradients of the augmented samples to G.” This avoids the problem of failing to identify fakes that are not augmented which we discussed under solution 2 earlier. This is the crux of the paper: to allow the propagation of the gradient to the generator G, they simply make very certain that the augmentation is, as the name says, differentiable . The authors provide three primary examples of such a differentiable augmentation:
Translation (within [−1/8, 1/8] of the image size, padded with zeros), Cutout (masking with a random square of half image size), and Color (including random brightness within [−0.5, 0.5], contrast within [0.5, 1.5], and saturation within [0, 2]).
Results
So does this work? YES. Yes, it seems to work. The authors list some pretty cool results and I will list some of them here. I strongly encourage you to look at the rest of the results in the main paper — they well and truly blew my mind.
Achievement #1 : CIFAR-10 and CIFAR-100 Dataset. The authors used two famous GANs namely BigGAN and StyleGAN2 and tried several dataset sizes(100% data, 10% data, 20% data). To make their comparison fair to the baseline, they even made sure to use regularization & horizontal flips in the baseline method. For both CIFAR-10 and CIFAR-100, they demonstrate improvements over the baseline and are the new state-of-the-art for both CIFAR-10 and CIFAR-100.
Achievement #2: ImageNet . Differentiable augmentation advanced the state-of-the-art on both 100% dataset and reduced size datasets.
The reason I love the solution presented in this paper is that it is so logical. The authors 1. tried different augmentation methods for GANs, 2. identified the exact difficulties and 3. promptly fixed the problem by using a specific kind of augmentation i.e. differentiable augmentation that is performed on both reals and fakes, and thus allowed the gradient to be propagated to the generator. And yet this logically deduced solution does so much. Now anyone else training a GAN can add ‘use differentiable augmentation’ to their toolbox of rules such as ‘add noise to discriminator inputs’ and ‘penalize[regularize] discriminator weights’ (“Generative Adversarial Networks”, n.d.).
This paper really made me excited, prompting me to write a post about it at 1 am. I hope that my discussion of it helps you to understand how the solution works and gets you hyped up as well!
To end, a nice diagram from the paper which clearly displays their methodology:
References
Zhao, S., Liu, Z., Lin, J., Zhu, J. Y., & Han, S. (2020). Differentiable Augmentation for Data-Efficient GAN Training. arXiv preprint arXiv:2006.10738 .
Generative Adversarial Networks (n.d.). Common Problems. Google Developers . Retrieved from https://developers.google.com/machine-learning/gan/problems
以上就是本文的全部内容,希望本文的内容对大家的学习或者工作能带来一定的帮助,也希望大家多多支持 码农网
猜你喜欢:本站部分资源来源于网络,本站转载出于传递更多信息之目的,版权归原作者或者来源机构所有,如转载稿涉及版权问题,请联系我们。