Training Generative Adverserial Networks (GANs) from scratch in PyTorch

Image source

Deep neural networks are used mainly for supervised learning: classification or regression. Generative Adverserial Networks or GANs, however, use neural networks for a very different purpose: generative modeling.

Generative modeling is an unsupervised learning task that involves automatically discovering and learning the patterns in input data so that the model can be used to generate new examples that plausibly could have been drawn from the original dataset. — Source

While there are many approaches used for generative modeling, GANs take the following approach for learning:

There are two neural networks: a Generator and a Discriminator. The generator generates a fake sample given a random vector/matrix, and the discriminator attempts to detect whether a given sample is real (picked from the training data) or fake (generated by the generator).

Training happens in tandem: we train the discriminator for a few epochs, then train the generator for a few epochs, and repeat. This way both the generator and the discriminator get better at doing their jobs. This rather simple approach can lead to some astounding results. The following images, for instances, were all generated using GANs:

Image source

GANs however, can be notoriously difficult to train, and are extremely sensitive to hyperparameters, activation functions and regularization. In this tutorial, we’ll train a GAN to generate images of handwritten digits similar to those from the MNIST database.

MNIST Database

System Setup

The code for this tutorial is available as a Jupyter notebook that can be found here: jovian.ml/aakashns/06-mnist-gan . It is hosted on Jovian.ml, a platform for sharing Jupyter notebooks and data science projects.

To run the notebook online, click on the “Run” dropdown and select “Run on Kaggle” or “Run on Binder”. You can also run the notebook on your laptop or cloud GPU machine by running the following commands on your terminal:

pip install jovian --upgrade       # Install the helper library
jovian clone aakashns/06-mnist-gan # Download the code & resources
cd 06-mnist-gan # Enter the project folder
jovian install # Install dependencies
conda activate 06-mnist-gan # Activate virtual environment
jupyter notebook # Start Jupyter notebook

Make sure you have the Conda distribution of Python installed before running the above. Check out the library docs for more details. Most of the code for this tutorial is borrowed from this excellent repository of PyTorch tutorials.

Loading the Data

We begin by downloading and importing the data as a PyTorch dataset using the MNIST helper class from torchvision.datasets. We also transform the images into PyTorch tensors.

Note that we are are transforming the pixel values from the range [0, 1] to the range [-1, 1]. The reason for doing this will become clear when define the generator network. Let's look at a sample image tensor from the data.

As expected, the pixel values range from -1 to 1. Let’s define a helper to denormalize the tensors.

This function will also be useful for viewing the generated images.

Finally, let’s create a data loader to load the images in batches.

The data loader can be used within a for loop to load a batch of images and labels from the training dataset.

We’ll also create a device which can be used to move the data and models to a GPU, if one is available.

torch.cuda.is_available() is used to check whether a GPU is available. It allows us to write code which can be executed both with and without a GPU, without any modification.

Discriminator Network

The discriminator takes an image as input, and tries to classify it as “real” or “generated”. In this sense, it’s like any other neural network. While we can use a CNN for the discriminator, we’ll use a simple feedforward network with 3 linear layers to keep things since. We’ll treat each 28x28 image as a vector of size 784. Let’s first define a couple of constants that we’ll need later.

We can now create the network using nn.Sequential and nn.Linear.

Note that we’re using the Leaky ReLU activation for the discriminator.

Different from the regular ReLU function, Leaky ReLU allows the pass of a small gradient signal for negative values. It makes the gradients from the discriminator flows stronger into the generator. Instead of passing a gradient (slope) of 0 in the back-prop pass, it passes a small negative gradient. — Advanced GANs

Just like any other binary classification model, the output of the discriminator is a single number between 0 and 1, which can be interpreted as the probability of the input image being drawn from the actual MNIST dataset.

Let’s move the discriminator model to the chosen device.

Generator Network

The input to the generator is typically a random vector or a matrix which is used as a seed for generating an image. Once again, to keep things simple, we’ll use a feedforward neural network with 3 layers, and the output will be a vector of size 784, which can be transformed to a 28x28 px image. We’ll use a latent vectors of size 64 as inputs to the generator.

We can now create the network using nn.Sequential and nn.Linear.

We use the TanH activation function for the output layer of the generator.

“The output layer uses the Tanh activation function. Using a bounded activation allows the model to learn more quickly to saturate and cover the color space of the training distribution” — Stack Overflow

Note that the outputs of the TanH activation lie in the range [-1,1]. This is why have applied the same transformation to the images in the training dataset. Let's generate an output vector using the generator and transform it into an image by denormalizing the output.

We can now view the generated image using plt.imshow. Here’s a sample:

As one might expect, the output from the generator is basically random noise. Let’s define a helper function which can save a batch of outputs from the generator to a file.

Let’s move the generator to the chosen device.

Discriminator Training

Since the discriminator is a binary classification model, we can use the binary cross entropy loss function to quantify how well it is able to differentiate between real and generated images.

Image source

Let’s instantiate the loss function and optimizers for the discriminator and generator. We’ll use the Adam optimizer with a learning rate of 0.0002.

Let’s define helper functions to reset gradients and train the discriminator.

Here are the steps involved in training the discriminator.

  • We expect the discriminator to output 1 if the image was picked from the real MNIST dataset, and 0 if it was generated.
  • We first pass a batch of real images, and compute the loss, setting the target labels to 1.
  • Then, we generate a batch of fake images using the generator, pass them into the discriminator, and compute the loss, setting the target labels to 0.
  • Finally we add the two losses and use the overall loss to perform gradient descent to adjust the weights of the discriminator.

It’s important to note that we don’t change the weights of the generator model while training the discriminator (d_optimizer only affects the D.parameters())

Generator Training

Since the outputs of the generator are images, it’s not obvious how we can train the generator. This is where we employ a rather elegant trick, which is to use the discriminator as a part of the loss function. Here’s how it works:

  • We generate a batch of images using the generator, pass the into the discriminator.
  • We calculate the loss by setting the target labels to 1 i.e. real. We do this because the generator’s objective is to “fool” the discriminator.
  • We use the loss to perform gradient descent i.e. change the weights of the generator, so it gets better at generating real-like images.

Here’s what this looks like in code.

Once again, note that the weights of the discriminator are not affected while training the generator.

Training the Model

Let’s create a directory where we can save intermediate outputs from the generator to visually inspect the progress of the model.

Let’s save a batch of real images that we can use for visual comparison while looking at the generated images.

We’ll also define a helper function to save a batch of generated images to disk at the end of every epoch. We’ll use a fixed set of input vectors to the generator to see how the individual generated images evolve over time as we train the model.

We are now ready to train the model. In each epoch, we train the discriminator first, and then the generator. The training might take a while if you’re not using a GPU.

We can save the weights of the trained models to avoid retraining from scratch in the future.

Here’s how the generated images look, after the 10th, 50th, 100th and 300th epochs of training.

We can visualize the training process by combining the sample images generated after each epoch into a video using OpenCV.

We can also visualize how the loss changes over time. Visualizing losses is quite useful for debugging the training process. For GANs, we expect the generator’s loss to reduce over time, without the discriminator’s loss getting too high.

Save the Experiment and Results

We can caputre a snapshot of our work (including the Jupyter notebook, sample images and trained models) using the jovian Python library.

Running jovian.commit uploads the Jupyter notebook and the associated files to your Jovian.ml account. You can share the project online, or collaborate privately with your friends & colleagues. Jovian also automatically captures the Python environment (dependencies & library versions), so that you & others can reproduce your work easily.

Jovian includes a powerful commenting interface for discussion.

To learn more visit https://www.jovian.ml.

Founder, Jovian