diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 879d04c760..6a64ecffe7 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -8,6 +8,7 @@ CHANGELOG * bug-fix: Remove __all__ and add noqa in __init__ * bug-fix: Estimators: Change max_iterations hyperparameter key for KMeans * bug-fix: Estimators: Remove unused argument job_details for ``EstimatorBase.attach()`` +* feature: Estimators: add support for tagging training jobs 1.3.0 ===== diff --git a/src/sagemaker/estimator.py b/src/sagemaker/estimator.py index bd6d3a334e..ed4581aa06 100644 --- a/src/sagemaker/estimator.py +++ b/src/sagemaker/estimator.py @@ -44,7 +44,7 @@ class EstimatorBase(with_metaclass(ABCMeta, object)): def __init__(self, role, train_instance_count, train_instance_type, train_volume_size=30, train_max_run=24 * 60 * 60, input_mode='File', - output_path=None, output_kms_key=None, base_job_name=None, sagemaker_session=None): + output_path=None, output_kms_key=None, base_job_name=None, sagemaker_session=None, tags=None): """Initialize an ``EstimatorBase`` instance. Args: @@ -73,6 +73,8 @@ def __init__(self, role, train_instance_count, train_instance_type, sagemaker_session (sagemaker.session.Session): Session object which manages interactions with Amazon SageMaker APIs and any other AWS services needed. If not specified, the estimator creates one using the default AWS configuration chain. + tags (list[dict]): List of tags for labeling a training job. For more, see + https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html. """ self.role = role self.train_instance_count = train_instance_count @@ -80,6 +82,7 @@ def __init__(self, role, train_instance_count, train_instance_type, self.train_volume_size = train_volume_size self.train_max_run = train_max_run self.input_mode = input_mode + self.tags = tags if self.train_instance_type in ('local', 'local_gpu'): if self.train_instance_type == 'local_gpu' and self.train_instance_count > 1: @@ -345,7 +348,8 @@ def start_new(cls, estimator, inputs): estimator.sagemaker_session.train(image=estimator.train_image(), input_mode=estimator.input_mode, input_config=input_config, role=role, job_name=estimator._current_job_name, output_config=output_config, resource_config=resource_config, - hyperparameters=hyperparameters, stop_condition=stop_condition) + hyperparameters=hyperparameters, stop_condition=stop_condition, + tags=estimator.tags) return cls(estimator.sagemaker_session, estimator._current_job_name) diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index 404af595d5..752936bdc6 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -203,7 +203,7 @@ def default_bucket(self): return self._default_bucket def train(self, image, input_mode, input_config, role, job_name, output_config, - resource_config, hyperparameters, stop_condition): + resource_config, hyperparameters, stop_condition, tags): """Create an Amazon SageMaker training job. Args: @@ -232,6 +232,8 @@ def train(self, image, input_mode, input_config, role, job_name, output_config, keys and values, but ``str()`` will be called to convert them before training. stop_condition (dict): Defines when training shall finish. Contains entries that can be understood by the service like ``MaxRuntimeInSeconds``. + tags (list[dict]): List of tags for labeling a training job. For more, see + https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html. Returns: str: ARN of the training job, if it is created. @@ -242,7 +244,6 @@ def train(self, image, input_mode, input_config, role, job_name, output_config, 'TrainingImage': image, 'TrainingInputMode': input_mode }, - # 'HyperParameters': hyperparameters, 'InputDataConfig': input_config, 'OutputDataConfig': output_config, 'TrainingJobName': job_name, @@ -253,6 +254,10 @@ def train(self, image, input_mode, input_config, role, job_name, output_config, if hyperparameters and len(hyperparameters) > 0: train_request['HyperParameters'] = hyperparameters + + if tags is not None: + train_request['Tags'] = tags + LOGGER.info('Creating training-job with name: {}'.format(job_name)) LOGGER.debug('train request: {}'.format(json.dumps(train_request, indent=4))) self.sagemaker_client.create_training_job(**train_request) diff --git a/tests/unit/test_chainer.py b/tests/unit/test_chainer.py index 82bb598753..49fe59a560 100644 --- a/tests/unit/test_chainer.py +++ b/tests/unit/test_chainer.py @@ -120,7 +120,8 @@ def _create_train_job(version): }, 'stop_condition': { 'MaxRuntimeInSeconds': 24 * 60 * 60 - } + }, + 'tags': None, } diff --git a/tests/unit/test_estimator.py b/tests/unit/test_estimator.py index 53da9a97a8..551a1b1386 100644 --- a/tests/unit/test_estimator.py +++ b/tests/unit/test_estimator.py @@ -292,9 +292,10 @@ def test_fit_then_fit_again(sagemaker_session): @patch('time.strftime', return_value=TIMESTAMP) def test_fit_verify_job_name(strftime, sagemaker_session): + tags = [{'Name': 'some-tag', 'Value': 'value-for-tag'}] fw = DummyFramework(entry_point=SCRIPT_PATH, role='DummyRole', sagemaker_session=sagemaker_session, train_instance_count=INSTANCE_COUNT, train_instance_type=INSTANCE_TYPE, - enable_cloudwatch_metrics=True) + enable_cloudwatch_metrics=True, tags=tags) fw.fit(inputs=s3_input('s3://mybucket/train')) _, _, train_kwargs = sagemaker_session.train.mock_calls[0] @@ -302,6 +303,7 @@ def test_fit_verify_job_name(strftime, sagemaker_session): assert train_kwargs['hyperparameters']['sagemaker_enable_cloudwatch_metrics'] assert train_kwargs['image'] == IMAGE_NAME assert train_kwargs['input_mode'] == 'File' + assert train_kwargs['tags'] == tags assert train_kwargs['job_name'] == JOB_NAME assert fw.latest_training_job.name == JOB_NAME @@ -475,7 +477,8 @@ def test_unsupported_type_in_dict(): 'InstanceType': INSTANCE_TYPE, 'VolumeSizeInGB': 30 }, - 'stop_condition': {'MaxRuntimeInSeconds': 86400} + 'stop_condition': {'MaxRuntimeInSeconds': 86400}, + 'tags': None, } diff --git a/tests/unit/test_mxnet.py b/tests/unit/test_mxnet.py index 34f76cc227..f0cea93de6 100644 --- a/tests/unit/test_mxnet.py +++ b/tests/unit/test_mxnet.py @@ -93,7 +93,8 @@ def _create_train_job(version): }, 'stop_condition': { 'MaxRuntimeInSeconds': 24 * 60 * 60 - } + }, + 'tags': None, } diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index de3500aff3..8f7edf61dc 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -142,7 +142,6 @@ def test_s3_input_all_arguments(): JOB_NAME = 'jobname' DEFAULT_EXPECTED_TRAIN_JOB_ARGS = { - # 'HyperParameters': None, 'OutputDataConfig': { 'S3OutputPath': S3_OUTPUT }, @@ -224,12 +223,45 @@ def test_train_pack_to_request(sagemaker_session): sagemaker_session.train(image=IMAGE, input_mode='File', input_config=in_config, role=EXPANDED_ROLE, job_name=JOB_NAME, output_config=out_config, resource_config=resource_config, - hyperparameters=None, stop_condition=stop_cond) + hyperparameters=None, stop_condition=stop_cond, tags=None) assert sagemaker_session.sagemaker_client.method_calls[0] == ( 'create_training_job', (), DEFAULT_EXPECTED_TRAIN_JOB_ARGS) +def test_train_pack_to_request_with_optional_params(sagemaker_session): + in_config = [{ + 'ChannelName': 'training', + 'DataSource': { + 'S3DataSource': { + 'S3DataDistributionType': 'FullyReplicated', + 'S3DataType': 'S3Prefix', + 'S3Uri': S3_INPUT_URI + } + } + }] + + out_config = {'S3OutputPath': S3_OUTPUT} + + resource_config = {'InstanceCount': INSTANCE_COUNT, + 'InstanceType': INSTANCE_TYPE, + 'VolumeSizeInGB': MAX_SIZE} + + stop_cond = {'MaxRuntimeInSeconds': MAX_TIME} + + hyperparameters = {'foo': 'bar'} + tags = [{'Name': 'some-tag', 'Value': 'value-for-tag'}] + + sagemaker_session.train(image=IMAGE, input_mode='File', input_config=in_config, role=EXPANDED_ROLE, + job_name=JOB_NAME, output_config=out_config, resource_config=resource_config, + hyperparameters=hyperparameters, stop_condition=stop_cond, tags=tags) + + _, _, actual_train_args = sagemaker_session.sagemaker_client.method_calls[0] + + assert actual_train_args['HyperParameters'] == hyperparameters + assert actual_train_args['Tags'] == tags + + @patch('sys.stdout', new_callable=io.BytesIO if six.PY2 else io.StringIO) def test_color_wrap(bio): color_wrap = sagemaker.logs.ColorWrap() diff --git a/tests/unit/test_tf_estimator.py b/tests/unit/test_tf_estimator.py index dc4192b06e..2123eee950 100644 --- a/tests/unit/test_tf_estimator.py +++ b/tests/unit/test_tf_estimator.py @@ -101,7 +101,8 @@ def _create_train_job(tf_version): }, 'stop_condition': { 'MaxRuntimeInSeconds': 24 * 60 * 60 - } + }, + 'tags': None, }