Module 6 of 7 · PyTorch on Apple Silicon · Intermediate

Production Deployment & Best Practices

Duration: 15 min

Exporting Models for Production

Export trained models for deployment:

import torch
import torch.nn as nn

class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(64 * 8 * 8, 128)
        self.fc2 = nn.Linear(128, 10)
        
    def forward(self, x):
        x = self.pool(torch.relu(self.conv1(x)))
        x = self.pool(torch.relu(self.conv2(x)))
        x = x.view(x.size(0), -1)
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# Load trained model
model = SimpleCNN()
model.load_state_dict(torch.load('cifar10_cnn.pth'))
model.eval()

# Export to TorchScript
scripted_model = torch.jit.script(model)
scripted_model.save('cifar10_cnn.pt')
print("Model exported to TorchScript")

# Load and use
loaded_model = torch.jit.load('cifar10_cnn.pt')
x = torch.randn(1, 3, 32, 32)
output = loaded_model(x)
print(f"Output shape: {output.shape}")

Inference Optimization

Optimize models for inference:

import torch
import torch.nn as nn
import time

device = torch.device("mps")

model = SimpleCNN().to(device)
model.load_state_dict(torch.load('cifar10_cnn.pth'))
model.eval()

# Prepare input
x = torch.randn(1, 3, 32, 32, device=device)

# Standard inference
with torch.no_grad():
    start = time.time()
    for _ in range(1000):
        output = model(x)
    standard_time = time.time() - start

# Optimized inference with torch.jit
scripted_model = torch.jit.script(model).to(device)

with torch.no_grad():
    start = time.time()
    for _ in range(1000):
        output = scripted_model(x)
    jit_time = time.time() - start

print(f"Standard inference: {standard_time:.3f}s")
print(f"JIT inference: {jit_time:.3f}s")
print(f"Speedup: {standard_time / jit_time:.2f}x")

Batch Inference

Process multiple samples efficiently:

import torch
import torch.nn as nn

device = torch.device("mps")

model = SimpleCNN().to(device)
model.load_state_dict(torch.load('cifar10_cnn.pth'))
model.eval()

# Batch inference
batch_size = 32
num_samples = 1000

# Create dummy data
x_batch = torch.randn(batch_size, 3, 32, 32, device=device)

with torch.no_grad():
    predictions = []
    for i in range(0, num_samples, batch_size):
        batch = x_batch[:min(batch_size, num_samples - i)]
        output = model(batch)
        predictions.append(output)
    
    all_predictions = torch.cat(predictions, dim=0)
    print(f"Processed {all_predictions.shape[0]} samples")

Model Quantization

Reduce model size for deployment:

import torch
import torch.nn as nn
from torch.quantization import quantize_dynamic

device = torch.device("mps")

model = SimpleCNN()
model.load_state_dict(torch.load('cifar10_cnn.pth'))

# Dynamic quantization
quantized_model = quantize_dynamic(
    model,
    {nn.Linear},
    dtype=torch.qint8
)

# Save quantized model
torch.save(quantized_model.state_dict(), 'cifar10_cnn_quantized.pth')

# Compare sizes
import os
original_size = os.path.getsize('cifar10_cnn.pth') / 1e6
quantized_size = os.path.getsize('cifar10_cnn_quantized.pth') / 1e6

print(f"Original size: {original_size:.2f} MB")
print(f"Quantized size: {quantized_size:.2f} MB")
print(f"Reduction: {(1 - quantized_size/original_size) * 100:.1f}%")

Best Practices Checklist

Development:
✓ Use MPS for local training
✓ Monitor memory usage
✓ Use mixed precision for speed
✓ Validate on separate dataset
✓ Save checkpoints during training

Deployment:
✓ Export to TorchScript for production
✓ Use torch.no_grad() for inference
✓ Batch inference for throughput
✓ Consider quantization for size
✓ Profile inference latency
✓ Monitor memory on target device

Local Development:
✓ Train on M1/M2/M3 locally
✓ Iterate rapidly without cloud costs
✓ Use version control (Git)
✓ Document hyperparameters
✓ Save training logs

Deployment Strategies

Strategy 1: Local Inference

Deploy model on user's Mac for privacy:

import torch
import torch.nn as nn
from PIL import Image
from torchvision import transforms

device = torch.device("mps")

# Load model
model = SimpleCNN().to(device)
model.load_state_dict(torch.load('cifar10_cnn.pth'))
model.eval()

# Inference function
def predict(image_path):
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    
    image = Image.open(image_path)
    x = transform(image).unsqueeze(0).to(device)
    
    with torch.no_grad():
        output = model(x)
        _, predicted = torch.max(output, 1)
    
    return predicted.item()

# Use
prediction = predict('test_image.png')
print(f"Predicted class: {prediction}")

Strategy 2: Cloud Deployment

Deploy to AWS Lambda or similar:

# Export model for cloud
model = SimpleCNN()
model.load_state_dict(torch.load('cifar10_cnn.pth'))

# Use CPU for cloud (no MPS available)
scripted_model = torch.jit.script(model)
scripted_model.save('model.pt')

# Lambda handler
def lambda_handler(event, context):
    import torch
    
    model = torch.jit.load('model.pt')
    model.eval()
    
    # Process input
    x = torch.tensor(event['image']).unsqueeze(0)
    
    with torch.no_grad():
        output = model(x)
        prediction = output.argmax(1).item()
    
    return {'prediction': prediction}

Performance Comparison

Local vs Cloud deployment:

Metric              Local (M1)      AWS Lambda
─────────────────────────────────────────────
Latency             5-10ms          50-100ms
Cost per 1M calls   $0              $20
Privacy             ✓ Local         ✗ Cloud
Setup time          5 min           30 min
Scalability         Single device   Unlimited
← Previous Continue interactively → Next →