Module 5 of 11 · AWS SageMaker — End-to-End ML Platform · Intermediate

Training Jobs

Duration: 60 min

Training Jobs are the core of SageMaker's ML workflow. This module covers the Estimator API, hyperparameter configuration, spot training for cost savings, and distributed training strategies.

Understanding Training Jobs

Training Jobs run your training script on managed infrastructure. They handle data loading, model training, and checkpoint management. SageMaker automatically scales resources and manages failures.

Using the Estimator API

from sagemaker.estimator import Estimator
import sagemaker

session = sagemaker.Session()
role = 'arn:aws:iam::123456789012:role/SageMakerRole'
bucket = session.default_bucket()

# Create estimator for custom training script
estimator = Estimator(
    image_uri='382416733822.dkr.ecr.us-east-1.amazonaws.com/sagemaker-xgboost:latest',
    role=role,
    instance_count=1,
    instance_type='ml.m5.xlarge',
    output_path=f's3://{bucket}/training-output',
    code_location=f's3://{bucket}/code',
    sagemaker_session=session
)

# Set hyperparameters
estimator.set_hyperparameters(
    epochs=10,
    batch_size=32,
    learning_rate=0.001,
    optimizer='adam'
)

# Start training
estimator.fit(
    {'training': f's3://{bucket}/train-data/'},
    job_name='training-job-2024-01-15',
    wait=True
)

Hyperparameter Configuration

from sagemaker.tensorflow import TensorFlow

# TensorFlow estimator with hyperparameters
tf_estimator = TensorFlow(
    entry_point='train.py',
    role=role,
    instance_count=1,
    instance_type='ml.p3.2xlarge',
    framework_version='2.8',
    py_version='py39',
    output_path=f's3://{bucket}/tf-output',
    sagemaker_session=session,
    hyperparameters={
        'epochs': 50,
        'batch_size': 64,
        'learning_rate': 0.001,
        'dropout': 0.5,
        'activation': 'relu'
    }
)

# Fit the model
tf_estimator.fit(
    {'training': f's3://{bucket}/train-data/'},
    job_name='tensorflow-training'
)

Spot Training for Cost Savings

from sagemaker.estimator import Estimator

# Enable spot training
estimator = Estimator(
    image_uri='382416733822.dkr.ecr.us-east-1.amazonaws.com/sagemaker-xgboost:latest',
    role=role,
    instance_count=1,
    instance_type='ml.m5.xlarge',
    output_path=f's3://{bucket}/training-output',
    sagemaker_session=session,
    use_spot_instances=True,
    max_run=3600,
    max_wait=5400
)

# Spot training can save up to 90% on compute costs
estimator.fit(
    {'training': f's3://{bucket}/train-data/'},
    job_name='spot-training-job'
)

Distributed Training

from sagemaker.pytorch import PyTorch

# Distributed training with multiple instances
pytorch_estimator = PyTorch(
    entry_point='train.py',
    role=role,
    instance_count=4,  # Multiple instances
    instance_type='ml.p3.8xlarge',
    framework_version='1.12',
    py_version='py38',
    output_path=f's3://{bucket}/pytorch-output',
    sagemaker_session=session,
    distribution={
        'torch_distributed': {
            'enabled': True
        }
    }
)

# Train with distributed strategy
pytorch_estimator.fit(
    {'training': f's3://{bucket}/train-data/'},
    job_name='distributed-training'
)

Training Job Configuration

{
  "training_job_config": {
    "job_name": "my-training-job",
    "role_arn": "arn:aws:iam::123456789012:role/SageMakerRole",
    "algorithm_specification": {
      "training_image": "382416733822.dkr.ecr.us-east-1.amazonaws.com/sagemaker-xgboost:latest",
      "training_input_mode": "File"
    },
    "input_data_config": [
      {
        "channel_name": "training",
        "data_source": {
          "s3_data_source": {
            "s3_data_type": "S3Prefix",
            "s3_uri": "s3://my-bucket/train-data/",
            "s3_data_distribution_type": "FullyReplicated"
          }
        }
      }
    ],
    "output_data_config": {
      "s3_output_path": "s3://my-bucket/training-output/"
    },
    "resource_config": {
      "instance_type": "ml.m5.xlarge",
      "instance_count": 1,
      "volume_size_in_gb": 30
    },
    "stopping_condition": {
      "max_runtime_in_seconds": 86400
    }
  }
}

Monitoring Training Progress

# Check training job status
import boto3

sm_client = boto3.client('sagemaker')

response = sm_client.describe_training_job(
    TrainingJobName='my-training-job'
)

print(f"Status: {response['TrainingJobStatus']}")
print(f"Start time: {response['CreationTime']}")
print(f"Training time: {response.get('TrainingEndTime', 'In progress')}")
print(f"Billable seconds: {response['BillableTimeInSeconds']}")

Quiz 1

❓ What is the primary purpose of SageMaker Training Jobs?

Quiz 2

❓ How much can spot training save on compute costs?

Quiz 3

❓ What is required for distributed training?

Quiz 4

❓ What does the Estimator API provide?

Quiz 5

❓ Where does SageMaker store training output?

← Previous Continue interactively → Next →

Related Courses