Skip to content

Add tags for training jobs #209

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jun 4, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ CHANGELOG
* bug-fix: Remove __all__ and add noqa in __init__
* bug-fix: Estimators: Change max_iterations hyperparameter key for KMeans
* bug-fix: Estimators: Remove unused argument job_details for ``EstimatorBase.attach()``
* feature: Estimators: add support for tagging training jobs

1.3.0
=====
Expand Down
8 changes: 6 additions & 2 deletions src/sagemaker/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ class EstimatorBase(with_metaclass(ABCMeta, object)):

def __init__(self, role, train_instance_count, train_instance_type,
train_volume_size=30, train_max_run=24 * 60 * 60, input_mode='File',
output_path=None, output_kms_key=None, base_job_name=None, sagemaker_session=None):
output_path=None, output_kms_key=None, base_job_name=None, sagemaker_session=None, tags=None):
"""Initialize an ``EstimatorBase`` instance.

Args:
Expand Down Expand Up @@ -73,13 +73,16 @@ def __init__(self, role, train_instance_count, train_instance_type,
sagemaker_session (sagemaker.session.Session): Session object which manages interactions with
Amazon SageMaker APIs and any other AWS services needed. If not specified, the estimator creates one
using the default AWS configuration chain.
tags (list[dict]): List of tags for labeling a training job. For more, see
https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html.
"""
self.role = role
self.train_instance_count = train_instance_count
self.train_instance_type = train_instance_type
self.train_volume_size = train_volume_size
self.train_max_run = train_max_run
self.input_mode = input_mode
self.tags = tags

if self.train_instance_type in ('local', 'local_gpu'):
if self.train_instance_type == 'local_gpu' and self.train_instance_count > 1:
Expand Down Expand Up @@ -345,7 +348,8 @@ def start_new(cls, estimator, inputs):
estimator.sagemaker_session.train(image=estimator.train_image(), input_mode=estimator.input_mode,
input_config=input_config, role=role, job_name=estimator._current_job_name,
output_config=output_config, resource_config=resource_config,
hyperparameters=hyperparameters, stop_condition=stop_condition)
hyperparameters=hyperparameters, stop_condition=stop_condition,
tags=estimator.tags)

return cls(estimator.sagemaker_session, estimator._current_job_name)

Expand Down
9 changes: 7 additions & 2 deletions src/sagemaker/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ def default_bucket(self):
return self._default_bucket

def train(self, image, input_mode, input_config, role, job_name, output_config,
resource_config, hyperparameters, stop_condition):
resource_config, hyperparameters, stop_condition, tags):
"""Create an Amazon SageMaker training job.

Args:
Expand Down Expand Up @@ -232,6 +232,8 @@ def train(self, image, input_mode, input_config, role, job_name, output_config,
keys and values, but ``str()`` will be called to convert them before training.
stop_condition (dict): Defines when training shall finish. Contains entries that can be understood by the
service like ``MaxRuntimeInSeconds``.
tags (list[dict]): List of tags for labeling a training job. For more, see
https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html.

Returns:
str: ARN of the training job, if it is created.
Expand All @@ -242,7 +244,6 @@ def train(self, image, input_mode, input_config, role, job_name, output_config,
'TrainingImage': image,
'TrainingInputMode': input_mode
},
# 'HyperParameters': hyperparameters,
'InputDataConfig': input_config,
'OutputDataConfig': output_config,
'TrainingJobName': job_name,
Expand All @@ -253,6 +254,10 @@ def train(self, image, input_mode, input_config, role, job_name, output_config,

if hyperparameters and len(hyperparameters) > 0:
train_request['HyperParameters'] = hyperparameters

if tags is not None:
train_request['Tags'] = tags

LOGGER.info('Creating training-job with name: {}'.format(job_name))
LOGGER.debug('train request: {}'.format(json.dumps(train_request, indent=4)))
self.sagemaker_client.create_training_job(**train_request)
Expand Down
3 changes: 2 additions & 1 deletion tests/unit/test_chainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,8 @@ def _create_train_job(version):
},
'stop_condition': {
'MaxRuntimeInSeconds': 24 * 60 * 60
}
},
'tags': None,
}


Expand Down
7 changes: 5 additions & 2 deletions tests/unit/test_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,16 +292,18 @@ def test_fit_then_fit_again(sagemaker_session):

@patch('time.strftime', return_value=TIMESTAMP)
def test_fit_verify_job_name(strftime, sagemaker_session):
tags = [{'Name': 'some-tag', 'Value': 'value-for-tag'}]
fw = DummyFramework(entry_point=SCRIPT_PATH, role='DummyRole', sagemaker_session=sagemaker_session,
train_instance_count=INSTANCE_COUNT, train_instance_type=INSTANCE_TYPE,
enable_cloudwatch_metrics=True)
enable_cloudwatch_metrics=True, tags=tags)
fw.fit(inputs=s3_input('s3://mybucket/train'))

_, _, train_kwargs = sagemaker_session.train.mock_calls[0]

assert train_kwargs['hyperparameters']['sagemaker_enable_cloudwatch_metrics']
assert train_kwargs['image'] == IMAGE_NAME
assert train_kwargs['input_mode'] == 'File'
assert train_kwargs['tags'] == tags
assert train_kwargs['job_name'] == JOB_NAME
assert fw.latest_training_job.name == JOB_NAME

Expand Down Expand Up @@ -475,7 +477,8 @@ def test_unsupported_type_in_dict():
'InstanceType': INSTANCE_TYPE,
'VolumeSizeInGB': 30
},
'stop_condition': {'MaxRuntimeInSeconds': 86400}
'stop_condition': {'MaxRuntimeInSeconds': 86400},
'tags': None,
}


Expand Down
3 changes: 2 additions & 1 deletion tests/unit/test_mxnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,8 @@ def _create_train_job(version):
},
'stop_condition': {
'MaxRuntimeInSeconds': 24 * 60 * 60
}
},
'tags': None,
}


Expand Down
36 changes: 34 additions & 2 deletions tests/unit/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,6 @@ def test_s3_input_all_arguments():
JOB_NAME = 'jobname'

DEFAULT_EXPECTED_TRAIN_JOB_ARGS = {
# 'HyperParameters': None,
'OutputDataConfig': {
'S3OutputPath': S3_OUTPUT
},
Expand Down Expand Up @@ -224,12 +223,45 @@ def test_train_pack_to_request(sagemaker_session):

sagemaker_session.train(image=IMAGE, input_mode='File', input_config=in_config, role=EXPANDED_ROLE,
job_name=JOB_NAME, output_config=out_config, resource_config=resource_config,
hyperparameters=None, stop_condition=stop_cond)
hyperparameters=None, stop_condition=stop_cond, tags=None)

assert sagemaker_session.sagemaker_client.method_calls[0] == (
'create_training_job', (), DEFAULT_EXPECTED_TRAIN_JOB_ARGS)


def test_train_pack_to_request_with_optional_params(sagemaker_session):
in_config = [{
'ChannelName': 'training',
'DataSource': {
'S3DataSource': {
'S3DataDistributionType': 'FullyReplicated',
'S3DataType': 'S3Prefix',
'S3Uri': S3_INPUT_URI
}
}
}]

out_config = {'S3OutputPath': S3_OUTPUT}

resource_config = {'InstanceCount': INSTANCE_COUNT,
'InstanceType': INSTANCE_TYPE,
'VolumeSizeInGB': MAX_SIZE}

stop_cond = {'MaxRuntimeInSeconds': MAX_TIME}

hyperparameters = {'foo': 'bar'}
tags = [{'Name': 'some-tag', 'Value': 'value-for-tag'}]

sagemaker_session.train(image=IMAGE, input_mode='File', input_config=in_config, role=EXPANDED_ROLE,
job_name=JOB_NAME, output_config=out_config, resource_config=resource_config,
hyperparameters=hyperparameters, stop_condition=stop_cond, tags=tags)

_, _, actual_train_args = sagemaker_session.sagemaker_client.method_calls[0]

assert actual_train_args['HyperParameters'] == hyperparameters
assert actual_train_args['Tags'] == tags


@patch('sys.stdout', new_callable=io.BytesIO if six.PY2 else io.StringIO)
def test_color_wrap(bio):
color_wrap = sagemaker.logs.ColorWrap()
Expand Down
3 changes: 2 additions & 1 deletion tests/unit/test_tf_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,8 @@ def _create_train_job(tf_version):
},
'stop_condition': {
'MaxRuntimeInSeconds': 24 * 60 * 60
}
},
'tags': None,
}


Expand Down