Skip to content

Commit e522948

Browse files
author
Chuyang Deng
committed
fix: workflow passing spot training param to training job
1 parent 16c1ed6 commit e522948

File tree

2 files changed

+5
-0
lines changed

2 files changed

+5
-0
lines changed

src/sagemaker/workflow/airflow.py

+3
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,9 @@ def training_base_config(estimator, inputs=None, job_name=None, mini_batch_size=
181181
if job_config["vpc_config"] is not None:
182182
train_config["VpcConfig"] = job_config["vpc_config"]
183183

184+
if estimator.train_use_spot_instances:
185+
train_config["EnableManagedSpotTraining"] = True
186+
184187
if estimator.hyperparameters() is not None:
185188
hyperparameters = {str(k): str(v) for (k, v) in estimator.hyperparameters().items()}
186189

tests/unit/test_airflow.py

+2
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@ def test_byo_training_config_all_args(sagemaker_session):
105105
model_uri="{{ model_uri }}",
106106
model_channel_name="{{ model_chanel }}",
107107
sagemaker_session=sagemaker_session,
108+
train_use_spot_instances=True,
108109
)
109110

110111
byo.set_hyperparameters(epochs=32, feature_dim=1024, mini_batch_size=256)
@@ -155,6 +156,7 @@ def test_byo_training_config_all_args(sagemaker_session):
155156
"Subnets": ["{{ subnet }}"],
156157
"SecurityGroupIds": ["{{ security_group_ids }}"],
157158
},
159+
"EnableManagedSpotTraining": True,
158160
"HyperParameters": {"epochs": "32", "feature_dim": "1024", "mini_batch_size": "256"},
159161
"Tags": [{"{{ key }}": "{{ value }}"}],
160162
}

0 commit comments

Comments
 (0)