Skip to content

Commit f9628f8

Browse files
authored
change: set _current_job_name and base_tuning_job_name in HyperparameterTuner.attach() (aws#1650)
1 parent a062c6a commit f9628f8

File tree

2 files changed

+34
-27
lines changed

2 files changed

+34
-27
lines changed

src/sagemaker/tuner.py

Lines changed: 15 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
)
3838
from sagemaker.session import Session
3939
from sagemaker.session import s3_input
40-
from sagemaker.utils import base_name_from_image, name_from_base
40+
from sagemaker.utils import base_from_name, base_name_from_image, name_from_base
4141

4242
AMAZON_ESTIMATOR_MODULE = "sagemaker"
4343
AMAZON_ESTIMATOR_CLS_NAMES = {
@@ -587,18 +587,21 @@ def attach(cls, tuning_job_name, sagemaker_session=None, job_details=None, estim
587587
)
588588

589589
if "TrainingJobDefinition" in job_details:
590-
return cls._attach_with_training_details(
591-
tuning_job_name, sagemaker_session, estimator_cls, job_details
590+
tuner = cls._attach_with_training_details(sagemaker_session, estimator_cls, job_details)
591+
else:
592+
tuner = cls._attach_with_training_details_list(
593+
sagemaker_session, estimator_cls, job_details
592594
)
593595

594-
return cls._attach_with_training_details_list(
595-
tuning_job_name, sagemaker_session, estimator_cls, job_details
596+
tuner.latest_tuning_job = _TuningJob(
597+
sagemaker_session=sagemaker_session, job_name=tuning_job_name
596598
)
599+
tuner._current_job_name = tuning_job_name
600+
601+
return tuner
597602

598603
@classmethod
599-
def _attach_with_training_details(
600-
cls, tuning_job_name, sagemaker_session, estimator_cls, job_details
601-
):
604+
def _attach_with_training_details(cls, sagemaker_session, estimator_cls, job_details):
602605
"""Create a HyperparameterTuner bound to an existing hyperparameter
603606
tuning job that has the ``TrainingJobDefinition`` field set."""
604607
estimator = cls._prepare_estimator(
@@ -609,17 +612,10 @@ def _attach_with_training_details(
609612
)
610613
init_params = cls._prepare_init_params_from_job_description(job_details)
611614

612-
tuner = cls(estimator=estimator, **init_params)
613-
tuner.latest_tuning_job = _TuningJob(
614-
sagemaker_session=sagemaker_session, job_name=tuning_job_name
615-
)
616-
617-
return tuner
615+
return cls(estimator=estimator, **init_params)
618616

619617
@classmethod
620-
def _attach_with_training_details_list(
621-
cls, tuning_job_name, sagemaker_session, estimator_cls, job_details
622-
):
618+
def _attach_with_training_details_list(cls, sagemaker_session, estimator_cls, job_details):
623619
"""Create a HyperparameterTuner bound to an existing hyperparameter
624620
tuning job that has the ``TrainingJobDefinitions`` field set."""
625621
estimator_names = sorted(
@@ -664,18 +660,13 @@ def _attach_with_training_details_list(
664660

665661
init_params = cls._prepare_init_params_from_job_description(job_details)
666662

667-
tuner = HyperparameterTuner.create(
663+
return HyperparameterTuner.create(
668664
estimator_dict=estimator_dict,
669665
objective_metric_name_dict=objective_metric_name_dict,
670666
hyperparameter_ranges_dict=hyperparameter_ranges_dict,
671667
metric_definitions_dict=metric_definitions_dict,
672668
**init_params
673669
)
674-
tuner.latest_tuning_job = _TuningJob(
675-
sagemaker_session=sagemaker_session, job_name=tuning_job_name
676-
)
677-
678-
return tuner
679670

680671
def deploy(
681672
self,
@@ -941,6 +932,7 @@ def _prepare_init_params_from_job_description(cls, job_details):
941932
job_details.get("WarmStartConfig", None)
942933
),
943934
"early_stopping_type": tuning_config["TrainingJobEarlyStoppingType"],
935+
"base_tuning_job_name": base_from_name(job_details["HyperParameterTuningJobName"]),
944936
}
945937

946938
if "HyperParameterTuningJobObjective" in tuning_config:

tests/unit/test_tuner.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,12 @@
1919
import pytest
2020
from mock import Mock, patch
2121

22-
from sagemaker import Predictor
22+
from sagemaker import Predictor, utils
2323
from sagemaker.amazon.amazon_estimator import RecordSet
2424
from sagemaker.estimator import Framework
2525
from sagemaker.mxnet import MXNet
26-
27-
from sagemaker.session import s3_input
28-
2926
from sagemaker.parameter import ParameterRange
27+
from sagemaker.session import s3_input
3028
from sagemaker.tuner import (
3129
_TuningJob,
3230
create_identical_dataset_and_algorithm_tuner,
@@ -498,6 +496,9 @@ def test_attach_tuning_job_with_estimator_from_hyperparameters(sagemaker_session
498496
tuner = HyperparameterTuner.attach(JOB_NAME, sagemaker_session=sagemaker_session)
499497

500498
assert tuner.latest_tuning_job.name == JOB_NAME
499+
assert tuner.base_tuning_job_name == JOB_NAME
500+
assert tuner._current_job_name == JOB_NAME
501+
501502
assert tuner.objective_metric_name == OBJECTIVE_METRIC_NAME
502503
assert tuner.max_jobs == 1
503504
assert tuner.max_parallel_jobs == 1
@@ -580,6 +581,20 @@ def test_attach_with_no_specified_estimator(sagemaker_session):
580581
assert isinstance(tuner.estimator, Estimator)
581582

582583

584+
def test_attach_with_generated_job_name(sagemaker_session):
585+
job_name = utils.name_from_base(BASE_JOB_NAME, max_length=32, short=True)
586+
587+
job_details = copy.deepcopy(TUNING_JOB_DETAILS)
588+
job_details["HyperParameterTuningJobName"] = job_name
589+
590+
sagemaker_session.sagemaker_client.describe_hyper_parameter_tuning_job = Mock(
591+
name="describe_tuning_job", return_value=job_details
592+
)
593+
594+
tuner = HyperparameterTuner.attach(job_name, sagemaker_session=sagemaker_session)
595+
assert BASE_JOB_NAME == tuner.base_tuning_job_name
596+
597+
583598
def test_attach_with_warm_start_config(sagemaker_session):
584599
warm_start_config = WarmStartConfig(
585600
warm_start_type=WarmStartTypes.TRANSFER_LEARNING, parents={"p1", "p2"}

0 commit comments

Comments
 (0)