Skip to content

Commit 56f737c

Browse files
authored
Fix logic around setting job name for tuning jobs (aws#47)
We originally weren't honoring job names passed through fit(). This change fixes that.
1 parent 535435a commit 56f737c

File tree

2 files changed

+25
-11
lines changed

2 files changed

+25
-11
lines changed

src/sagemaker/tuner.py

+15-11
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,8 @@ def __init__(self, min_value, max_value):
8181

8282

8383
class HyperparameterTuner(object):
84+
TUNING_JOB_NAME_MAX_LENGTH = 32
85+
8486
SAGEMAKER_ESTIMATOR_MODULE = 'sagemaker_estimator_module'
8587
SAGEMAKER_ESTIMATOR_CLASS_NAME = 'sagemaker_estimator_class_name'
8688

@@ -96,18 +98,25 @@ def __init__(self, estimator, objective_metric_name, hyperparameter_ranges, metr
9698
self.estimator = estimator
9799
self.objective_metric_name = objective_metric_name
98100
self.metric_definitions = metric_definitions
101+
self._validate_parameter_ranges()
99102

100103
self.strategy = strategy
101104
self.objective_type = objective_type
102105

103106
self.max_jobs = max_jobs
104107
self.max_parallel_jobs = max_parallel_jobs
108+
105109
self.base_tuning_job_name = base_tuning_job_name
106-
self.metric_definitions = metric_definitions
110+
self._current_job_name = None
107111
self.latest_tuning_job = None
108-
self._validate_parameter_ranges()
109112

110-
def prepare_for_training(self):
113+
def prepare_for_training(self, job_name=None):
114+
if job_name is not None:
115+
self._current_job_name = job_name
116+
else:
117+
base_name = self.base_tuning_job_name or base_name_from_image(self.estimator.train_image())
118+
self._current_job_name = name_from_base(base_name, max_length=self.TUNING_JOB_NAME_MAX_LENGTH, short=True)
119+
111120
self.static_hyperparameters = {to_str(k): to_str(v) for (k, v) in self.estimator.hyperparameters().items()}
112121
for hyperparameter_name in self._hyperparameter_ranges.keys():
113122
self.static_hyperparameters.pop(hyperparameter_name, None)
@@ -133,7 +142,7 @@ def fit(self, inputs, job_name=None, **kwargs):
133142
else:
134143
self.estimator.prepare_for_training(**kwargs)
135144

136-
self.prepare_for_training()
145+
self.prepare_for_training(job_name=job_name)
137146
self.latest_tuning_job = _TuningJob.start_new(self, inputs)
138147

139148
@classmethod
@@ -350,8 +359,6 @@ def _validate_parameter_ranges(self):
350359

351360

352361
class _TuningJob(_Job):
353-
TUNING_JOB_NAME_MAX_LENGTH = 32
354-
355362
def __init__(self, sagemaker_session, tuning_job_name):
356363
super(_TuningJob, self).__init__(sagemaker_session, tuning_job_name)
357364

@@ -368,10 +375,7 @@ def start_new(cls, tuner, inputs):
368375
"""
369376
config = _Job._load_config(inputs, tuner.estimator)
370377

371-
base_name = tuner.base_tuning_job_name or base_name_from_image(tuner.estimator.train_image())
372-
tuning_job_name = name_from_base(base_name, max_length=cls.TUNING_JOB_NAME_MAX_LENGTH, short=True)
373-
374-
tuner.estimator.sagemaker_session.tune(job_name=tuning_job_name, strategy=tuner.strategy,
378+
tuner.estimator.sagemaker_session.tune(job_name=tuner._current_job_name, strategy=tuner.strategy,
375379
objective_type=tuner.objective_type,
376380
objective_metric_name=tuner.objective_metric_name,
377381
max_jobs=tuner.max_jobs, max_parallel_jobs=tuner.max_parallel_jobs,
@@ -385,7 +389,7 @@ def start_new(cls, tuner, inputs):
385389
resource_config=(config['resource_config']),
386390
stop_condition=(config['stop_condition']))
387391

388-
return cls(tuner.sagemaker_session, tuning_job_name)
392+
return cls(tuner.sagemaker_session, tuner._current_job_name)
389393

390394
def stop(self):
391395
self.sagemaker_session.stop_tuning_job(HyperParameterTuningJobName=self.name)

tests/unit/test_tuner.py

+10
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,8 @@ def test_prepare_for_training(tuner):
148148
tuner.estimator.set_hyperparameters(**static_hyperparameters)
149149
tuner.prepare_for_training()
150150

151+
assert tuner._current_job_name.startswith(IMAGE_NAME)
152+
151153
assert len(tuner.static_hyperparameters) == 3
152154
assert tuner.static_hyperparameters['another_one'] == '0'
153155

@@ -157,6 +159,14 @@ def test_prepare_for_training(tuner):
157159
assert tuner.static_hyperparameters['sagemaker_estimator_module'] == module
158160

159161

162+
def test_prepare_for_training_with_job_name(tuner):
163+
static_hyperparameters = {'validated': 1, 'another_one': 0}
164+
tuner.estimator.set_hyperparameters(**static_hyperparameters)
165+
166+
tuner.prepare_for_training(job_name='some-other-job-name')
167+
assert tuner._current_job_name == 'some-other-job-name'
168+
169+
160170
def test_validate_parameter_ranges_number_validation_error(sagemaker_session):
161171
pca = PCA(ROLE, TRAIN_INSTANCE_COUNT, TRAIN_INSTANCE_TYPE, NUM_COMPONENTS,
162172
base_job_name='pca', sagemaker_session=sagemaker_session)

0 commit comments

Comments
 (0)