Skip to content

Commit 2a04096

Browse files
Merge branch 'master' into hyperparameter-tuning-support
2 parents 5c52054 + 731641c commit 2a04096

File tree

8 files changed

+58
-10
lines changed

8 files changed

+58
-10
lines changed

CHANGELOG.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ CHANGELOG
1111
* bug-fix: Local Mode: Show logs in Jupyter notebooks
1212
* feature: HyperparameterTuner: Add support for hyperparameter tuning jobs
1313
* feature: Analytics: Add functions for metrics in Training and Hyperparameter Tuning jobs
14+
* feature: Estimators: add support for tagging training jobs
1415

1516
1.3.0
1617
=====

src/sagemaker/estimator.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ class EstimatorBase(with_metaclass(ABCMeta, object)):
4444

4545
def __init__(self, role, train_instance_count, train_instance_type,
4646
train_volume_size=30, train_max_run=24 * 60 * 60, input_mode='File',
47-
output_path=None, output_kms_key=None, base_job_name=None, sagemaker_session=None):
47+
output_path=None, output_kms_key=None, base_job_name=None, sagemaker_session=None, tags=None):
4848
"""Initialize an ``EstimatorBase`` instance.
4949
5050
Args:
@@ -73,13 +73,16 @@ def __init__(self, role, train_instance_count, train_instance_type,
7373
sagemaker_session (sagemaker.session.Session): Session object which manages interactions with
7474
Amazon SageMaker APIs and any other AWS services needed. If not specified, the estimator creates one
7575
using the default AWS configuration chain.
76+
tags (list[dict]): List of tags for labeling a training job. For more, see
77+
https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html.
7678
"""
7779
self.role = role
7880
self.train_instance_count = train_instance_count
7981
self.train_instance_type = train_instance_type
8082
self.train_volume_size = train_volume_size
8183
self.train_max_run = train_max_run
8284
self.input_mode = input_mode
85+
self.tags = tags
8386

8487
if self.train_instance_type in ('local', 'local_gpu'):
8588
if self.train_instance_type == 'local_gpu' and self.train_instance_count > 1:
@@ -356,7 +359,7 @@ def start_new(cls, estimator, inputs):
356359
input_config=config['input_config'], role=config['role'],
357360
job_name=estimator._current_job_name, output_config=config['output_config'],
358361
resource_config=config['resource_config'], hyperparameters=hyperparameters,
359-
stop_condition=config['stop_condition'])
362+
stop_condition=config['stop_condition'], tags=estimator.tags)
360363

361364
return cls(estimator.sagemaker_session, estimator._current_job_name)
362365

src/sagemaker/session.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,7 @@ def default_bucket(self):
202202
return self._default_bucket
203203

204204
def train(self, image, input_mode, input_config, role, job_name, output_config,
205-
resource_config, hyperparameters, stop_condition):
205+
resource_config, hyperparameters, stop_condition, tags):
206206
"""Create an Amazon SageMaker training job.
207207
208208
Args:
@@ -231,6 +231,8 @@ def train(self, image, input_mode, input_config, role, job_name, output_config,
231231
keys and values, but ``str()`` will be called to convert them before training.
232232
stop_condition (dict): Defines when training shall finish. Contains entries that can be understood by the
233233
service like ``MaxRuntimeInSeconds``.
234+
tags (list[dict]): List of tags for labeling a training job. For more, see
235+
https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html.
234236
235237
Returns:
236238
str: ARN of the training job, if it is created.
@@ -251,6 +253,10 @@ def train(self, image, input_mode, input_config, role, job_name, output_config,
251253

252254
if hyperparameters and len(hyperparameters) > 0:
253255
train_request['HyperParameters'] = hyperparameters
256+
257+
if tags is not None:
258+
train_request['Tags'] = tags
259+
254260
LOGGER.info('Creating training-job with name: {}'.format(job_name))
255261
LOGGER.debug('train request: {}'.format(json.dumps(train_request, indent=4)))
256262
self.sagemaker_client.create_training_job(**train_request)

tests/unit/test_chainer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,8 @@ def _create_train_job(version):
120120
},
121121
'stop_condition': {
122122
'MaxRuntimeInSeconds': 24 * 60 * 60
123-
}
123+
},
124+
'tags': None,
124125
}
125126

126127

tests/unit/test_estimator.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -345,16 +345,18 @@ def test_attach_framework_with_tuning(sagemaker_session):
345345

346346
@patch('time.strftime', return_value=TIMESTAMP)
347347
def test_fit_verify_job_name(strftime, sagemaker_session):
348+
tags = [{'Name': 'some-tag', 'Value': 'value-for-tag'}]
348349
fw = DummyFramework(entry_point=SCRIPT_PATH, role='DummyRole', sagemaker_session=sagemaker_session,
349350
train_instance_count=INSTANCE_COUNT, train_instance_type=INSTANCE_TYPE,
350-
enable_cloudwatch_metrics=True)
351+
enable_cloudwatch_metrics=True, tags=tags)
351352
fw.fit(inputs=s3_input('s3://mybucket/train'))
352353

353354
_, _, train_kwargs = sagemaker_session.train.mock_calls[0]
354355

355356
assert train_kwargs['hyperparameters']['sagemaker_enable_cloudwatch_metrics']
356357
assert train_kwargs['image'] == IMAGE_NAME
357358
assert train_kwargs['input_mode'] == 'File'
359+
assert train_kwargs['tags'] == tags
358360
assert train_kwargs['job_name'] == JOB_NAME
359361
assert fw.latest_training_job.name == JOB_NAME
360362

@@ -494,7 +496,8 @@ def test_unsupported_type_in_dict():
494496
'InstanceType': INSTANCE_TYPE,
495497
'VolumeSizeInGB': 30
496498
},
497-
'stop_condition': {'MaxRuntimeInSeconds': 86400}
499+
'stop_condition': {'MaxRuntimeInSeconds': 86400},
500+
'tags': None,
498501
}
499502

500503
HYPERPARAMS = {'x': 1, 'y': 'hello'}

tests/unit/test_mxnet.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,8 @@ def _create_train_job(version):
9393
},
9494
'stop_condition': {
9595
'MaxRuntimeInSeconds': 24 * 60 * 60
96-
}
96+
},
97+
'tags': None,
9798
}
9899

99100

tests/unit/test_session.py

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,6 @@ def test_s3_input_all_arguments():
144144
JOB_NAME = 'jobname'
145145

146146
DEFAULT_EXPECTED_TRAIN_JOB_ARGS = {
147-
# 'HyperParameters': None,
148147
'OutputDataConfig': {
149148
'S3OutputPath': S3_OUTPUT
150149
},
@@ -226,7 +225,7 @@ def test_train_pack_to_request(sagemaker_session):
226225

227226
sagemaker_session.train(image=IMAGE, input_mode='File', input_config=in_config, role=EXPANDED_ROLE,
228227
job_name=JOB_NAME, output_config=out_config, resource_config=resource_config,
229-
hyperparameters=None, stop_condition=stop_cond)
228+
hyperparameters=None, stop_condition=stop_cond, tags=None)
230229

231230
assert sagemaker_session.sagemaker_client.method_calls[0] == (
232231
'create_training_job', (), DEFAULT_EXPECTED_TRAIN_JOB_ARGS)
@@ -266,6 +265,39 @@ def test_stop_tuning_job_client_error(sagemaker_session):
266265
assert 'An error occurred (MockException) when calling the Operation operation: MockMessage' in str(e)
267266

268267

268+
def test_train_pack_to_request_with_optional_params(sagemaker_session):
269+
in_config = [{
270+
'ChannelName': 'training',
271+
'DataSource': {
272+
'S3DataSource': {
273+
'S3DataDistributionType': 'FullyReplicated',
274+
'S3DataType': 'S3Prefix',
275+
'S3Uri': S3_INPUT_URI
276+
}
277+
}
278+
}]
279+
280+
out_config = {'S3OutputPath': S3_OUTPUT}
281+
282+
resource_config = {'InstanceCount': INSTANCE_COUNT,
283+
'InstanceType': INSTANCE_TYPE,
284+
'VolumeSizeInGB': MAX_SIZE}
285+
286+
stop_cond = {'MaxRuntimeInSeconds': MAX_TIME}
287+
288+
hyperparameters = {'foo': 'bar'}
289+
tags = [{'Name': 'some-tag', 'Value': 'value-for-tag'}]
290+
291+
sagemaker_session.train(image=IMAGE, input_mode='File', input_config=in_config, role=EXPANDED_ROLE,
292+
job_name=JOB_NAME, output_config=out_config, resource_config=resource_config,
293+
hyperparameters=hyperparameters, stop_condition=stop_cond, tags=tags)
294+
295+
_, _, actual_train_args = sagemaker_session.sagemaker_client.method_calls[0]
296+
297+
assert actual_train_args['HyperParameters'] == hyperparameters
298+
assert actual_train_args['Tags'] == tags
299+
300+
269301
@patch('sys.stdout', new_callable=io.BytesIO if six.PY2 else io.StringIO)
270302
def test_color_wrap(bio):
271303
color_wrap = sagemaker.logs.ColorWrap()

tests/unit/test_tf_estimator.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,8 @@ def _create_train_job(tf_version):
101101
},
102102
'stop_condition': {
103103
'MaxRuntimeInSeconds': 24 * 60 * 60
104-
}
104+
},
105+
'tags': None,
105106
}
106107

107108

0 commit comments

Comments
 (0)