diff --git a/CHANGELOG.rst b/CHANGELOG.rst index a7757cad0f..4977dbad9b 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -8,6 +8,7 @@ CHANGELOG * doc-fix: Remove incorrect parameter for EI TFS Python README * feature: ``Predictor``: delete SageMaker model * feature: ``Pipeline``: delete SageMaker model +* bug-fix: Estimator.attach works with training jobs without hyperparameters 1.18.3.post1 ============ diff --git a/src/sagemaker/estimator.py b/src/sagemaker/estimator.py index a0af44c2d7..f840ce17fc 100644 --- a/src/sagemaker/estimator.py +++ b/src/sagemaker/estimator.py @@ -425,7 +425,9 @@ def _prepare_init_params_from_job_description(cls, job_details, model_channel_na init_params['output_path'] = job_details['OutputDataConfig']['S3OutputPath'] init_params['output_kms_key'] = job_details['OutputDataConfig']['KmsKeyId'] - init_params['hyperparameters'] = job_details['HyperParameters'] + has_hps = 'HyperParameters' in job_details + init_params['hyperparameters'] = job_details['HyperParameters'] if has_hps else {} + if 'TrainingImage' in job_details['AlgorithmSpecification']: init_params['image'] = job_details['AlgorithmSpecification']['TrainingImage'] elif 'AlgorithmName' in job_details['AlgorithmSpecification']: diff --git a/tests/unit/test_estimator.py b/tests/unit/test_estimator.py index 59d803254e..e9c483b9b4 100644 --- a/tests/unit/test_estimator.py +++ b/tests/unit/test_estimator.py @@ -451,6 +451,20 @@ def test_attach_framework(sagemaker_session): assert framework_estimator.encrypt_inter_container_traffic is False +def test_attach_without_hyperparameters(sagemaker_session): + returned_job_description = RETURNED_JOB_DESCRIPTION.copy() + del returned_job_description['HyperParameters'] + + mock_describe_training_job = Mock(name='describe_training_job', + return_value=returned_job_description) + sagemaker_session.sagemaker_client.describe_training_job = mock_describe_training_job + + estimator = Estimator.attach(training_job_name='job', + sagemaker_session=sagemaker_session) + + assert estimator.hyperparameters() == {} + + def test_attach_framework_with_tuning(sagemaker_session): returned_job_description = RETURNED_JOB_DESCRIPTION.copy() returned_job_description['HyperParameters']['_tuning_objective_metric'] = 'Validation-accuracy'