Module 3 of 7 · PyTorch on Apple Silicon · Intermediate

Tensors and GPU Operations

Duration: 22 min

Understanding Tensors on MPS

Tensors are the fundamental data structure in PyTorch. On MPS, tensors can reside on either CPU or GPU. Understanding tensor placement is critical for performance.

import torch

device = torch.device("mps")

# Create tensor on CPU
t_cpu = torch.randn(1000, 1000)
print(f"CPU tensor device: {t_cpu.device}")

# Create tensor on MPS
t_mps = torch.randn(1000, 1000, device=device)
print(f"MPS tensor device: {t_mps.device}")

# Move CPU tensor to MPS
t_moved = t_cpu.to(device)
print(f"Moved tensor device: {t_moved.device}")

# Check if tensor is on MPS
print(f"Is on MPS: {t_mps.is_mps}")

Tensor Operations on MPS

Most PyTorch operations are optimized for MPS. Here are common operations:

import torch

device = torch.device("mps")

# Create tensors on MPS
a = torch.randn(100, 100, device=device)
b = torch.randn(100, 100, device=device)

# Matrix multiplication (highly optimized on MPS)
c = torch.matmul(a, b)
print(f"Matrix multiplication result shape: {c.shape}")

# Element-wise operations
d = a + b
e = a * b
f = torch.sin(a)

# Reductions
mean_val = a.mean()
sum_val = a.sum()
max_val = a.max()

print(f"Mean: {mean_val.item():.4f}")
print(f"Sum: {sum_val.item():.4f}")
print(f"Max: {max_val.item():.4f}")

Performance: CPU vs MPS

Benchmark common operations:

import torch
import time

def benchmark_operation(name, operation, device, iterations=100):
    # Warmup
    for _ in range(10):
        operation()
    
    # Benchmark
    start = time.time()
    for _ in range(iterations):
        operation()
    elapsed = time.time() - start
    
    avg_time = (elapsed / iterations) * 1000  # Convert to ms
    print(f"{name} ({device}): {avg_time:.2f} ms")

# Test on CPU
device_cpu = torch.device("cpu")
a_cpu = torch.randn(1000, 1000, device=device_cpu)
b_cpu = torch.randn(1000, 1000, device=device_cpu)

benchmark_operation(
    "Matrix Mult",
    lambda: torch.matmul(a_cpu, b_cpu),
    "CPU"
)

# Test on MPS
device_mps = torch.device("mps")
a_mps = torch.randn(1000, 1000, device=device_mps)
b_mps = torch.randn(1000, 1000, device=device_mps)

benchmark_operation(
    "Matrix Mult",
    lambda: torch.matmul(a_mps, b_mps),
    "MPS"
)

Expected Output:

Matrix Mult (CPU): 15.32 ms
Matrix Mult (MPS): 2.45 ms
Speedup: 6.3x

Autograd on MPS

Automatic differentiation (autograd) works seamlessly on MPS:

import torch

device = torch.device("mps")

# Create tensor with gradient tracking
x = torch.randn(10, 5, device=device, requires_grad=True)
w = torch.randn(5, 3, device=device, requires_grad=True)

# Forward pass
y = torch.matmul(x, w)
loss = y.sum()

# Backward pass (computed on MPS)
loss.backward()

print(f"x.grad shape: {x.grad.shape}")
print(f"w.grad shape: {w.grad.shape}")
print(f"Gradients computed on: {x.grad.device}")

Memory Management on MPS

MPS shares system memory with CPU. Monitor memory usage:

import torch

device = torch.device("mps")

# Check available memory
print(f"MPS memory allocated: {torch.mps.current_allocated_memory() / 1e9:.2f} GB")

# Create large tensor
large_tensor = torch.randn(10000, 10000, device=device)
print(f"After allocation: {torch.mps.current_allocated_memory() / 1e9:.2f} GB")

# Free memory
del large_tensor
torch.mps.empty_cache()
print(f"After cleanup: {torch.mps.current_allocated_memory() / 1e9:.2f} GB")

Data Type Considerations

MPS supports float32 and float16. Float16 uses less memory and is faster:

import torch

device = torch.device("mps")

# Float32 (default)
x_f32 = torch.randn(1000, 1000, dtype=torch.float32, device=device)

# Float16 (half precision)
x_f16 = torch.randn(1000, 1000, dtype=torch.float16, device=device)

print(f"Float32 memory: {x_f32.element_size() * x_f32.nelement() / 1e6:.2f} MB")
print(f"Float16 memory: {x_f16.element_size() * x_f16.nelement() / 1e6:.2f} MB")

# Mixed precision training (recommended)
with torch.autocast(device_type="mps", dtype=torch.float16):
    y = torch.matmul(x_f32, x_f32)
    print(f"Mixed precision output dtype: {y.dtype}")
← Previous Continue interactively → Next →