Skip to content

Commit 18baa99

Browse files
jesterhazypengk19
authored andcommitted
fix: set _current_job_name in attach() (aws#808)
1 parent 4ca0d55 commit 18baa99

File tree

3 files changed

+3
-0
lines changed

3 files changed

+3
-0
lines changed

src/sagemaker/estimator.py

+1
Original file line numberDiff line numberDiff line change
@@ -326,6 +326,7 @@ def attach(cls, training_job_name, sagemaker_session=None, model_channel_name='m
326326
estimator = cls(sagemaker_session=sagemaker_session, **init_params)
327327
estimator.latest_training_job = _TrainingJob(sagemaker_session=sagemaker_session,
328328
job_name=init_params['base_job_name'])
329+
estimator._current_job_name = estimator.latest_training_job.name
329330
estimator.latest_training_job.wait()
330331
return estimator
331332

tests/unit/test_estimator.py

+1
Original file line numberDiff line numberDiff line change
@@ -452,6 +452,7 @@ def test_attach_framework(sagemaker_session):
452452
return_value=returned_job_description)
453453

454454
framework_estimator = DummyFramework.attach(training_job_name='neo', sagemaker_session=sagemaker_session)
455+
assert framework_estimator._current_job_name == 'neo'
455456
assert framework_estimator.latest_training_job.job_name == 'neo'
456457
assert framework_estimator.role == 'arn:aws:iam::366:role/SageMakerRole'
457458
assert framework_estimator.train_instance_count == 1

tests/unit/test_sklearn.py

+1
Original file line numberDiff line numberDiff line change
@@ -324,6 +324,7 @@ def test_attach(sagemaker_session, sklearn_version):
324324
return_value=returned_job_description)
325325

326326
estimator = SKLearn.attach(training_job_name='neo', sagemaker_session=sagemaker_session)
327+
assert estimator._current_job_name == 'neo'
327328
assert estimator.latest_training_job.job_name == 'neo'
328329
assert estimator.py_version == PYTHON_VERSION
329330
assert estimator.framework_version == sklearn_version

0 commit comments

Comments
 (0)