Module 16 of 25 · TensorFlow & Keras · Intermediate

Generative Adversarial Networks (GANs)

Duration: 7 min

This module delves into Generative Adversarial Networks (GANs), a class of machine learning frameworks designed by Ian Goodfellow and his colleagues in 2014. GANs are powerful tools for generating new data that resembles a given dataset, and they have applications in image synthesis, text generation, and more. Understanding GANs is crucial for anyone interested in advanced machine learning techniques and generative models.

Understanding GANs Architecture

GANs consist of two neural networks: the Generator and the Discriminator. The Generator creates new data instances, while the Discriminator evaluates them for authenticity. The two networks are trained simultaneously in a min-max game where the Generator aims to fool the Discriminator, and the Discriminator aims to correctly classify real and generated data.

import tensorflow as tf
from tensorflow.keras import layers

# Define the Generator
def build_generator():
    model = tf.keras.Sequential()
    model.add(layers.Dense(256, input_dim=100))
    model.add(layers.LeakyReLU(alpha=0.2))
    model.add(layers.BatchNormalization())
    model.add(layers.Dense(512))
    model.add(layers.LeakyReLU(alpha=0.2))
    model.add(layers.BatchNormalization())
    model.add(layers.Dense(1024))
    model.add(layers.LeakyReLU(alpha=0.2))
    model.add(layers.BatchNormalization())
    model.add(layers.Dense(784, activation='tanh'))
    model.add(layers.Reshape((28, 28, 1)))
    return model

# Define the Discriminator
def build_discriminator():
    model = tf.keras.Sequential()
    model.add(layers.Flatten(input_shape=(28, 28, 1)))
    model.add(layers.Dense(512))
    model.add(layers.LeakyReLU(alpha=0.2))
    model.add(layers.Dense(256))
    model.add(layers.LeakyReLU(alpha=0.2))
    model.add(layers.Dense(1, activation='sigmoid'))
    return model

# Build and compile the Discriminator
discriminator = build_discriminator()
discriminator.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])

Try it in Google Colab: Open in Colab

Model: "sequential"
__________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense (Dense)                (None, 256)               25856    
__________________________________________________________________
leaky_re_lu (LeakyReLU)      (None, 256)               0        
__________________________________________________________________
batch_normalization (BatchNo (None, 256)               512      
__________________________________________________________________
dense_1 (Dense)              (None, 512)               131328   
__________________________________________________________________
leaky_re_lu_1 (LeakyReLU)    (None, 512)               0        
__________________________________________________________________
batch_normalization_1 (Batch (None, 512)               1024     
__________________________________________________________________
dense_2 (Dense)              (None, 1024)              524800   
__________________________________________________________________
leaky_re_lu_2 (LeakyReLU)    (None, 1024)              0        
__________________________________________________________________
batch_normalization_2 (Batch (None, 1024)              2048     
__________________________________________________________________
dense_3 (Dense)              (None, 784)               803088   
__________________________________________________________________
reshape (Reshape)            (None, 28, 28, 1)         0        
=================================================================
Total params: 1,487,356
Trainable params: 1,485,324
Non-trainable params: 2,032
__________________________________________________________________

Model: "sequential_1"
__________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
flatten (Flatten)            (None, 784)               0        
__________________________________________________________________
dense_4 (Dense)              (None, 512)               401408   
__________________________________________________________________
leaky_re_lu_3 (LeakyReLU)    (None, 512)               0        
__________________________________________________________________
dense_5 (Dense)              (None, 256)               131328   
__________________________________________________________________
leaky_re_lu_4 (LeakyReLU)    (None, 256)               0        
__________________________________________________________________
dense_6 (Dense)              (None, 1)                 257      
=================================================================
Total params: 532,993
Trainable params: 532,993
Non-trainable params: 0
__________________________________________________________________

Training GANs

Training GANs involves alternating between training the Discriminator and the Generator. The Discriminator is trained on real and generated data to improve its ability to distinguish between them. The Generator is trained to produce data that can fool the Discriminator. This process continues iteratively until the Generator produces high-quality, realistic data.

import numpy as np

# Build the Generator
generator = build_generator()

# Build the GAN model
gan_input = tf.keras.Input(shape=(100,))
gan_output = discriminator(generator(gan_input))
gan = tf.keras.Model(gan_input, gan_output)
gan.compile(loss='binary_crossentropy', optimizer='adam')

# Training function
def train_gan(gan, generator, discriminator, epochs=10000, batch_size=128):
    for epoch in range(epochs):
        # Train Discriminator
        real_images = np.random.uniform(-1, 1, size=(batch_size, 28, 28, 1))
        real_labels = np.ones((batch_size, 1))
        fake_images = generator.predict(np.random.normal(0, 1, size=(batch_size, 100)))
        fake_labels = np.zeros((batch_size, 1))
        d_loss_real = discriminator.train_on_batch(real_images, real_labels)
        d_loss_fake = discriminator.train_on_batch(fake_images, fake_labels)
        d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
        
        # Train Generator
        noise = np.random.normal(0, 1, size=(batch_size, 100))
        valid_y = np.array([1] * batch_size)
        g_loss = gan.train_on_batch(noise, valid_y)
        
        # Print the progress
        if epoch % 1000 == 0:
            print(f"Epoch {epoch}, D loss: {d_loss[0]}, G loss: {g_loss}")

# Train the GAN
train_gan(gan, generator, discriminator)

💡 Tip: When training GANs, it's important to monitor both the Generator and Discriminator losses. If the Discriminator loss is too low, it may indicate that the Generator is not producing diverse enough samples. Conversely, if the Generator loss is too high, the Discriminator may be too powerful, making it difficult for the Generator to produce realistic samples.

❓ What are the two main components of a GAN?

❓ What is the primary goal of the Generator in a GAN?

← Previous Continue interactively → Next →

Related Courses