Skip to content

Commit 48c4e02

Browse files
committed
Add tags for hyperparameter tuning jobs
1 parent 1c22945 commit 48c4e02

File tree

3 files changed

+14
-5
lines changed

3 files changed

+14
-5
lines changed

src/sagemaker/session.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,6 @@ def train(self, image, input_mode, input_config, role, job_name, output_config,
241241
'TrainingImage': image,
242242
'TrainingInputMode': input_mode
243243
},
244-
# 'HyperParameters': hyperparameters,
245244
'InputDataConfig': input_config,
246245
'OutputDataConfig': output_config,
247246
'TrainingJobName': job_name,
@@ -259,7 +258,7 @@ def train(self, image, input_mode, input_config, role, job_name, output_config,
259258
def tune(self, job_name, strategy, objective_type, objective_metric_name,
260259
max_jobs, max_parallel_jobs, parameter_ranges,
261260
static_hyperparameters, image, input_mode, metric_definitions,
262-
role, input_config, output_config, resource_config, stop_condition):
261+
role, input_config, output_config, resource_config, stop_condition, tags):
263262
"""Create an Amazon SageMaker hyperparameter tuning job
264263
265264
Args:
@@ -292,6 +291,7 @@ def tune(self, job_name, strategy, objective_type, objective_metric_name,
292291
instance_type (str): Type of EC2 instance to use for training, for example, 'ml.c4.xlarge'.
293292
stop_condition (dict): Defines when training shall finish. Contains entries that can be understood by the
294293
service like ``MaxRuntimeInSeconds``.
294+
tags (list[dict]): List of tags for labeling the tuning job.
295295
"""
296296
tune_request = {
297297
'HyperParameterTuningJobName': job_name,
@@ -324,6 +324,9 @@ def tune(self, job_name, strategy, objective_type, objective_metric_name,
324324
if metric_definitions is not None:
325325
tune_request['TrainingJobDefinition']['AlgorithmSpecification']['MetricDefinitions'] = metric_definitions
326326

327+
if tags is not None:
328+
tune_request['Tags'] = tags
329+
327330
LOGGER.info('Creating hyperparameter tuning job with name: {}'.format(job_name))
328331
LOGGER.debug('tune request: {}'.format(json.dumps(tune_request, indent=4)))
329332
self.sagemaker_client.create_hyper_parameter_tuning_job(**tune_request)

src/sagemaker/tuner.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ class HyperparameterTuner(object):
9191

9292
def __init__(self, estimator, objective_metric_name, hyperparameter_ranges, metric_definitions=None,
9393
strategy='Bayesian', objective_type='Maximize', max_jobs=1, max_parallel_jobs=1,
94-
base_tuning_job_name=None):
94+
tags=None, base_tuning_job_name=None):
9595
self._hyperparameter_ranges = hyperparameter_ranges
9696
if self._hyperparameter_ranges is None or len(self._hyperparameter_ranges) == 0:
9797
raise ValueError('Need to specify hyperparameter ranges')
@@ -106,6 +106,7 @@ def __init__(self, estimator, objective_metric_name, hyperparameter_ranges, metr
106106
self.max_jobs = max_jobs
107107
self.max_parallel_jobs = max_parallel_jobs
108108

109+
self.tags = tags
109110
self.base_tuning_job_name = base_tuning_job_name
110111
self._current_job_name = None
111112
self.latest_tuning_job = None
@@ -387,7 +388,7 @@ def start_new(cls, tuner, inputs):
387388
role=(config['role']), input_config=(config['input_config']),
388389
output_config=(config['output_config']),
389390
resource_config=(config['resource_config']),
390-
stop_condition=(config['stop_condition']))
391+
stop_condition=(config['stop_condition']), tags=tuner.tags)
391392

392393
return cls(tuner.sagemaker_session, tuner._current_job_name)
393394

tests/unit/test_tuner.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -201,9 +201,13 @@ def test_fit_pca(sagemaker_session, tuner):
201201
pca.subtract_mean = True
202202
pca.extra_components = 5
203203

204+
tuner.estimator = pca
205+
206+
tags = [{'Name': 'some-tag-without-a-value'}]
207+
tuner.tags = tags
208+
204209
hyperparameter_ranges = {'num_components': IntegerParameter(2, 4),
205210
'algorithm_mode': CategoricalParameter(['regular', 'randomized'])}
206-
tuner.estimator = pca
207211
tuner._hyperparameter_ranges = hyperparameter_ranges
208212

209213
records = RecordSet(s3_data=INPUTS, num_records=1, feature_dim=1)
@@ -215,6 +219,7 @@ def test_fit_pca(sagemaker_session, tuner):
215219
assert tune_kwargs['static_hyperparameters']['extra_components'] == '5'
216220
assert len(tune_kwargs['parameter_ranges']['IntegerParameterRanges']) == 1
217221
assert tune_kwargs['job_name'].startswith('pca')
222+
assert tune_kwargs['tags'] == tags
218223
assert tuner.estimator.mini_batch_size == 9999
219224

220225

0 commit comments

Comments
 (0)