Generative Adversarial Networks (GANs)
Duration: 7 min
This module delves into Generative Adversarial Networks (GANs), a class of machine learning frameworks designed by Ian Goodfellow and his colleagues in 2014. GANs are powerful tools for generating new data that resembles a given dataset, and they have applications in image synthesis, text generation, and more. Understanding GANs is crucial for anyone interested in advanced machine learning techniques and generative models.
Understanding GANs Architecture
GANs consist of two neural networks: the Generator and the Discriminator. The Generator creates new data instances, while the Discriminator evaluates them for authenticity. The two networks are trained simultaneously in a min-max game where the Generator aims to fool the Discriminator, and the Discriminator aims to correctly classify real and generated data.
import tensorflow as tf
from tensorflow.keras import layers
# Define the Generator
def build_generator():
model = tf.keras.Sequential()
model.add(layers.Dense(256, input_dim=100))
model.add(layers.LeakyReLU(alpha=0.2))
model.add(layers.BatchNormalization())
model.add(layers.Dense(512))
model.add(layers.LeakyReLU(alpha=0.2))
model.add(layers.BatchNormalization())
model.add(layers.Dense(1024))
model.add(layers.LeakyReLU(alpha=0.2))
model.add(layers.BatchNormalization())
model.add(layers.Dense(784, activation='tanh'))
model.add(layers.Reshape((28, 28, 1)))
return model
# Define the Discriminator
def build_discriminator():
model = tf.keras.Sequential()
model.add(layers.Flatten(input_shape=(28, 28, 1)))
model.add(layers.Dense(512))
model.add(layers.LeakyReLU(alpha=0.2))
model.add(layers.Dense(256))
model.add(layers.LeakyReLU(alpha=0.2))
model.add(layers.Dense(1, activation='sigmoid'))
return model
# Build and compile the Discriminator
discriminator = build_discriminator()
discriminator.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])Model: "sequential"
__________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
dense (Dense) (None, 256) 25856
__________________________________________________________________
leaky_re_lu (LeakyReLU) (None, 256) 0
__________________________________________________________________
batch_normalization (BatchNo (None, 256) 512
__________________________________________________________________
dense_1 (Dense) (None, 512) 131328
__________________________________________________________________
leaky_re_lu_1 (LeakyReLU) (None, 512) 0
__________________________________________________________________
batch_normalization_1 (Batch (None, 512) 1024
__________________________________________________________________
dense_2 (Dense) (None, 1024) 524800
__________________________________________________________________
leaky_re_lu_2 (LeakyReLU) (None, 1024) 0
__________________________________________________________________
batch_normalization_2 (Batch (None, 1024) 2048
__________________________________________________________________
dense_3 (Dense) (None, 784) 803088
__________________________________________________________________
reshape (Reshape) (None, 28, 28, 1) 0
=================================================================
Total params: 1,487,356
Trainable params: 1,485,324
Non-trainable params: 2,032
__________________________________________________________________
Model: "sequential_1"
__________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
flatten (Flatten) (None, 784) 0
__________________________________________________________________
dense_4 (Dense) (None, 512) 401408
__________________________________________________________________
leaky_re_lu_3 (LeakyReLU) (None, 512) 0
__________________________________________________________________
dense_5 (Dense) (None, 256) 131328
__________________________________________________________________
leaky_re_lu_4 (LeakyReLU) (None, 256) 0
__________________________________________________________________
dense_6 (Dense) (None, 1) 257
=================================================================
Total params: 532,993
Trainable params: 532,993
Non-trainable params: 0
__________________________________________________________________Training GANs
Training GANs involves alternating between training the Discriminator and the Generator. The Discriminator is trained on real and generated data to improve its ability to distinguish between them. The Generator is trained to produce data that can fool the Discriminator. This process continues iteratively until the Generator produces high-quality, realistic data.
import numpy as np
# Build the Generator
generator = build_generator()
# Build the GAN model
gan_input = tf.keras.Input(shape=(100,))
gan_output = discriminator(generator(gan_input))
gan = tf.keras.Model(gan_input, gan_output)
gan.compile(loss='binary_crossentropy', optimizer='adam')
# Training function
def train_gan(gan, generator, discriminator, epochs=10000, batch_size=128):
for epoch in range(epochs):
# Train Discriminator
real_images = np.random.uniform(-1, 1, size=(batch_size, 28, 28, 1))
real_labels = np.ones((batch_size, 1))
fake_images = generator.predict(np.random.normal(0, 1, size=(batch_size, 100)))
fake_labels = np.zeros((batch_size, 1))
d_loss_real = discriminator.train_on_batch(real_images, real_labels)
d_loss_fake = discriminator.train_on_batch(fake_images, fake_labels)
d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
# Train Generator
noise = np.random.normal(0, 1, size=(batch_size, 100))
valid_y = np.array([1] * batch_size)
g_loss = gan.train_on_batch(noise, valid_y)
# Print the progress
if epoch % 1000 == 0:
print(f"Epoch {epoch}, D loss: {d_loss[0]}, G loss: {g_loss}")
# Train the GAN
train_gan(gan, generator, discriminator)💡 Tip: When training GANs, it's important to monitor both the Generator and Discriminator losses. If the Discriminator loss is too low, it may indicate that the Generator is not producing diverse enough samples. Conversely, if the Generator loss is too high, the Discriminator may be too powerful, making it difficult for the Generator to produce realistic samples.
❓ What are the two main components of a GAN?
❓ What is the primary goal of the Generator in a GAN?