@@ -81,6 +81,8 @@ def __init__(self, min_value, max_value):
81
81
82
82
83
83
class HyperparameterTuner (object ):
84
+ TUNING_JOB_NAME_MAX_LENGTH = 32
85
+
84
86
SAGEMAKER_ESTIMATOR_MODULE = 'sagemaker_estimator_module'
85
87
SAGEMAKER_ESTIMATOR_CLASS_NAME = 'sagemaker_estimator_class_name'
86
88
@@ -96,18 +98,25 @@ def __init__(self, estimator, objective_metric_name, hyperparameter_ranges, metr
96
98
self .estimator = estimator
97
99
self .objective_metric_name = objective_metric_name
98
100
self .metric_definitions = metric_definitions
101
+ self ._validate_parameter_ranges ()
99
102
100
103
self .strategy = strategy
101
104
self .objective_type = objective_type
102
105
103
106
self .max_jobs = max_jobs
104
107
self .max_parallel_jobs = max_parallel_jobs
108
+
105
109
self .base_tuning_job_name = base_tuning_job_name
106
- self .metric_definitions = metric_definitions
110
+ self ._current_job_name = None
107
111
self .latest_tuning_job = None
108
- self ._validate_parameter_ranges ()
109
112
110
- def prepare_for_training (self ):
113
+ def prepare_for_training (self , job_name = None ):
114
+ if job_name is not None :
115
+ self ._current_job_name = job_name
116
+ else :
117
+ base_name = self .base_tuning_job_name or base_name_from_image (self .estimator .train_image ())
118
+ self ._current_job_name = name_from_base (base_name , max_length = self .TUNING_JOB_NAME_MAX_LENGTH , short = True )
119
+
111
120
self .static_hyperparameters = {to_str (k ): to_str (v ) for (k , v ) in self .estimator .hyperparameters ().items ()}
112
121
for hyperparameter_name in self ._hyperparameter_ranges .keys ():
113
122
self .static_hyperparameters .pop (hyperparameter_name , None )
@@ -133,7 +142,7 @@ def fit(self, inputs, job_name=None, **kwargs):
133
142
else :
134
143
self .estimator .prepare_for_training (** kwargs )
135
144
136
- self .prepare_for_training ()
145
+ self .prepare_for_training (job_name = job_name )
137
146
self .latest_tuning_job = _TuningJob .start_new (self , inputs )
138
147
139
148
@classmethod
@@ -350,8 +359,6 @@ def _validate_parameter_ranges(self):
350
359
351
360
352
361
class _TuningJob (_Job ):
353
- TUNING_JOB_NAME_MAX_LENGTH = 32
354
-
355
362
def __init__ (self , sagemaker_session , tuning_job_name ):
356
363
super (_TuningJob , self ).__init__ (sagemaker_session , tuning_job_name )
357
364
@@ -368,10 +375,7 @@ def start_new(cls, tuner, inputs):
368
375
"""
369
376
config = _Job ._load_config (inputs , tuner .estimator )
370
377
371
- base_name = tuner .base_tuning_job_name or base_name_from_image (tuner .estimator .train_image ())
372
- tuning_job_name = name_from_base (base_name , max_length = cls .TUNING_JOB_NAME_MAX_LENGTH , short = True )
373
-
374
- tuner .estimator .sagemaker_session .tune (job_name = tuning_job_name , strategy = tuner .strategy ,
378
+ tuner .estimator .sagemaker_session .tune (job_name = tuner ._current_job_name , strategy = tuner .strategy ,
375
379
objective_type = tuner .objective_type ,
376
380
objective_metric_name = tuner .objective_metric_name ,
377
381
max_jobs = tuner .max_jobs , max_parallel_jobs = tuner .max_parallel_jobs ,
@@ -385,7 +389,7 @@ def start_new(cls, tuner, inputs):
385
389
resource_config = (config ['resource_config' ]),
386
390
stop_condition = (config ['stop_condition' ]))
387
391
388
- return cls (tuner .sagemaker_session , tuning_job_name )
392
+ return cls (tuner .sagemaker_session , tuner . _current_job_name )
389
393
390
394
def stop (self ):
391
395
self .sagemaker_session .stop_tuning_job (HyperParameterTuningJobName = self .name )
0 commit comments