Module 5 of 7 · PyTorch on Apple Silicon · Intermediate

Memory Optimization & Benchmarking

Duration: 20 min

Understanding Memory on Apple Silicon

Apple Silicon shares unified memory between CPU and GPU. Unlike discrete GPUs with separate VRAM, MPS uses system RAM. This means:

Monitor memory usage:

import torch
import psutil

device = torch.device("mps")

# System memory
process = psutil.Process()
mem_info = process.memory_info()
print(f"System memory used: {mem_info.rss / 1e9:.2f} GB")

# MPS memory
print(f"MPS allocated: {torch.mps.current_allocated_memory() / 1e9:.2f} GB")
print(f"MPS reserved: {torch.mps.driver_allocated_memory() / 1e9:.2f} GB")

# Create tensor and monitor
large_tensor = torch.randn(5000, 5000, device=device)
print(f"After allocation: {torch.mps.current_allocated_memory() / 1e9:.2f} GB")

# Cleanup
del large_tensor
torch.mps.empty_cache()
print(f"After cleanup: {torch.mps.current_allocated_memory() / 1e9:.2f} GB")

Batch Size Optimization

Batch size directly impacts memory usage. Find the optimal batch size:

import torch
import torch.nn as nn
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import time

device = torch.device("mps")

# Load data
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

train_dataset = datasets.CIFAR10(
    root='./data', train=True, download=True, transform=transform
)

# Model
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(64 * 8 * 8, 128)
        self.fc2 = nn.Linear(128, 10)
        
    def forward(self, x):
        x = self.pool(torch.relu(self.conv1(x)))
        x = self.pool(torch.relu(self.conv2(x)))
        x = x.view(x.size(0), -1)
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

model = SimpleCNN().to(device)
criterion = nn.CrossEntropyLoss()

# Test different batch sizes
batch_sizes = [32, 64, 128, 256, 512]

for batch_size in batch_sizes:
    try:
        train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
        
        # Single forward pass
        images, labels = next(iter(train_loader))
        images, labels = images.to(device), labels.to(device)
        
        start = time.time()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        elapsed = time.time() - start
        
        mem_used = torch.mps.current_allocated_memory() / 1e9
        throughput = batch_size / elapsed
        
        print(f"Batch {batch_size}: {elapsed:.3f}s, {mem_used:.2f}GB, {throughput:.0f} samples/s")
        
        torch.mps.empty_cache()
    except RuntimeError as e:
        print(f"Batch {batch_size}: Out of memory")
        break

Expected Output:

Batch 32: 0.045s, 0.82GB, 711 samples/s
Batch 64: 0.052s, 1.24GB, 1231 samples/s
Batch 128: 0.068s, 2.15GB, 1882 samples/s
Batch 256: 0.142s, 4.32GB, 1803 samples/s
Batch 512: Out of memory

Mixed Precision Training

Use float16 to reduce memory and increase speed:

import torch
import torch.nn as nn
import torch.optim as optim
from torch.cuda.amp import autocast, GradScaler

device = torch.device("mps")

model = SimpleCNN().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
scaler = GradScaler()

# Training with mixed precision
for epoch in range(num_epochs):
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)
        
        # Forward pass with autocast
        with autocast(device_type="mps", dtype=torch.float16):
            outputs = model(images)
            loss = criterion(outputs, labels)
        
        # Backward pass
        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

Benefits:

Gradient Accumulation

Train with larger effective batch sizes without exceeding memory:

import torch
import torch.nn as nn
import torch.optim as optim

device = torch.device("mps")

model = SimpleCNN().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

accumulation_steps = 4
actual_batch_size = 128
effective_batch_size = actual_batch_size * accumulation_steps

for epoch in range(num_epochs):
    for batch_idx, (images, labels) in enumerate(train_loader):
        images, labels = images.to(device), labels.to(device)
        
        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, labels) / accumulation_steps
        
        # Backward pass (accumulate gradients)
        loss.backward()
        
        # Update weights every accumulation_steps
        if (batch_idx + 1) % accumulation_steps == 0:
            optimizer.step()
            optimizer.zero_grad()
            print(f"Effective batch size: {effective_batch_size}")

Profiling and Benchmarking

Profile your code to find bottlenecks:

import torch
import torch.nn as nn
from torch.profiler import profile, record_function, ProfilerActivity

device = torch.device("mps")

model = SimpleCNN().to(device)
images = torch.randn(128, 3, 32, 32, device=device)

# Profile
with profile(
    activities=[ProfilerActivity.CPU, ProfilerActivity.MPS],
    record_shapes=True
) as prof:
    with record_function("model_inference"):
        output = model(images)

print(prof.key_averages().table(sort_by="mps_time", row_limit=10))

Optimization Checklist

✓ Use MPS device for GPU acceleration
✓ Move model and data to device
✓ Use appropriate batch size (128-256 for M1)
✓ Enable mixed precision training
✓ Use gradient accumulation if needed
✓ Monitor memory with torch.mps.current_allocated_memory()
✓ Profile code to find bottlenecks
✓ Use torch.mps.empty_cache() to free memory
← Previous Continue interactively → Next →