Skip to content

Commit 53fe1dc

Browse files
authored
fix: training_config returns MetricDefinitions (#1453)
1 parent d2d1cdf commit 53fe1dc

File tree

2 files changed

+5
-0
lines changed

2 files changed

+5
-0
lines changed

src/sagemaker/workflow/airflow.py

+3
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,9 @@ def training_config(estimator, inputs=None, job_name=None, mini_batch_size=None)
236236
if estimator.tags is not None:
237237
train_config["Tags"] = estimator.tags
238238

239+
if estimator.metric_definitions is not None:
240+
train_config["AlgorithmSpecification"]["MetricDefinitions"] = estimator.metric_definitions
241+
239242
return train_config
240243

241244

tests/unit/test_airflow.py

+2
Original file line numberDiff line numberDiff line change
@@ -284,6 +284,7 @@ def test_framework_training_config_all_args(ecr_prefix, sagemaker_session):
284284
tags=[{"{{ key }}": "{{ value }}"}],
285285
subnets=["{{ subnet }}"],
286286
security_group_ids=["{{ security_group_ids }}"],
287+
metric_definitions=[{"Name": "{{ name }}", "Regex": "{{ regex }}"}],
287288
sagemaker_session=sagemaker_session,
288289
)
289290

@@ -294,6 +295,7 @@ def test_framework_training_config_all_args(ecr_prefix, sagemaker_session):
294295
"AlgorithmSpecification": {
295296
"TrainingImage": "520713654638.dkr.ecr.us-west-2.amazonaws.com/sagemaker-tensorflow:1.10.0-cpu-py2",
296297
"TrainingInputMode": "Pipe",
298+
"MetricDefinitions": [{"Name": "{{ name }}", "Regex": "{{ regex }}"}],
297299
},
298300
"OutputDataConfig": {
299301
"S3OutputPath": "{{ output_path }}",

0 commit comments

Comments
 (0)