Module 13 of 26 · Deep Learning with PyTorch · Intermediate

generative-adversarial-networks

Duration: 8 min

This module delves into Generative Adversarial Networks (GANs), a powerful class of neural networks designed for generating new data instances that are similar to a training dataset. GANs are crucial in fields like image generation, data augmentation, and more. Understanding GANs will empower you to create models that can generate realistic images, text, and other data types.

Understanding the Basics of GANs

Generative Adversarial Networks consist of two neural networks: the Generator and the Discriminator. The Generator creates new data instances, while the Discriminator evaluates them for authenticity. Both networks are trained simultaneously in a competitive setup, improving each other's performance over time.

import torch
from torch import nn
from torch.optim import Adam
import torch.nn.functional as F

# Define the Generator
class Generator(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(Generator, self).__init__()
        self.main = nn.Sequential(
            nn.Linear(input_dim, 128),
            nn.ReLU(True),
            nn.Linear(128, 256),
            nn.ReLU(True),
            nn.Linear(256, output_dim),
            nn.Tanh()
        )

    def forward(self, x):
        return self.main(x)

# Define the Discriminator
class Discriminator(nn.Module):
    def __init__(self, input_dim):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            nn.Linear(input_dim, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(128, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.main(x)

# Hyperparameters
input_dim = 100
output_dim = 784  # For MNIST
lr = 0.0002
batch_size = 64

# Initialize the networks
generator = Generator(input_dim, output_dim)
discriminator = Discriminator(output_dim)

# Optimizers
optimizer_G = Adam(generator.parameters(), lr=lr)
optimizer_D = Adam(discriminator.parameters(), lr=lr)

# Loss function
criterion = nn.BCELoss()

Try it in Google Colab: Open in Colab

No output from the code snippet above. This is just the setup for the GAN.

Training the GAN

Training a GAN involves alternating between training the discriminator and the generator. The discriminator is trained to correctly classify real and generated data, while the generator is trained to fool the discriminator. This adversarial training process leads to improved generation capabilities over time.

import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import numpy as np

# Load the dataset
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5])])
dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Training loop
num_epochs = 50
for epoch in range(num_epochs):
    for real_data, _ in dataloader:
        # Train the discriminator
        optimizer_D.zero_grad()
        real_labels = torch.ones(batch_size, 1)
        output_D_real = discriminator(real_data.view(batch_size, -1))
        loss_D_real = criterion(output_D_real, real_labels)

        noise = torch.randn(batch_size, input_dim)
        fake_data = generator(noise)
        output_D_fake = discriminator(fake_data)
        loss_D_fake = criterion(output_D_fake, torch.zeros(batch_size, 1))
        loss_D = loss_D_real + loss_D_fake
        loss_D.backward()
        optimizer_D.step()

        # Train the generator
        optimizer_G.zero_grad()
        noise = torch.randn(batch_size, input_dim)
        fake_data = generator(noise)
        output_D_fake = discriminator(fake_data)
        loss_G = criterion(output_D_fake, torch.ones(batch_size, 1))
        loss_G.backward()
        optimizer_G.step()

    print(f'Epoch [{epoch+1}/{num_epochs}] Loss D: {loss_D.item():.4f}, Loss G: {loss_G.item():.4f}')

💡 Tip: Ensure that the batch size and learning rates are appropriately set to avoid issues like mode collapse or instability in training.

❓ What is the primary role of the discriminator in a GAN?

❓ Which loss function is commonly used for training the discriminator in a GAN?

← Previous Continue interactively → Next →

Related Courses