Training Jobs
Duration: 60 min
Training Jobs are the core of SageMaker's ML workflow. This module covers the Estimator API, hyperparameter configuration, spot training for cost savings, and distributed training strategies.
Understanding Training Jobs
Training Jobs run your training script on managed infrastructure. They handle data loading, model training, and checkpoint management. SageMaker automatically scales resources and manages failures.
Using the Estimator API
from sagemaker.estimator import Estimator
import sagemaker
session = sagemaker.Session()
role = 'arn:aws:iam::123456789012:role/SageMakerRole'
bucket = session.default_bucket()
# Create estimator for custom training script
estimator = Estimator(
image_uri='382416733822.dkr.ecr.us-east-1.amazonaws.com/sagemaker-xgboost:latest',
role=role,
instance_count=1,
instance_type='ml.m5.xlarge',
output_path=f's3://{bucket}/training-output',
code_location=f's3://{bucket}/code',
sagemaker_session=session
)
# Set hyperparameters
estimator.set_hyperparameters(
epochs=10,
batch_size=32,
learning_rate=0.001,
optimizer='adam'
)
# Start training
estimator.fit(
{'training': f's3://{bucket}/train-data/'},
job_name='training-job-2024-01-15',
wait=True
)Hyperparameter Configuration
from sagemaker.tensorflow import TensorFlow
# TensorFlow estimator with hyperparameters
tf_estimator = TensorFlow(
entry_point='train.py',
role=role,
instance_count=1,
instance_type='ml.p3.2xlarge',
framework_version='2.8',
py_version='py39',
output_path=f's3://{bucket}/tf-output',
sagemaker_session=session,
hyperparameters={
'epochs': 50,
'batch_size': 64,
'learning_rate': 0.001,
'dropout': 0.5,
'activation': 'relu'
}
)
# Fit the model
tf_estimator.fit(
{'training': f's3://{bucket}/train-data/'},
job_name='tensorflow-training'
)Spot Training for Cost Savings
from sagemaker.estimator import Estimator
# Enable spot training
estimator = Estimator(
image_uri='382416733822.dkr.ecr.us-east-1.amazonaws.com/sagemaker-xgboost:latest',
role=role,
instance_count=1,
instance_type='ml.m5.xlarge',
output_path=f's3://{bucket}/training-output',
sagemaker_session=session,
use_spot_instances=True,
max_run=3600,
max_wait=5400
)
# Spot training can save up to 90% on compute costs
estimator.fit(
{'training': f's3://{bucket}/train-data/'},
job_name='spot-training-job'
)Distributed Training
from sagemaker.pytorch import PyTorch
# Distributed training with multiple instances
pytorch_estimator = PyTorch(
entry_point='train.py',
role=role,
instance_count=4, # Multiple instances
instance_type='ml.p3.8xlarge',
framework_version='1.12',
py_version='py38',
output_path=f's3://{bucket}/pytorch-output',
sagemaker_session=session,
distribution={
'torch_distributed': {
'enabled': True
}
}
)
# Train with distributed strategy
pytorch_estimator.fit(
{'training': f's3://{bucket}/train-data/'},
job_name='distributed-training'
)Training Job Configuration
{
"training_job_config": {
"job_name": "my-training-job",
"role_arn": "arn:aws:iam::123456789012:role/SageMakerRole",
"algorithm_specification": {
"training_image": "382416733822.dkr.ecr.us-east-1.amazonaws.com/sagemaker-xgboost:latest",
"training_input_mode": "File"
},
"input_data_config": [
{
"channel_name": "training",
"data_source": {
"s3_data_source": {
"s3_data_type": "S3Prefix",
"s3_uri": "s3://my-bucket/train-data/",
"s3_data_distribution_type": "FullyReplicated"
}
}
}
],
"output_data_config": {
"s3_output_path": "s3://my-bucket/training-output/"
},
"resource_config": {
"instance_type": "ml.m5.xlarge",
"instance_count": 1,
"volume_size_in_gb": 30
},
"stopping_condition": {
"max_runtime_in_seconds": 86400
}
}
}Monitoring Training Progress
# Check training job status
import boto3
sm_client = boto3.client('sagemaker')
response = sm_client.describe_training_job(
TrainingJobName='my-training-job'
)
print(f"Status: {response['TrainingJobStatus']}")
print(f"Start time: {response['CreationTime']}")
print(f"Training time: {response.get('TrainingEndTime', 'In progress')}")
print(f"Billable seconds: {response['BillableTimeInSeconds']}")Quiz 1
❓ What is the primary purpose of SageMaker Training Jobs?
Quiz 2
❓ How much can spot training save on compute costs?
Quiz 3
❓ What is required for distributed training?
Quiz 4
❓ What does the Estimator API provide?
Quiz 5
❓ Where does SageMaker store training output?