diff --git a/src/sagemaker/estimator.py b/src/sagemaker/estimator.py index 6ff4d7b06e..12f7176907 100644 --- a/src/sagemaker/estimator.py +++ b/src/sagemaker/estimator.py @@ -500,15 +500,16 @@ def transformer(self, instance_count, instance_type, strategy=None, assemble_wit volume_kms_key (str): Optional. KMS key ID for encrypting the volume attached to the ML compute instance (default: None). """ + tags = tags or self.tags + if self.latest_training_job is not None: - model_name = self.sagemaker_session.create_model_from_job(self.latest_training_job.name, role=role) + model_name = self.sagemaker_session.create_model_from_job(self.latest_training_job.name, role=role, + tags=tags) else: logging.warning('No finished training job found associated with this estimator. Please make sure' 'this estimator is only used for building workflow config') model_name = self._current_job_name - tags = tags or self.tags - return Transformer(model_name, instance_count, instance_type, strategy=strategy, assemble_with=assemble_with, output_path=output_path, output_kms_key=output_kms_key, accept=accept, max_concurrent_transforms=max_concurrent_transforms, max_payload=max_payload, @@ -1061,7 +1062,8 @@ def transformer(self, instance_count, instance_type, strategy=None, assemble_wit container_def = model.prepare_container_def(instance_type) model_name = model.name or name_from_image(container_def['Image']) vpc_config = model.vpc_config - self.sagemaker_session.create_model(model_name, role, container_def, vpc_config) + tags = tags or self.tags + self.sagemaker_session.create_model(model_name, role, container_def, vpc_config, tags=tags) transform_env = model.env.copy() if env is not None: transform_env.update(env) diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index f5ff344edc..9d958810d8 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -624,7 +624,8 @@ def create_model(self, name, role, container_defs, vpc_config=None, return name def create_model_from_job(self, training_job_name, name=None, role=None, primary_container_image=None, - model_data_url=None, env=None, vpc_config_override=vpc_utils.VPC_CONFIG_DEFAULT): + model_data_url=None, env=None, vpc_config_override=vpc_utils.VPC_CONFIG_DEFAULT, + tags=None): """Create an Amazon SageMaker ``Model`` from a SageMaker Training Job. Args: @@ -642,6 +643,8 @@ def create_model_from_job(self, training_job_name, name=None, role=None, primary Default: use VpcConfig from training job. * 'Subnets' (list[str]): List of subnet ids. * 'SecurityGroupIds' (list[str]): List of security group ids. + tags(List[dict[str, str]]): Optional. The list of tags to add to the model. For more, see + https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html. Returns: str: The name of the created ``Model``. @@ -655,7 +658,7 @@ def create_model_from_job(self, training_job_name, name=None, role=None, primary model_data_url=model_data_url or training_job['ModelArtifacts']['S3ModelArtifacts'], env=env) vpc_config = _vpc_config_from_training_job(training_job, vpc_config_override) - return self.create_model(name, role, primary_container, vpc_config=vpc_config) + return self.create_model(name, role, primary_container, vpc_config=vpc_config, tags=tags) def create_model_package_from_algorithm(self, name, description, algorithm_arn, model_data): """Create a SageMaker Model Package from the results of training with an Algorithm Package diff --git a/tests/integ/test_transformer.py b/tests/integ/test_transformer.py index 3d121fbb5a..d47e0f7373 100644 --- a/tests/integ/test_transformer.py +++ b/tests/integ/test_transformer.py @@ -22,6 +22,7 @@ from sagemaker import KMeans from sagemaker.mxnet import MXNet from sagemaker.transformer import Transformer +from sagemaker.estimator import Estimator from sagemaker.utils import unique_name_from_base from tests.integ import DATA_DIR, TRAINING_DEFAULT_TIMEOUT_MINUTES, TRANSFORM_DEFAULT_TIMEOUT_MINUTES from tests.integ.kms_utils import get_or_create_kms_key @@ -148,6 +149,89 @@ def test_transform_mxnet_vpc(sagemaker_session, mxnet_full_version): assert [security_group_id] == model_desc['VpcConfig']['SecurityGroupIds'] +def test_transform_mxnet_tags(sagemaker_session, mxnet_full_version): + data_path = os.path.join(DATA_DIR, 'mxnet_mnist') + script_path = os.path.join(data_path, 'mnist.py') + tags = [{'Key': 'some-tag', 'Value': 'value-for-tag'}] + + mx = MXNet(entry_point=script_path, role='SageMakerRole', train_instance_count=1, + train_instance_type='ml.c4.xlarge', sagemaker_session=sagemaker_session, + framework_version=mxnet_full_version) + + train_input = mx.sagemaker_session.upload_data(path=os.path.join(data_path, 'train'), + key_prefix='integ-test-data/mxnet_mnist/train') + test_input = mx.sagemaker_session.upload_data(path=os.path.join(data_path, 'test'), + key_prefix='integ-test-data/mxnet_mnist/test') + job_name = unique_name_from_base('test-mxnet-transform') + + with timeout(minutes=TRAINING_DEFAULT_TIMEOUT_MINUTES): + mx.fit({'train': train_input, 'test': test_input}, job_name=job_name) + + transform_input_path = os.path.join(data_path, 'transform', 'data.csv') + transform_input_key_prefix = 'integ-test-data/mxnet_mnist/transform' + transform_input = mx.sagemaker_session.upload_data(path=transform_input_path, + key_prefix=transform_input_key_prefix) + + transformer = mx.transformer(1, 'ml.m4.xlarge', tags=tags) + transformer.transform(transform_input, content_type='text/csv') + + with timeout_and_delete_model_with_transformer(transformer, sagemaker_session, + minutes=TRANSFORM_DEFAULT_TIMEOUT_MINUTES): + transformer.wait() + model_desc = sagemaker_session.sagemaker_client.describe_model(ModelName=transformer.model_name) + model_tags = sagemaker_session.sagemaker_client.list_tags(ResourceArn=model_desc['ModelArn'])['Tags'] + assert tags == model_tags + + +def test_transform_byo_estimator(sagemaker_session): + data_path = os.path.join(DATA_DIR, 'one_p_mnist') + pickle_args = {} if sys.version_info.major == 2 else {'encoding': 'latin1'} + tags = [{'Key': 'some-tag', 'Value': 'value-for-tag'}] + + # Load the data into memory as numpy arrays + train_set_path = os.path.join(data_path, 'mnist.pkl.gz') + with gzip.open(train_set_path, 'rb') as f: + train_set, _, _ = pickle.load(f, **pickle_args) + + kmeans = KMeans(role='SageMakerRole', train_instance_count=1, + train_instance_type='ml.c4.xlarge', k=10, sagemaker_session=sagemaker_session, + output_path='s3://{}/'.format(sagemaker_session.default_bucket())) + + # set kmeans specific hp + kmeans.init_method = 'random' + kmeans.max_iterators = 1 + kmeans.tol = 1 + kmeans.num_trials = 1 + kmeans.local_init_method = 'kmeans++' + kmeans.half_life_time_size = 1 + kmeans.epochs = 1 + + records = kmeans.record_set(train_set[0][:100]) + + job_name = unique_name_from_base('test-kmeans-attach') + + with timeout(minutes=TRAINING_DEFAULT_TIMEOUT_MINUTES): + kmeans.fit(records, job_name=job_name) + + transform_input_path = os.path.join(data_path, 'transform_input.csv') + transform_input_key_prefix = 'integ-test-data/one_p_mnist/transform' + transform_input = kmeans.sagemaker_session.upload_data(path=transform_input_path, + key_prefix=transform_input_key_prefix) + + estimator = Estimator.attach(training_job_name=job_name, + sagemaker_session=sagemaker_session) + + transformer = estimator.transformer(1, 'ml.m4.xlarge', tags=tags) + transformer.transform(transform_input, content_type='text/csv') + + with timeout_and_delete_model_with_transformer(transformer, sagemaker_session, + minutes=TRANSFORM_DEFAULT_TIMEOUT_MINUTES): + transformer.wait() + model_desc = sagemaker_session.sagemaker_client.describe_model(ModelName=transformer.model_name) + model_tags = sagemaker_session.sagemaker_client.list_tags(ResourceArn=model_desc['ModelArn'])['Tags'] + assert tags == model_tags + + def _create_transformer_and_transform_job(estimator, transform_input, volume_kms_key=None): transformer = estimator.transformer(1, 'ml.m4.xlarge', volume_kms_key=volume_kms_key) transformer.transform(transform_input, content_type='text/csv') diff --git a/tests/unit/test_estimator.py b/tests/unit/test_estimator.py index c5d7a9f426..1121659973 100644 --- a/tests/unit/test_estimator.py +++ b/tests/unit/test_estimator.py @@ -624,7 +624,7 @@ def test_framework_transformer_creation(name_from_image, sagemaker_session): transformer = fw.transformer(INSTANCE_COUNT, INSTANCE_TYPE) name_from_image.assert_called_with(MODEL_IMAGE) - sagemaker_session.create_model.assert_called_with(MODEL_IMAGE, ROLE, MODEL_CONTAINER_DEF, None) + sagemaker_session.create_model.assert_called_with(MODEL_IMAGE, ROLE, MODEL_CONTAINER_DEF, None, tags=None) assert isinstance(transformer, Transformer) assert transformer.sagemaker_session == sagemaker_session @@ -659,7 +659,7 @@ def test_framework_transformer_creation_with_optional_params(name_from_image, sa max_concurrent_transforms=max_concurrent_transforms, max_payload=max_payload, volume_kms_key=kms_key, env=env, role=new_role, model_server_workers=1) - sagemaker_session.create_model.assert_called_with(MODEL_IMAGE, new_role, MODEL_CONTAINER_DEF, vpc_config) + sagemaker_session.create_model.assert_called_with(MODEL_IMAGE, new_role, MODEL_CONTAINER_DEF, vpc_config, tags=TAGS) assert transformer.strategy == strategy assert transformer.assemble_with == assemble_with assert transformer.output_path == OUTPUT_PATH @@ -698,7 +698,7 @@ def test_estimator_transformer_creation(sagemaker_session): transformer = estimator.transformer(INSTANCE_COUNT, INSTANCE_TYPE) - sagemaker_session.create_model_from_job.assert_called_with(JOB_NAME, role=None) + sagemaker_session.create_model_from_job.assert_called_with(JOB_NAME, role=None, tags=None) assert isinstance(transformer, Transformer) assert transformer.sagemaker_session == sagemaker_session assert transformer.instance_count == INSTANCE_COUNT @@ -728,7 +728,7 @@ def test_estimator_transformer_creation_with_optional_params(sagemaker_session): max_concurrent_transforms=max_concurrent_transforms, max_payload=max_payload, env=env, role=ROLE) - sagemaker_session.create_model_from_job.assert_called_with(JOB_NAME, role=ROLE) + sagemaker_session.create_model_from_job.assert_called_with(JOB_NAME, role=ROLE, tags=TAGS) assert transformer.strategy == strategy assert transformer.assemble_with == assemble_with assert transformer.output_path == OUTPUT_PATH diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index 4f34bde068..48250beda5 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -291,6 +291,7 @@ def test_s3_input_all_arguments(): } COMPLETED_DESCRIBE_JOB_RESULT = dict(DEFAULT_EXPECTED_TRAIN_JOB_ARGS) +COMPLETED_DESCRIBE_JOB_RESULT.update({'TrainingJobArn': 'arn:aws:sagemaker:us-west-2:336:training-job/' + JOB_NAME}) COMPLETED_DESCRIBE_JOB_RESULT.update({'TrainingJobStatus': 'Completed'}) COMPLETED_DESCRIBE_JOB_RESULT.update( {'ModelArtifacts': { @@ -870,6 +871,19 @@ def test_create_model_from_job(sagemaker_session): VpcConfig=VPC_CONFIG) +def test_create_model_from_job_with_tags(sagemaker_session): + ims = sagemaker_session + ims.sagemaker_client.describe_training_job.return_value = COMPLETED_DESCRIBE_JOB_RESULT + ims.create_model_from_job(JOB_NAME, tags=TAGS) + + assert call(TrainingJobName=JOB_NAME) in ims.sagemaker_client.describe_training_job.call_args_list + ims.sagemaker_client.create_model.assert_called_with(ExecutionRoleArn=EXPANDED_ROLE, + ModelName=JOB_NAME, + PrimaryContainer=PRIMARY_CONTAINER, + VpcConfig=VPC_CONFIG, + Tags=TAGS) + + def test_create_model_from_job_with_image(sagemaker_session): ims = sagemaker_session ims.sagemaker_client.describe_training_job.return_value = COMPLETED_DESCRIBE_JOB_RESULT