Gradients of GAN Objectives
This technical post will offer a new view of common training objectives for generative adversarial networks (GANs), including a justification for the widely-used non-saturating loss.
The gradient of the discriminator
First, let’s look at the original GAN loss function and show that it’s simpler than it looks. As defined in Goodfellow et al. (2014), it’s
where is the data distribution, is the noise distribution, is the discriminator, and is the generator.
To analyze this loss function, we’ll use the same idea as in my post on the cross entropy loss: instead of looking at the function itself, we’ll look at its gradient.
We’ll assume the discriminator uses the standard method of obtaining the output probability by applying the logistic sigmoid function to a scalar output
Taking the logarithm:
And the derivative:
This is quite surprising: by taking the derivative of with respect to
We can do the same calculation for the flipped probability
We’re going to use these facts to analyze the discriminator loss function
First, some change of notation: let’s write as to make it clear what input produced
Now we can write the gradient of as follows:
This gives a clearer view of the GAN objective than looking at the original loss function. From this, we see that the discriminator’s goal is simply to increase for real inputs and decrease for fake inputs, with the caveat that inputs are weighted by their probability of being misclassified.
This is an example of a general principle in machine learning, which is that learning algorithms should focus on examples the model gets wrong over examples the model gets right. Other examples of this principle include:
- The perceptron learning algorithm, which only updates on mistakes.
- The hinge loss, which has no gradient when the prediction has the right sign and
- Logistic regression, which has the same loss function as a GAN discriminator.
The non-saturating loss
If we accept the principle that a model should weight its mistakes more than its successes, then there’s something not quite right about
Though Goodfellow et al. motivate it differently, referring to vanishing gradients early in training, the idea is precisely the same: maximizing leads to weighting by the probability that is correctly classified by the discriminator. This is called the non-saturating loss and it is used ubiquitously, for example in StyleGAN 2 (Karras et al., 2020).
The hinge loss
Like the non-saturating loss, it uses different loss functions for the generator and the discriminator:
The gradient of this loss is similar to the gradient in
You can see the similarity to
You can also see why it’s necessary to use a different loss function for the generator. Otherwise, the generator would get no gradient from examples the discriminator confidently classifies as fake, which is precisely the opposite of what we want.
You might ask whether the reasoning from the previous section on the non-saturating loss can also be applied here. In fact, there is a direct analogue of the non-saturating loss:
This cuts off the gradient using the same rule as
I’m not aware of anyone using this alternate loss. If you’ve tried it, let me know.
Adding a W
Now let’s consider the Wasserstein GAN (Arjovsky et al., 2017). WGAN can be obtained from the original GAN by making the following modifications:
- Stop weighting by the probability of being misclassified. Instead, weight all equally.
- Enforce a Lipschitz constraint on
In the original paper, WGAN is presented differently. Arjovsky et al. motivate it by arguing that we should minimize the Wasserstein distance between the data distribution and the generator distribution, then proposing to use an alternative representation of the Wasserstein distance to approximate it using the discriminator. However, I prefer this presentation because it makes it clear that WGAN is very similar to the original GAN.
Why is my formulation of WGAN equivalent to that of Arjovsky et al.? Their formulation is to learn a discriminator
and the gradient of this quantity is
which is equal to without the weighting by misclassification probability.
Why does WGAN add the Lipschitz constraint? This was motivated in the WGAN paper by the fact that a Lipschitz constraint is necessary to estimate the Wasserstein distance — specifically, if the Wasserstein distance between and is
But there’s another way to look at the Lipschitz constraint: we can view it as a regularizer on the discriminator that keeps its gradients bounded and makes the generator easier to optimize. This view is supported by Kurach et al. (2019), which found that spectral normalization, a common way of enforcing the Lipschitz constraint (Miyato et al., 2018), helped all GAN variants in their study, not just WGAN:
Hopefully this post has given you a new way to think about GAN objectives. If you liked it, you might also like this other post, which applies the same idea to the cross entropy loss.