37
37
)
38
38
from sagemaker .session import Session
39
39
from sagemaker .session import s3_input
40
- from sagemaker .utils import base_name_from_image , name_from_base
40
+ from sagemaker .utils import base_from_name , base_name_from_image , name_from_base
41
41
42
42
AMAZON_ESTIMATOR_MODULE = "sagemaker"
43
43
AMAZON_ESTIMATOR_CLS_NAMES = {
@@ -587,18 +587,21 @@ def attach(cls, tuning_job_name, sagemaker_session=None, job_details=None, estim
587
587
)
588
588
589
589
if "TrainingJobDefinition" in job_details :
590
- return cls ._attach_with_training_details (
591
- tuning_job_name , sagemaker_session , estimator_cls , job_details
590
+ tuner = cls ._attach_with_training_details (sagemaker_session , estimator_cls , job_details )
591
+ else :
592
+ tuner = cls ._attach_with_training_details_list (
593
+ sagemaker_session , estimator_cls , job_details
592
594
)
593
595
594
- return cls . _attach_with_training_details_list (
595
- tuning_job_name , sagemaker_session , estimator_cls , job_details
596
+ tuner . latest_tuning_job = _TuningJob (
597
+ sagemaker_session = sagemaker_session , job_name = tuning_job_name
596
598
)
599
+ tuner ._current_job_name = tuning_job_name
600
+
601
+ return tuner
597
602
598
603
@classmethod
599
- def _attach_with_training_details (
600
- cls , tuning_job_name , sagemaker_session , estimator_cls , job_details
601
- ):
604
+ def _attach_with_training_details (cls , sagemaker_session , estimator_cls , job_details ):
602
605
"""Create a HyperparameterTuner bound to an existing hyperparameter
603
606
tuning job that has the ``TrainingJobDefinition`` field set."""
604
607
estimator = cls ._prepare_estimator (
@@ -609,17 +612,10 @@ def _attach_with_training_details(
609
612
)
610
613
init_params = cls ._prepare_init_params_from_job_description (job_details )
611
614
612
- tuner = cls (estimator = estimator , ** init_params )
613
- tuner .latest_tuning_job = _TuningJob (
614
- sagemaker_session = sagemaker_session , job_name = tuning_job_name
615
- )
616
-
617
- return tuner
615
+ return cls (estimator = estimator , ** init_params )
618
616
619
617
@classmethod
620
- def _attach_with_training_details_list (
621
- cls , tuning_job_name , sagemaker_session , estimator_cls , job_details
622
- ):
618
+ def _attach_with_training_details_list (cls , sagemaker_session , estimator_cls , job_details ):
623
619
"""Create a HyperparameterTuner bound to an existing hyperparameter
624
620
tuning job that has the ``TrainingJobDefinitions`` field set."""
625
621
estimator_names = sorted (
@@ -664,18 +660,13 @@ def _attach_with_training_details_list(
664
660
665
661
init_params = cls ._prepare_init_params_from_job_description (job_details )
666
662
667
- tuner = HyperparameterTuner .create (
663
+ return HyperparameterTuner .create (
668
664
estimator_dict = estimator_dict ,
669
665
objective_metric_name_dict = objective_metric_name_dict ,
670
666
hyperparameter_ranges_dict = hyperparameter_ranges_dict ,
671
667
metric_definitions_dict = metric_definitions_dict ,
672
668
** init_params
673
669
)
674
- tuner .latest_tuning_job = _TuningJob (
675
- sagemaker_session = sagemaker_session , job_name = tuning_job_name
676
- )
677
-
678
- return tuner
679
670
680
671
def deploy (
681
672
self ,
@@ -941,6 +932,7 @@ def _prepare_init_params_from_job_description(cls, job_details):
941
932
job_details .get ("WarmStartConfig" , None )
942
933
),
943
934
"early_stopping_type" : tuning_config ["TrainingJobEarlyStoppingType" ],
935
+ "base_tuning_job_name" : base_from_name (job_details ["HyperParameterTuningJobName" ]),
944
936
}
945
937
946
938
if "HyperParameterTuningJobObjective" in tuning_config :
0 commit comments