diff --git a/src/sagemaker/workflow/airflow.py b/src/sagemaker/workflow/airflow.py index 8663cc59c1..6d663bad94 100644 --- a/src/sagemaker/workflow/airflow.py +++ b/src/sagemaker/workflow/airflow.py @@ -236,6 +236,9 @@ def training_config(estimator, inputs=None, job_name=None, mini_batch_size=None) if estimator.tags is not None: train_config["Tags"] = estimator.tags + if estimator.metric_definitions is not None: + train_config["AlgorithmSpecification"]["MetricDefinitions"] = estimator.metric_definitions + return train_config diff --git a/tests/unit/test_airflow.py b/tests/unit/test_airflow.py index b878069c59..835929e750 100644 --- a/tests/unit/test_airflow.py +++ b/tests/unit/test_airflow.py @@ -284,6 +284,7 @@ def test_framework_training_config_all_args(ecr_prefix, sagemaker_session): tags=[{"{{ key }}": "{{ value }}"}], subnets=["{{ subnet }}"], security_group_ids=["{{ security_group_ids }}"], + metric_definitions=[{"Name": "{{ name }}", "Regex": "{{ regex }}"}], sagemaker_session=sagemaker_session, ) @@ -294,6 +295,7 @@ def test_framework_training_config_all_args(ecr_prefix, sagemaker_session): "AlgorithmSpecification": { "TrainingImage": "520713654638.dkr.ecr.us-west-2.amazonaws.com/sagemaker-tensorflow:1.10.0-cpu-py2", "TrainingInputMode": "Pipe", + "MetricDefinitions": [{"Name": "{{ name }}", "Regex": "{{ regex }}"}], }, "OutputDataConfig": { "S3OutputPath": "{{ output_path }}",