Skip to content

Commit d2430a1

Browse files
authored
Estimator.attach should work with a job without hps (#665)
1 parent 15cc58f commit d2430a1

File tree

3 files changed

+18
-1
lines changed

3 files changed

+18
-1
lines changed

CHANGELOG.rst

+1
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ CHANGELOG
88
* doc-fix: Remove incorrect parameter for EI TFS Python README
99
* feature: ``Predictor``: delete SageMaker model
1010
* feature: ``Pipeline``: delete SageMaker model
11+
* bug-fix: Estimator.attach works with training jobs without hyperparameters
1112

1213
1.18.3.post1
1314
============

src/sagemaker/estimator.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -425,7 +425,9 @@ def _prepare_init_params_from_job_description(cls, job_details, model_channel_na
425425
init_params['output_path'] = job_details['OutputDataConfig']['S3OutputPath']
426426
init_params['output_kms_key'] = job_details['OutputDataConfig']['KmsKeyId']
427427

428-
init_params['hyperparameters'] = job_details['HyperParameters']
428+
has_hps = 'HyperParameters' in job_details
429+
init_params['hyperparameters'] = job_details['HyperParameters'] if has_hps else {}
430+
429431
if 'TrainingImage' in job_details['AlgorithmSpecification']:
430432
init_params['image'] = job_details['AlgorithmSpecification']['TrainingImage']
431433
elif 'AlgorithmName' in job_details['AlgorithmSpecification']:

tests/unit/test_estimator.py

+14
Original file line numberDiff line numberDiff line change
@@ -451,6 +451,20 @@ def test_attach_framework(sagemaker_session):
451451
assert framework_estimator.encrypt_inter_container_traffic is False
452452

453453

454+
def test_attach_without_hyperparameters(sagemaker_session):
455+
returned_job_description = RETURNED_JOB_DESCRIPTION.copy()
456+
del returned_job_description['HyperParameters']
457+
458+
mock_describe_training_job = Mock(name='describe_training_job',
459+
return_value=returned_job_description)
460+
sagemaker_session.sagemaker_client.describe_training_job = mock_describe_training_job
461+
462+
estimator = Estimator.attach(training_job_name='job',
463+
sagemaker_session=sagemaker_session)
464+
465+
assert estimator.hyperparameters() == {}
466+
467+
454468
def test_attach_framework_with_tuning(sagemaker_session):
455469
returned_job_description = RETURNED_JOB_DESCRIPTION.copy()
456470
returned_job_description['HyperParameters']['_tuning_objective_metric'] = 'Validation-accuracy'

0 commit comments

Comments
 (0)