Skip to content

feature: emit estimator transformer tags to model #815

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jun 9, 2019
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions src/sagemaker/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,15 +500,16 @@ def transformer(self, instance_count, instance_type, strategy=None, assemble_wit
volume_kms_key (str): Optional. KMS key ID for encrypting the volume attached to the ML
compute instance (default: None).
"""
tags = tags or self.tags

if self.latest_training_job is not None:
model_name = self.sagemaker_session.create_model_from_job(self.latest_training_job.name, role=role)
model_name = self.sagemaker_session.create_model_from_job(self.latest_training_job.name, role=role,
tags=tags)
else:
logging.warning('No finished training job found associated with this estimator. Please make sure'
'this estimator is only used for building workflow config')
model_name = self._current_job_name

tags = tags or self.tags

return Transformer(model_name, instance_count, instance_type, strategy=strategy, assemble_with=assemble_with,
output_path=output_path, output_kms_key=output_kms_key, accept=accept,
max_concurrent_transforms=max_concurrent_transforms, max_payload=max_payload,
Expand Down Expand Up @@ -1061,7 +1062,8 @@ def transformer(self, instance_count, instance_type, strategy=None, assemble_wit
container_def = model.prepare_container_def(instance_type)
model_name = model.name or name_from_image(container_def['Image'])
vpc_config = model.vpc_config
self.sagemaker_session.create_model(model_name, role, container_def, vpc_config)
tags = tags or self.tags
self.sagemaker_session.create_model(model_name, role, container_def, vpc_config, tags=tags)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how about defaulting to self.tags if tags is None? I think both transformer() implementations should have the same behavior around which tags are propagated.

transform_env = model.env.copy()
if env is not None:
transform_env.update(env)
Expand Down
7 changes: 5 additions & 2 deletions src/sagemaker/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -620,7 +620,8 @@ def create_model(self, name, role, container_defs, vpc_config=None,
return name

def create_model_from_job(self, training_job_name, name=None, role=None, primary_container_image=None,
model_data_url=None, env=None, vpc_config_override=vpc_utils.VPC_CONFIG_DEFAULT):
model_data_url=None, env=None, vpc_config_override=vpc_utils.VPC_CONFIG_DEFAULT,
tags=None):
"""Create an Amazon SageMaker ``Model`` from a SageMaker Training Job.

Args:
Expand All @@ -638,6 +639,8 @@ def create_model_from_job(self, training_job_name, name=None, role=None, primary
Default: use VpcConfig from training job.
* 'Subnets' (list[str]): List of subnet ids.
* 'SecurityGroupIds' (list[str]): List of security group ids.
tags(List[dict[str, str]]): Optional. The list of tags to add to the model. For more, see
https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add that this defaults to the job's tags. (might also need to be updated in other docstrings.) alternatively, you could default to self.tags before this method is called in the estimator.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the second option 👌🏽


Returns:
str: The name of the created ``Model``.
Expand All @@ -651,7 +654,7 @@ def create_model_from_job(self, training_job_name, name=None, role=None, primary
model_data_url=model_data_url or training_job['ModelArtifacts']['S3ModelArtifacts'],
env=env)
vpc_config = _vpc_config_from_training_job(training_job, vpc_config_override)
return self.create_model(name, role, primary_container, vpc_config=vpc_config)
return self.create_model(name, role, primary_container, vpc_config=vpc_config, tags=tags)

def create_model_package_from_algorithm(self, name, description, algorithm_arn, model_data):
"""Create a SageMaker Model Package from the results of training with an Algorithm Package
Expand Down
84 changes: 84 additions & 0 deletions tests/integ/test_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from sagemaker import KMeans
from sagemaker.mxnet import MXNet
from sagemaker.transformer import Transformer
from sagemaker.estimator import Estimator
from sagemaker.utils import unique_name_from_base
from tests.integ import DATA_DIR, TRAINING_DEFAULT_TIMEOUT_MINUTES, TRANSFORM_DEFAULT_TIMEOUT_MINUTES
from tests.integ.kms_utils import get_or_create_kms_key
Expand Down Expand Up @@ -148,6 +149,89 @@ def test_transform_mxnet_vpc(sagemaker_session, mxnet_full_version):
assert [security_group_id] == model_desc['VpcConfig']['SecurityGroupIds']


def test_transform_mxnet_tags(sagemaker_session, mxnet_full_version):
data_path = os.path.join(DATA_DIR, 'mxnet_mnist')
script_path = os.path.join(data_path, 'mnist.py')
tags = [{'Key': 'some-tag', 'Value': 'value-for-tag'}]

mx = MXNet(entry_point=script_path, role='SageMakerRole', train_instance_count=1,
train_instance_type='ml.c4.xlarge', sagemaker_session=sagemaker_session,
framework_version=mxnet_full_version)

train_input = mx.sagemaker_session.upload_data(path=os.path.join(data_path, 'train'),
key_prefix='integ-test-data/mxnet_mnist/train')
test_input = mx.sagemaker_session.upload_data(path=os.path.join(data_path, 'test'),
key_prefix='integ-test-data/mxnet_mnist/test')
job_name = unique_name_from_base('test-mxnet-transform')

with timeout(minutes=TRAINING_DEFAULT_TIMEOUT_MINUTES):
mx.fit({'train': train_input, 'test': test_input}, job_name=job_name)

transform_input_path = os.path.join(data_path, 'transform', 'data.csv')
transform_input_key_prefix = 'integ-test-data/mxnet_mnist/transform'
transform_input = mx.sagemaker_session.upload_data(path=transform_input_path,
key_prefix=transform_input_key_prefix)

transformer = mx.transformer(1, 'ml.m4.xlarge', tags=tags)
transformer.transform(transform_input, content_type='text/csv')

with timeout_and_delete_model_with_transformer(transformer, sagemaker_session,
minutes=TRANSFORM_DEFAULT_TIMEOUT_MINUTES):
transformer.wait()
model_desc = sagemaker_session.sagemaker_client.describe_model(ModelName=transformer.model_name)
model_tags = sagemaker_session.sagemaker_client.list_tags(ResourceArn=model_desc['ModelArn'])['Tags']
assert tags == model_tags


def test_transform_byo_estimator(sagemaker_session):
data_path = os.path.join(DATA_DIR, 'one_p_mnist')
pickle_args = {} if sys.version_info.major == 2 else {'encoding': 'latin1'}
tags = [{'Key': 'some-tag', 'Value': 'value-for-tag'}]

# Load the data into memory as numpy arrays
train_set_path = os.path.join(data_path, 'mnist.pkl.gz')
with gzip.open(train_set_path, 'rb') as f:
train_set, _, _ = pickle.load(f, **pickle_args)

kmeans = KMeans(role='SageMakerRole', train_instance_count=1,
train_instance_type='ml.c4.xlarge', k=10, sagemaker_session=sagemaker_session,
output_path='s3://{}/'.format(sagemaker_session.default_bucket()))

# set kmeans specific hp
kmeans.init_method = 'random'
kmeans.max_iterators = 1
kmeans.tol = 1
kmeans.num_trials = 1
kmeans.local_init_method = 'kmeans++'
kmeans.half_life_time_size = 1
kmeans.epochs = 1

records = kmeans.record_set(train_set[0][:100])

job_name = unique_name_from_base('test-kmeans-attach')

with timeout(minutes=TRAINING_DEFAULT_TIMEOUT_MINUTES):
kmeans.fit(records, job_name=job_name)

transform_input_path = os.path.join(data_path, 'transform_input.csv')
transform_input_key_prefix = 'integ-test-data/one_p_mnist/transform'
transform_input = kmeans.sagemaker_session.upload_data(path=transform_input_path,
key_prefix=transform_input_key_prefix)

estimator = Estimator.attach(training_job_name=job_name,
sagemaker_session=sagemaker_session)

transformer = estimator.transformer(1, 'ml.m4.xlarge', tags=tags)
transformer.transform(transform_input, content_type='text/csv')

with timeout_and_delete_model_with_transformer(transformer, sagemaker_session,
minutes=TRANSFORM_DEFAULT_TIMEOUT_MINUTES):
transformer.wait()
model_desc = sagemaker_session.sagemaker_client.describe_model(ModelName=transformer.model_name)
model_tags = sagemaker_session.sagemaker_client.list_tags(ResourceArn=model_desc['ModelArn'])['Tags']
assert tags == model_tags


def _create_transformer_and_transform_job(estimator, transform_input, volume_kms_key=None):
transformer = estimator.transformer(1, 'ml.m4.xlarge', volume_kms_key=volume_kms_key)
transformer.transform(transform_input, content_type='text/csv')
Expand Down
8 changes: 4 additions & 4 deletions tests/unit/test_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -624,7 +624,7 @@ def test_framework_transformer_creation(name_from_image, sagemaker_session):
transformer = fw.transformer(INSTANCE_COUNT, INSTANCE_TYPE)

name_from_image.assert_called_with(MODEL_IMAGE)
sagemaker_session.create_model.assert_called_with(MODEL_IMAGE, ROLE, MODEL_CONTAINER_DEF, None)
sagemaker_session.create_model.assert_called_with(MODEL_IMAGE, ROLE, MODEL_CONTAINER_DEF, None, tags=None)

assert isinstance(transformer, Transformer)
assert transformer.sagemaker_session == sagemaker_session
Expand Down Expand Up @@ -659,7 +659,7 @@ def test_framework_transformer_creation_with_optional_params(name_from_image, sa
max_concurrent_transforms=max_concurrent_transforms, max_payload=max_payload,
volume_kms_key=kms_key, env=env, role=new_role, model_server_workers=1)

sagemaker_session.create_model.assert_called_with(MODEL_IMAGE, new_role, MODEL_CONTAINER_DEF, vpc_config)
sagemaker_session.create_model.assert_called_with(MODEL_IMAGE, new_role, MODEL_CONTAINER_DEF, vpc_config, tags=TAGS)
assert transformer.strategy == strategy
assert transformer.assemble_with == assemble_with
assert transformer.output_path == OUTPUT_PATH
Expand Down Expand Up @@ -698,7 +698,7 @@ def test_estimator_transformer_creation(sagemaker_session):

transformer = estimator.transformer(INSTANCE_COUNT, INSTANCE_TYPE)

sagemaker_session.create_model_from_job.assert_called_with(JOB_NAME, role=None)
sagemaker_session.create_model_from_job.assert_called_with(JOB_NAME, role=None, tags=None)
assert isinstance(transformer, Transformer)
assert transformer.sagemaker_session == sagemaker_session
assert transformer.instance_count == INSTANCE_COUNT
Expand Down Expand Up @@ -728,7 +728,7 @@ def test_estimator_transformer_creation_with_optional_params(sagemaker_session):
max_concurrent_transforms=max_concurrent_transforms, max_payload=max_payload,
env=env, role=ROLE)

sagemaker_session.create_model_from_job.assert_called_with(JOB_NAME, role=ROLE)
sagemaker_session.create_model_from_job.assert_called_with(JOB_NAME, role=ROLE, tags=TAGS)
assert transformer.strategy == strategy
assert transformer.assemble_with == assemble_with
assert transformer.output_path == OUTPUT_PATH
Expand Down
14 changes: 14 additions & 0 deletions tests/unit/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,7 @@ def test_s3_input_all_arguments():
}

COMPLETED_DESCRIBE_JOB_RESULT = dict(DEFAULT_EXPECTED_TRAIN_JOB_ARGS)
COMPLETED_DESCRIBE_JOB_RESULT.update({'TrainingJobArn': 'arn:aws:sagemaker:us-west-2:336:training-job/' + JOB_NAME})
COMPLETED_DESCRIBE_JOB_RESULT.update({'TrainingJobStatus': 'Completed'})
COMPLETED_DESCRIBE_JOB_RESULT.update(
{'ModelArtifacts': {
Expand Down Expand Up @@ -870,6 +871,19 @@ def test_create_model_from_job(sagemaker_session):
VpcConfig=VPC_CONFIG)


def test_create_model_from_job_with_tags(sagemaker_session):
ims = sagemaker_session
ims.sagemaker_client.describe_training_job.return_value = COMPLETED_DESCRIBE_JOB_RESULT
ims.create_model_from_job(JOB_NAME, tags=TAGS)

assert call(TrainingJobName=JOB_NAME) in ims.sagemaker_client.describe_training_job.call_args_list
ims.sagemaker_client.create_model.assert_called_with(ExecutionRoleArn=EXPANDED_ROLE,
ModelName=JOB_NAME,
PrimaryContainer=PRIMARY_CONTAINER,
VpcConfig=VPC_CONFIG,
Tags=TAGS)


def test_create_model_from_job_with_image(sagemaker_session):
ims = sagemaker_session
ims.sagemaker_client.describe_training_job.return_value = COMPLETED_DESCRIBE_JOB_RESULT
Expand Down