Unraveling DCGAN: Enhancements in Generative Adversarial Networks
Written on
In my earlier piece discussing Generative Adversarial Networks (GANs), we examined the fundamental workings of these elegantly simple yet potent models, their applications, and crafted a basic GAN to produce new Pokémon species. If you haven't checked it out yet, here’s the link:
Generative Adversarial Networks 101
How to build a simple GAN towardsdatascience.com
In this article, we aim to elevate our model by integrating convolutional layers that are more adept at handling images. We will construct a Deep Convolutional GAN (DCGAN) to evaluate its performance and any challenges it may introduce. Let’s dive in!
Deep Convolutional GAN
Transitioning from a basic feed-forward GAN to a convolutional architecture might seem straightforward: simply substitute the dense layers with convolutional ones. However, this task proved to be more complex than anticipated.
Since the introduction of the original GAN by Goodfellow et al. in 2014, researchers have sought to deepen GANs and utilize convolutions for generating larger, higher-quality images. This endeavor faced difficulties as training was often unstable, and CNN architectures, which excelled in supervised scenarios, faltered in GAN applications. It took nearly two years before a successful deep convolutional GAN emerged. This model, aptly named DCGAN (Deep Convolutional GAN), was the result of extensive experimentation with various architectures. The authors proposed several key guidelines:
- Employ only strided convolutional layers, avoiding pooling and fully-connected layers.
- Implement batch normalization.
- Utilize ReLU activation for all generator layers except the last one, where tanh is preferred.
- Apply LeakyReLU activation in the discriminator for all layers except the final one, which does not require activation.
Let’s explore why these guidelines hold significant importance.
Firstly, using strided convolutions instead of pooling enables the model to learn its own downsampling or upsampling, enhancing the network's learning capabilities. Additionally, omitting fully-connected layers atop convolutional blocks became a trend, with global average pooling being favored. However, while global average pooling provided stability, it hampered convergence speed. The DCGAN authors found that connecting convolutional layers directly to the input (in the generator) or output (in the discriminator) was more effective.
Secondly, batch normalization helps stabilize the learning process by standardizing the inputs of network units to have a mean of zero and a variance of one. This mitigates issues stemming from poor parameter initialization and ensures that gradients remain effective during backpropagation. In GANs, batch normalization has also been shown to help combat mode collapse, a topic we will delve into shortly. Importantly, it should not be applied to the generator output layer or the discriminator input layer, as doing so would destabilize training.
Finally, regarding activations, ReLU in the generator promotes faster saturation and better coverage of the data's color space, while Leaky ReLU has proven effective in the discriminator, especially with high-resolution images.
Following these principles, we will construct a DCGAN to generate new Pokémon!
Deep Convolutional Generator
The PyTorch implementation of the generator based on the aforementioned guidelines may resemble the following structure. We will utilize 4 channels for the output image, as the Pokémon sprite images in our dataset contain 4 color channels.
Deep Convolutional Discriminator
The discriminator will also accept 4 channels as input. Other hyperparameters, like the count of blocks or hidden units per layer, can be chosen somewhat arbitrarily. Let's examine how this setup will perform with our dataset.
Training
The training method closely resembles that of a basic feed-forward GAN. The PokemonDataset object is a creation of mine and is available on GitHub.
The key differences lie in our manual initialization of convolutional layer parameters and the specification of momentum parameters (betas) for the Adam optimizer. We also adopt a small learning rate of 0.0002, adhering to the recommendations from the DCGAN authors.
The loss functions for the generator and discriminator, along with the training loop, remain the same as those used for the feed-forward GAN. Readers interested in this can refer to the previous article. The code can also be executed in this notebook.
Now, let’s see what our GAN has generated! Below are images of Pokémon created by the GAN as training progressed.
However, all generated images appear identical! What went wrong?
Mode Collapse
This phenomenon is a classic instance of what is referred to as mode collapse. To comprehend this concept, consider the probability space from which the generator samples images. This space is typically multimodal, meaning it has several regions from which an image is likely to be created, alongside areas from which generation is less probable.
Imagine a mountainous terrain. The peaks represent the modes of the generator’s probability distribution. The generator is significantly more inclined to produce an image from one of these peaks rather than from a valley's bottom. Ideally, images produced from different peaks would resemble various types of training images.
Mode collapse occurs when the generator realizes that generating images from a specific peak effectively deceives the discriminator. These images need not be realistic; they merely need to confuse the discriminator enough that it misclassifies them as genuine. Consequently, the generator focuses on producing more of these images, leading to repetitive outputs.
As training continues, the discriminator may eventually learn to differentiate these identical fake images from real training examples. However, the generator will simply shift to another peak that the discriminator struggles with, perpetuating the cycle of generating identical images.
This scenario is what transpired with our Pokémon generation. For many epochs, the generator produced similar-looking shapes before abruptly shifting to a different form. This exemplifies the challenges inherent in training GANs.
Preventing Mode Collapse
Upon recognizing mode collapse, what strategies can one employ? One option is to adjust the model's architecture by altering the number of layers, units per layer, or tweaking hyperparameters like the learning rate or optimizer's momentum. The architecture employed here closely resembles the original DCGAN architecture tested on facial and bedroom datasets, but adapting it to our Pokémon sprites may require substantial effort.
Fortunately, there is an alternative. Many training challenges in GANs stem from the chosen loss function. Stay tuned for the next article, where we will substitute the binary cross-entropy loss used thus far with a metric called Wasserstein distance to construct a Wasserstein GAN (WGAN), aiming to enhance the realism of our generated Pokémon!
Acknowledgments
The training loop code for this DCGAN and the loss-calculating functions are derived from Coursera’s Generative Adversarial Networks (GANs) Specialization, presented by Sharon Zhou et al.
Thank you for reading!
If you enjoyed this article, consider subscribing for email updates on my new posts. By becoming a Medium member, you can support my writing and gain unlimited access to all stories from various authors, including my own.
Want to stay updated on the rapidly evolving field of machine learning and AI? Check out my new newsletter, AI Pulse. If you’re in need of consulting, feel free to reach out or book a 1:1 session here.
You might also like to explore some of my other articles. Unsure which one to choose? Here are a few suggestions:
On the Importance of Bayesian Thinking in Everyday Life
This simple mind-shift will help you better understand the uncertain world around you towardsdatascience.com
Explainable Boosting Machines
Keeping accuracy high while providing explanations that enhance understanding and debugging of data. pub.towardsai.net
8 Hazards Menacing Machine Learning Systems in Production
What to be cautious of when maintaining ML systems towardsdatascience.com