Lab41 Reading Group: Generative Adversarial Nets

This article originally appeared on Lab41's blog: Gab41. It is reposted here with permission.

A coffee, reading glasses, and a book sit on a table in front of a MacBook.

One of the great things about working at Lab41 is that we are always taking on new and exciting projects, but this means we need to stay knowledgeable in many different subjects and techniques. In order to stay up-to-date and expand our horizons we read and discuss about one paper a week. The most common subject of these papers is machine learning, but any topic that we deem interesting and useful could show up. We thought you might find these papers interesting, so we plan to start posting about our findings on Gab41.

This week’s paper is by Ian Goodfellow and introduces the concept of Generative Adversarial Nets (GANs).

Many of the problems tackled by machine learning can be broken down into one of two types: discriminative and generative. A discriminative task seeks to classify some input, and a generative task seeks to create a model that can generate data that looks like the training data. For example, an algorithm that identifies cats in photos is performing a discriminative task, while an algorithm that generates images of cats is performing a generative task.

Neural networks have proven to be extremely effective at discriminative tasks—beating 95% accuracy on various test benchmarks—but perform less effectively at generative tasks. Goodfellow’s key idea was to turn the generation of images (normally a generative task) into a partially discriminative task to benefit from this effectiveness. In a GAN model there are two networks. One network, the generative network G, learns to generate fake images while the other network, the discriminative network D, learns to detect fake images. The model is adversarial because the two networks “fight”: the generative network makes forgeries, the discriminative network tries to detect them, and both networks are constantly learning to beat the other network.

The model is trained by using a data distribution X, and a noise distribution Z. The generative network takes in noise and outputs images: G(Z). A toy example of an image from the training set (left) and the output from the generative network (right) are shown below.

A picture of a real cat (as X) and a cartoon cat (as G(Z)).
The images from the training set X (Left) and a fake image generated by the network G(Z) (Right). Cat picture CC-BY-SA from Zimin.V.G., Cat cartoon CC-BY XXspiritwolf2000XX.

The discriminative network takes in both real images and fake images and classifies them: D(x) and D(G(z)) respectively. It should return 1 for real images, and 0 for fake images. A discriminative network that is well trained would perform as shown in the image below.

The real cat is scored as 1, the cartoon as 0.
The discriminative network determines if images are real—scoring them a 1—or faked—scoring them a 0.

The two networks are trained using different optimization functions. The generative network tries to minimize the function log(1-D(G(z))). This function is minimized when the discriminative network incorrectly classifies a fake image as real. The discriminative network tries to maximize the function log(D(x)) + log(1-D(G(z))). The first term is maximized when the network correctly identifies real images, and the second is maximized when it correctly identifies fakes. Note that both networks are optimizing log(1-D(G(z))), but that one wants to make it small while the other wants to make it large!

Goodfellow trained his model on various datasets including MNIST, TFD, and CIFAR-10. The results are shown below for MNIST and TFD:

The way the network should score the images, with a 1 for the real cat and a 0 for the fake cartoon.

Each row contains a few examples generated by the model, while the furthest right column (highlighted with a yellow border) are the closest examples from the training set.

Some of the challenges of GAN models are that it can be difficult to synchronize the generative and discriminative networks. If one network is much better at its task than the other network, than the left behind network will have trouble learning because it always loses. Additionally, there is no explicit modeling of the probability of the data distribution; if it is required, it must be approximated from the model.