Module 4 of 7 · PyTorch on Apple Silicon · Intermediate

Training CNNs on Apple Silicon

Duration: 25 min

Building a CNN for CIFAR-10

Let's build and train a Convolutional Neural Network on CIFAR-10 using MPS:

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# Set device
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print(f"Using device: {device}")

# Data preprocessing
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# Load CIFAR-10
train_dataset = datasets.CIFAR10(
    root='./data', train=True, download=True, transform=transform
)
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)

# Define CNN
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(64 * 8 * 8, 128)
        self.fc2 = nn.Linear(128, 10)
        self.relu = nn.ReLU()
        
    def forward(self, x):
        x = self.pool(self.relu(self.conv1(x)))
        x = self.pool(self.relu(self.conv2(x)))
        x = x.view(x.size(0), -1)
        x = self.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# Initialize model
model = SimpleCNN().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Training loop
num_epochs = 5
for epoch in range(num_epochs):
    total_loss = 0
    for batch_idx, (images, labels) in enumerate(train_loader):
        images, labels = images.to(device), labels.to(device)
        
        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, labels)
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        
        if (batch_idx + 1) % 100 == 0:
            avg_loss = total_loss / (batch_idx + 1)
            print(f"Epoch [{epoch+1}/{num_epochs}], Batch [{batch_idx+1}], Loss: {avg_loss:.4f}")
    
    print(f"Epoch {epoch+1} completed")

Measuring Training Speed

Compare training speed on CPU vs MPS:

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import time

def train_epoch(model, device, train_loader, criterion, optimizer):
    model.train()
    total_loss = 0
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)
        
        outputs = model(images)
        loss = criterion(outputs, labels)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
    
    return total_loss / len(train_loader)

# Setup
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

train_dataset = datasets.CIFAR10(
    root='./data', train=True, download=True, transform=transform
)
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)

# Model
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(64 * 8 * 8, 128)
        self.fc2 = nn.Linear(128, 10)
        
    def forward(self, x):
        x = self.pool(torch.relu(self.conv1(x)))
        x = self.pool(torch.relu(self.conv2(x)))
        x = x.view(x.size(0), -1)
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# Benchmark on CPU
print("Training on CPU...")
model_cpu = SimpleCNN().to("cpu")
criterion = nn.CrossEntropyLoss()
optimizer_cpu = optim.Adam(model_cpu.parameters(), lr=0.001)

start = time.time()
train_epoch(model_cpu, "cpu", train_loader, criterion, optimizer_cpu)
cpu_time = time.time() - start
print(f"CPU time: {cpu_time:.2f} seconds")

# Benchmark on MPS
if torch.backends.mps.is_available():
    print("Training on MPS...")
    model_mps = SimpleCNN().to("mps")
    optimizer_mps = optim.Adam(model_mps.parameters(), lr=0.001)
    
    start = time.time()
    train_epoch(model_mps, "mps", train_loader, criterion, optimizer_mps)
    mps_time = time.time() - start
    print(f"MPS time: {mps_time:.2f} seconds")
    
    speedup = cpu_time / mps_time
    print(f"Speedup: {speedup:.1f}x")

Expected Output:

Training on CPU...
CPU time: 45.32 seconds
Training on MPS...
MPS time: 8.15 seconds
Speedup: 5.6x

Validation Loop

Evaluate model performance:

def validate(model, device, val_loader):
    model.eval()
    correct = 0
    total = 0
    
    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    
    accuracy = 100 * correct / total
    return accuracy

# Use in training loop
val_loader = DataLoader(val_dataset, batch_size=128, shuffle=False)

for epoch in range(num_epochs):
    train_loss = train_epoch(model, device, train_loader, criterion, optimizer)
    val_accuracy = validate(model, device, val_loader)
    print(f"Epoch {epoch+1}: Loss={train_loss:.4f}, Accuracy={val_accuracy:.2f}%")

Saving and Loading Models

Save trained models for later use:

import torch

# Save model
torch.save(model.state_dict(), 'cifar10_cnn.pth')
print("Model saved")

# Load model
model = SimpleCNN()
model.load_state_dict(torch.load('cifar10_cnn.pth'))
model.to(device)
print("Model loaded")
← Previous Continue interactively → Next →