Training CNNs on Apple Silicon
Duration: 25 min
Building a CNN for CIFAR-10
Let's build and train a Convolutional Neural Network on CIFAR-10 using MPS:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
# Set device
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print(f"Using device: {device}")
# Data preprocessing
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
# Load CIFAR-10
train_dataset = datasets.CIFAR10(
root='./data', train=True, download=True, transform=transform
)
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
# Define CNN
class SimpleCNN(nn.Module):
def __init__(self):
super(SimpleCNN, self).__init__()
self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
self.pool = nn.MaxPool2d(2, 2)
self.fc1 = nn.Linear(64 * 8 * 8, 128)
self.fc2 = nn.Linear(128, 10)
self.relu = nn.ReLU()
def forward(self, x):
x = self.pool(self.relu(self.conv1(x)))
x = self.pool(self.relu(self.conv2(x)))
x = x.view(x.size(0), -1)
x = self.relu(self.fc1(x))
x = self.fc2(x)
return x
# Initialize model
model = SimpleCNN().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
# Training loop
num_epochs = 5
for epoch in range(num_epochs):
total_loss = 0
for batch_idx, (images, labels) in enumerate(train_loader):
images, labels = images.to(device), labels.to(device)
# Forward pass
outputs = model(images)
loss = criterion(outputs, labels)
# Backward pass
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
if (batch_idx + 1) % 100 == 0:
avg_loss = total_loss / (batch_idx + 1)
print(f"Epoch [{epoch+1}/{num_epochs}], Batch [{batch_idx+1}], Loss: {avg_loss:.4f}")
print(f"Epoch {epoch+1} completed")Measuring Training Speed
Compare training speed on CPU vs MPS:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import time
def train_epoch(model, device, train_loader, criterion, optimizer):
model.train()
total_loss = 0
for images, labels in train_loader:
images, labels = images.to(device), labels.to(device)
outputs = model(images)
loss = criterion(outputs, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
return total_loss / len(train_loader)
# Setup
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
train_dataset = datasets.CIFAR10(
root='./data', train=True, download=True, transform=transform
)
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
# Model
class SimpleCNN(nn.Module):
def __init__(self):
super(SimpleCNN, self).__init__()
self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
self.pool = nn.MaxPool2d(2, 2)
self.fc1 = nn.Linear(64 * 8 * 8, 128)
self.fc2 = nn.Linear(128, 10)
def forward(self, x):
x = self.pool(torch.relu(self.conv1(x)))
x = self.pool(torch.relu(self.conv2(x)))
x = x.view(x.size(0), -1)
x = torch.relu(self.fc1(x))
x = self.fc2(x)
return x
# Benchmark on CPU
print("Training on CPU...")
model_cpu = SimpleCNN().to("cpu")
criterion = nn.CrossEntropyLoss()
optimizer_cpu = optim.Adam(model_cpu.parameters(), lr=0.001)
start = time.time()
train_epoch(model_cpu, "cpu", train_loader, criterion, optimizer_cpu)
cpu_time = time.time() - start
print(f"CPU time: {cpu_time:.2f} seconds")
# Benchmark on MPS
if torch.backends.mps.is_available():
print("Training on MPS...")
model_mps = SimpleCNN().to("mps")
optimizer_mps = optim.Adam(model_mps.parameters(), lr=0.001)
start = time.time()
train_epoch(model_mps, "mps", train_loader, criterion, optimizer_mps)
mps_time = time.time() - start
print(f"MPS time: {mps_time:.2f} seconds")
speedup = cpu_time / mps_time
print(f"Speedup: {speedup:.1f}x")Expected Output:
Training on CPU...
CPU time: 45.32 seconds
Training on MPS...
MPS time: 8.15 seconds
Speedup: 5.6xValidation Loop
Evaluate model performance:
def validate(model, device, val_loader):
model.eval()
correct = 0
total = 0
with torch.no_grad():
for images, labels in val_loader:
images, labels = images.to(device), labels.to(device)
outputs = model(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
accuracy = 100 * correct / total
return accuracy
# Use in training loop
val_loader = DataLoader(val_dataset, batch_size=128, shuffle=False)
for epoch in range(num_epochs):
train_loss = train_epoch(model, device, train_loader, criterion, optimizer)
val_accuracy = validate(model, device, val_loader)
print(f"Epoch {epoch+1}: Loss={train_loss:.4f}, Accuracy={val_accuracy:.2f}%")Saving and Loading Models
Save trained models for later use:
import torch
# Save model
torch.save(model.state_dict(), 'cifar10_cnn.pth')
print("Model saved")
# Load model
model = SimpleCNN()
model.load_state_dict(torch.load('cifar10_cnn.pth'))
model.to(device)
print("Model loaded")