diff --git a/src/sagemaker/automl/automl.py b/src/sagemaker/automl/automl.py index 334c1d5c88..ce71d50977 100644 --- a/src/sagemaker/automl/automl.py +++ b/src/sagemaker/automl/automl.py @@ -332,7 +332,7 @@ def attach(cls, auto_ml_job_name, sagemaker_session=None): total_job_runtime_in_seconds=auto_ml_job_desc.get("AutoMLJobConfig", {}) .get("CompletionCriteria", {}) .get("MaxAutoMLJobRuntimeInSeconds"), - job_objective=auto_ml_job_desc.get("AutoMLJobObjective", {}).get("MetricName"), + job_objective=auto_ml_job_desc.get("AutoMLJobObjective", {}), generate_candidate_definitions_only=auto_ml_job_desc.get( "GenerateCandidateDefinitionsOnly", False ), diff --git a/tests/unit/sagemaker/workflow/test_automl_step.py b/tests/unit/sagemaker/workflow/test_automl_step.py index 6f02eccf4a..d831729241 100644 --- a/tests/unit/sagemaker/workflow/test_automl_step.py +++ b/tests/unit/sagemaker/workflow/test_automl_step.py @@ -41,7 +41,7 @@ def test_single_automl_step(pipeline_session): max_candidates=1, max_runtime_per_training_job_in_seconds=3600, total_job_runtime_in_seconds=36000, - job_objective="fake job objective", + job_objective={"MetricName": "F1"}, generate_candidate_definitions_only=False, tags=[{"Name": "some-tag", "Value": "value-for-tag"}], content_type="x-application/vnd.amazon+parquet", @@ -111,7 +111,7 @@ def test_single_automl_step(pipeline_session): "VpcConfig": {"SecurityGroupIds": ["group"], "Subnets": ["subnet"]}, }, }, - "AutoMLJobObjective": "fake job objective", + "AutoMLJobObjective": {"MetricName": "F1"}, "InputDataConfig": [ { "ChannelType": "training", @@ -165,7 +165,7 @@ def test_single_automl_step_with_parameter(pipeline_session): max_candidates=1, max_runtime_per_training_job_in_seconds=3600, total_job_runtime_in_seconds=36000, - job_objective="fake job objective", + job_objective={"MetricName": "F1"}, generate_candidate_definitions_only=False, tags=[{"Name": "some-tag", "Value": "value-for-tag"}], content_type="x-application/vnd.amazon+parquet", @@ -239,7 +239,7 @@ def test_single_automl_step_with_parameter(pipeline_session): "VpcConfig": {"SecurityGroupIds": ["group"], "Subnets": ["subnet"]}, }, }, - "AutoMLJobObjective": "fake job objective", + "AutoMLJobObjective": {"MetricName": "F1"}, "InputDataConfig": [ { "ChannelType": "training", @@ -290,7 +290,7 @@ def test_get_best_auto_ml_model(pipeline_session): max_candidates=1, max_runtime_per_training_job_in_seconds=3600, total_job_runtime_in_seconds=36000, - job_objective="fake job objective", + job_objective={"MetricName": "F1"}, generate_candidate_definitions_only=False, tags=[{"Name": "some-tag", "Value": "value-for-tag"}], content_type="x-application/vnd.amazon+parquet", @@ -399,7 +399,7 @@ def test_automl_step_with_invalid_mode(pipeline_session): max_candidates=1, max_runtime_per_training_job_in_seconds=3600, total_job_runtime_in_seconds=36000, - job_objective="fake job objective", + job_objective={"MetricName": "F1"}, generate_candidate_definitions_only=False, tags=[{"Name": "some-tag", "Value": "value-for-tag"}], content_type="x-application/vnd.amazon+parquet", @@ -455,7 +455,7 @@ def test_automl_step_with_no_mode(pipeline_session): max_candidates=1, max_runtime_per_training_job_in_seconds=3600, total_job_runtime_in_seconds=36000, - job_objective="fake job objective", + job_objective={"MetricName": "F1"}, generate_candidate_definitions_only=False, tags=[{"Name": "some-tag", "Value": "value-for-tag"}], content_type="x-application/vnd.amazon+parquet",