Skip to content

Commit 372a99a

Browse files
committed
Add tags for training jobs
1 parent afb3bbe commit 372a99a

File tree

4 files changed

+19
-6
lines changed

4 files changed

+19
-6
lines changed

CHANGELOG.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ CHANGELOG
66
========
77

88
* bug-fix: Estimators: Change max_iterations hyperparameter key for KMeans
9+
* feature: Estimators: add support for tagging training jobs
910

1011
1.3.0
1112
=====

src/sagemaker/estimator.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ class EstimatorBase(with_metaclass(ABCMeta, object)):
4444

4545
def __init__(self, role, train_instance_count, train_instance_type,
4646
train_volume_size=30, train_max_run=24 * 60 * 60, input_mode='File',
47-
output_path=None, output_kms_key=None, base_job_name=None, sagemaker_session=None):
47+
output_path=None, output_kms_key=None, base_job_name=None, sagemaker_session=None, tags=None):
4848
"""Initialize an ``EstimatorBase`` instance.
4949
5050
Args:
@@ -73,13 +73,16 @@ def __init__(self, role, train_instance_count, train_instance_type,
7373
sagemaker_session (sagemaker.session.Session): Session object which manages interactions with
7474
Amazon SageMaker APIs and any other AWS services needed. If not specified, the estimator creates one
7575
using the default AWS configuration chain.
76+
tags (list[dict]): List of tags for labeling a training job. For more, see
77+
https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html.
7678
"""
7779
self.role = role
7880
self.train_instance_count = train_instance_count
7981
self.train_instance_type = train_instance_type
8082
self.train_volume_size = train_volume_size
8183
self.train_max_run = train_max_run
8284
self.input_mode = input_mode
85+
self.tags = tags
8386

8487
if self.train_instance_type in ('local', 'local_gpu'):
8588
if self.train_instance_type == 'local_gpu' and self.train_instance_count > 1:
@@ -345,7 +348,8 @@ def start_new(cls, estimator, inputs):
345348
estimator.sagemaker_session.train(image=estimator.train_image(), input_mode=estimator.input_mode,
346349
input_config=input_config, role=role, job_name=estimator._current_job_name,
347350
output_config=output_config, resource_config=resource_config,
348-
hyperparameters=hyperparameters, stop_condition=stop_condition)
351+
hyperparameters=hyperparameters, stop_condition=stop_condition,
352+
tags=estimator.tags)
349353

350354
return cls(estimator.sagemaker_session, estimator._current_job_name)
351355

src/sagemaker/session.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,7 @@ def default_bucket(self):
203203
return self._default_bucket
204204

205205
def train(self, image, input_mode, input_config, role, job_name, output_config,
206-
resource_config, hyperparameters, stop_condition):
206+
resource_config, hyperparameters, stop_condition, tags):
207207
"""Create an Amazon SageMaker training job.
208208
209209
Args:
@@ -232,6 +232,8 @@ def train(self, image, input_mode, input_config, role, job_name, output_config,
232232
keys and values, but ``str()`` will be called to convert them before training.
233233
stop_condition (dict): Defines when training shall finish. Contains entries that can be understood by the
234234
service like ``MaxRuntimeInSeconds``.
235+
tags (list[dict]): List of tags for labeling a training job. For more, see
236+
https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html.
235237
236238
Returns:
237239
str: ARN of the training job, if it is created.
@@ -242,7 +244,6 @@ def train(self, image, input_mode, input_config, role, job_name, output_config,
242244
'TrainingImage': image,
243245
'TrainingInputMode': input_mode
244246
},
245-
# 'HyperParameters': hyperparameters,
246247
'InputDataConfig': input_config,
247248
'OutputDataConfig': output_config,
248249
'TrainingJobName': job_name,
@@ -253,6 +254,10 @@ def train(self, image, input_mode, input_config, role, job_name, output_config,
253254

254255
if hyperparameters and len(hyperparameters) > 0:
255256
train_request['HyperParameters'] = hyperparameters
257+
258+
if tags is not None:
259+
train_request['Tags'] = tags
260+
256261
LOGGER.info('Creating training-job with name: {}'.format(job_name))
257262
LOGGER.debug('train request: {}'.format(json.dumps(train_request, indent=4)))
258263
self.sagemaker_client.create_training_job(**train_request)

tests/unit/test_estimator.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -292,16 +292,18 @@ def test_fit_then_fit_again(sagemaker_session):
292292

293293
@patch('time.strftime', return_value=TIMESTAMP)
294294
def test_fit_verify_job_name(strftime, sagemaker_session):
295+
tags = [{'Name': 'some-tag'}]
295296
fw = DummyFramework(entry_point=SCRIPT_PATH, role='DummyRole', sagemaker_session=sagemaker_session,
296297
train_instance_count=INSTANCE_COUNT, train_instance_type=INSTANCE_TYPE,
297-
enable_cloudwatch_metrics=True)
298+
enable_cloudwatch_metrics=True, tags=tags)
298299
fw.fit(inputs=s3_input('s3://mybucket/train'))
299300

300301
_, _, train_kwargs = sagemaker_session.train.mock_calls[0]
301302

302303
assert train_kwargs['hyperparameters']['sagemaker_enable_cloudwatch_metrics']
303304
assert train_kwargs['image'] == IMAGE_NAME
304305
assert train_kwargs['input_mode'] == 'File'
306+
assert train_kwargs['tags'] == tags
305307
assert train_kwargs['job_name'] == JOB_NAME
306308
assert fw.latest_training_job.name == JOB_NAME
307309

@@ -475,7 +477,8 @@ def test_unsupported_type_in_dict():
475477
'InstanceType': INSTANCE_TYPE,
476478
'VolumeSizeInGB': 30
477479
},
478-
'stop_condition': {'MaxRuntimeInSeconds': 86400}
480+
'stop_condition': {'MaxRuntimeInSeconds': 86400},
481+
'tags': None,
479482
}
480483

481484

0 commit comments

Comments
 (0)