Skip to content

Commit 7d72221

Browse files
Add best tuning job method (aws#34)
1 parent d63deea commit 7d72221

File tree

2 files changed

+53
-2
lines changed

2 files changed

+53
-2
lines changed

src/sagemaker/tuner.py

+16-2
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def __init__(self, estimator, objective_metric_name, hyperparameter_ranges, metr
8383

8484
self.max_jobs = max_jobs
8585
self.max_parallel_jobs = max_parallel_jobs
86-
self.tuning_job_name = base_tuning_job_name
86+
self.base_tuning_job_name = base_tuning_job_name
8787
self.metric_definitions = metric_definitions
8888
self.latest_tuning_job = None
8989
self._validate_parameter_ranges()
@@ -125,8 +125,22 @@ def stop_tuning_job(self):
125125
self._ensure_last_tuning_job()
126126
self.latest_tuning_job.stop()
127127

128+
def best_training_job(self):
129+
"""Return name of the best training job for the latest tuning job.
130+
"""
131+
self._ensure_last_tuning_job()
132+
133+
tuning_job_describe_result = \
134+
self.estimator.sagemaker_session.sagemaker_client.describe_hyper_parameter_tuning_job(
135+
HyperParameterTuningJobName=self.latest_tuning_job.name)
136+
137+
try:
138+
return tuning_job_describe_result['BestTrainingJob']['TrainingJobName']
139+
except KeyError:
140+
raise Exception('Best training job not available for tuning job: {}'.format(self.latest_tuning_job.name))
141+
128142
def _ensure_last_tuning_job(self):
129-
if 'latest_tuning_job' not in dir(self) or self.latest_tuning_job is None:
143+
if self.latest_tuning_job is None:
130144
raise ValueError('No tuning job available')
131145

132146
def hyperparameter_ranges(self):

tests/unit/test_tuner.py

+37
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,43 @@ def test_stop_tuning_job_no_tuning_job(tuner):
174174
assert 'No tuning job available' in str(e)
175175

176176

177+
def test_best_tuning_job(tuner):
178+
tuning_job_description = {'BestTrainingJob': {'TrainingJobName': JOB_NAME}}
179+
180+
tuner.estimator.sagemaker_session.sagemaker_client.describe_hyper_parameter_tuning_job = Mock(
181+
name='describe_hyper_parameter_tuning_job', return_value=tuning_job_description)
182+
183+
tuner.latest_tuning_job = _TuningJob(tuner.estimator.sagemaker_session, JOB_NAME)
184+
best_training_job = tuner.best_training_job()
185+
186+
assert best_training_job == JOB_NAME
187+
tuner.estimator.sagemaker_session.sagemaker_client.describe_hyper_parameter_tuning_job.assert_called_once_with(
188+
HyperParameterTuningJobName=JOB_NAME)
189+
190+
191+
def test_best_tuning_job_no_latest_job(tuner):
192+
with pytest.raises(Exception) as e:
193+
tuner.best_training_job()
194+
195+
assert 'No tuning job available' in str(e)
196+
197+
198+
def test_best_tuning_job_no_best_job(tuner):
199+
tuning_job_description = {'BestTrainingJob': {'Mock': None}}
200+
201+
tuner.estimator.sagemaker_session.sagemaker_client.describe_hyper_parameter_tuning_job = Mock(
202+
name='describe_hyper_parameter_tuning_job', return_value=tuning_job_description)
203+
204+
tuner.latest_tuning_job = _TuningJob(tuner.estimator.sagemaker_session, JOB_NAME)
205+
206+
with pytest.raises(Exception) as e:
207+
tuner.best_training_job()
208+
209+
tuner.estimator.sagemaker_session.sagemaker_client.describe_hyper_parameter_tuning_job.assert_called_once_with(
210+
HyperParameterTuningJobName=JOB_NAME)
211+
assert 'Best training job not available for tuning job:' in str(e)
212+
213+
177214
#################################################################################
178215
# _ParameterRange Tests
179216

0 commit comments

Comments
 (0)