Skip to content

Commit eaf4612

Browse files
committed
Add integration test for estimator transformer
1 parent 8feff4a commit eaf4612

File tree

5 files changed

+74
-9
lines changed

5 files changed

+74
-9
lines changed

src/sagemaker/estimator.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -500,15 +500,16 @@ def transformer(self, instance_count, instance_type, strategy=None, assemble_wit
500500
volume_kms_key (str): Optional. KMS key ID for encrypting the volume attached to the ML
501501
compute instance (default: None).
502502
"""
503+
tags = tags or self.tags
504+
503505
if self.latest_training_job is not None:
504-
model_name = self.sagemaker_session.create_model_from_job(self.latest_training_job.name, role=role)
506+
model_name = self.sagemaker_session.create_model_from_job(self.latest_training_job.name, role=role,
507+
tags=tags)
505508
else:
506509
logging.warning('No finished training job found associated with this estimator. Please make sure'
507510
'this estimator is only used for building workflow config')
508511
model_name = self._current_job_name
509512

510-
tags = tags or self.tags
511-
512513
return Transformer(model_name, instance_count, instance_type, strategy=strategy, assemble_with=assemble_with,
513514
output_path=output_path, output_kms_key=output_kms_key, accept=accept,
514515
max_concurrent_transforms=max_concurrent_transforms, max_payload=max_payload,
@@ -1061,6 +1062,7 @@ def transformer(self, instance_count, instance_type, strategy=None, assemble_wit
10611062
container_def = model.prepare_container_def(instance_type)
10621063
model_name = model.name or name_from_image(container_def['Image'])
10631064
vpc_config = model.vpc_config
1065+
tags = tags or self.tags
10641066
self.sagemaker_session.create_model(model_name, role, container_def, vpc_config, tags=tags)
10651067
transform_env = model.env.copy()
10661068
if env is not None:

src/sagemaker/session.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -620,7 +620,8 @@ def create_model(self, name, role, container_defs, vpc_config=None,
620620
return name
621621

622622
def create_model_from_job(self, training_job_name, name=None, role=None, primary_container_image=None,
623-
model_data_url=None, env=None, vpc_config_override=vpc_utils.VPC_CONFIG_DEFAULT):
623+
model_data_url=None, env=None, vpc_config_override=vpc_utils.VPC_CONFIG_DEFAULT,
624+
tags=None):
624625
"""Create an Amazon SageMaker ``Model`` from a SageMaker Training Job.
625626
626627
Args:
@@ -638,12 +639,13 @@ def create_model_from_job(self, training_job_name, name=None, role=None, primary
638639
Default: use VpcConfig from training job.
639640
* 'Subnets' (list[str]): List of subnet ids.
640641
* 'SecurityGroupIds' (list[str]): List of security group ids.
642+
tags(List[dict[str, str]]): Optional. The list of tags to add to the model. For more, see
643+
https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html.
641644
642645
Returns:
643646
str: The name of the created ``Model``.
644647
"""
645648
training_job = self.sagemaker_client.describe_training_job(TrainingJobName=training_job_name)
646-
tags = self.sagemaker_client.list_tags(ResourceArn=training_job['TrainingJobArn'])['Tags']
647649
name = name or training_job_name
648650
role = role or training_job['RoleArn']
649651
env = env or {}

tests/integ/test_transformer.py

+51-1
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from sagemaker import KMeans
2323
from sagemaker.mxnet import MXNet
2424
from sagemaker.transformer import Transformer
25+
from sagemaker.estimator import Estimator
2526
from sagemaker.utils import unique_name_from_base
2627
from tests.integ import DATA_DIR, TRAINING_DEFAULT_TIMEOUT_MINUTES, TRANSFORM_DEFAULT_TIMEOUT_MINUTES
2728
from tests.integ.kms_utils import get_or_create_kms_key
@@ -148,7 +149,7 @@ def test_transform_mxnet_vpc(sagemaker_session, mxnet_full_version):
148149
assert [security_group_id] == model_desc['VpcConfig']['SecurityGroupIds']
149150

150151

151-
def test_transform_mxnet_logs(sagemaker_session, mxnet_full_version):
152+
def test_transform_mxnet_tags(sagemaker_session, mxnet_full_version):
152153
data_path = os.path.join(DATA_DIR, 'mxnet_mnist')
153154
script_path = os.path.join(data_path, 'mnist.py')
154155
tags = [{'Key': 'some-tag', 'Value': 'value-for-tag'}]
@@ -182,6 +183,55 @@ def test_transform_mxnet_logs(sagemaker_session, mxnet_full_version):
182183
assert tags == model_tags
183184

184185

186+
def test_transform_byo_estimator(sagemaker_session):
187+
data_path = os.path.join(DATA_DIR, 'one_p_mnist')
188+
pickle_args = {} if sys.version_info.major == 2 else {'encoding': 'latin1'}
189+
tags = [{'Key': 'some-tag', 'Value': 'value-for-tag'}]
190+
191+
# Load the data into memory as numpy arrays
192+
train_set_path = os.path.join(data_path, 'mnist.pkl.gz')
193+
with gzip.open(train_set_path, 'rb') as f:
194+
train_set, _, _ = pickle.load(f, **pickle_args)
195+
196+
kmeans = KMeans(role='SageMakerRole', train_instance_count=1,
197+
train_instance_type='ml.c4.xlarge', k=10, sagemaker_session=sagemaker_session,
198+
output_path='s3://{}/'.format(sagemaker_session.default_bucket()))
199+
200+
# set kmeans specific hp
201+
kmeans.init_method = 'random'
202+
kmeans.max_iterators = 1
203+
kmeans.tol = 1
204+
kmeans.num_trials = 1
205+
kmeans.local_init_method = 'kmeans++'
206+
kmeans.half_life_time_size = 1
207+
kmeans.epochs = 1
208+
209+
records = kmeans.record_set(train_set[0][:100])
210+
211+
job_name = unique_name_from_base('test-kmeans-attach')
212+
213+
with timeout(minutes=TRAINING_DEFAULT_TIMEOUT_MINUTES):
214+
kmeans.fit(records, job_name=job_name)
215+
216+
transform_input_path = os.path.join(data_path, 'transform_input.csv')
217+
transform_input_key_prefix = 'integ-test-data/one_p_mnist/transform'
218+
transform_input = kmeans.sagemaker_session.upload_data(path=transform_input_path,
219+
key_prefix=transform_input_key_prefix)
220+
221+
estimator = Estimator.attach(training_job_name=job_name,
222+
sagemaker_session=sagemaker_session)
223+
224+
transformer = estimator.transformer(1, 'ml.m4.xlarge', tags=tags)
225+
transformer.transform(transform_input, content_type='text/csv')
226+
227+
with timeout_and_delete_model_with_transformer(transformer, sagemaker_session,
228+
minutes=TRANSFORM_DEFAULT_TIMEOUT_MINUTES):
229+
transformer.wait()
230+
model_desc = sagemaker_session.sagemaker_client.describe_model(ModelName=transformer.model_name)
231+
model_tags = sagemaker_session.sagemaker_client.list_tags(ResourceArn=model_desc['ModelArn'])['Tags']
232+
assert tags == model_tags
233+
234+
185235
def _create_transformer_and_transform_job(estimator, transform_input, volume_kms_key=None):
186236
transformer = estimator.transformer(1, 'ml.m4.xlarge', volume_kms_key=volume_kms_key)
187237
transformer.transform(transform_input, content_type='text/csv')

tests/unit/test_estimator.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -698,7 +698,7 @@ def test_estimator_transformer_creation(sagemaker_session):
698698

699699
transformer = estimator.transformer(INSTANCE_COUNT, INSTANCE_TYPE)
700700

701-
sagemaker_session.create_model_from_job.assert_called_with(JOB_NAME, role=None)
701+
sagemaker_session.create_model_from_job.assert_called_with(JOB_NAME, role=None, tags=None)
702702
assert isinstance(transformer, Transformer)
703703
assert transformer.sagemaker_session == sagemaker_session
704704
assert transformer.instance_count == INSTANCE_COUNT
@@ -728,7 +728,7 @@ def test_estimator_transformer_creation_with_optional_params(sagemaker_session):
728728
max_concurrent_transforms=max_concurrent_transforms, max_payload=max_payload,
729729
env=env, role=ROLE)
730730

731-
sagemaker_session.create_model_from_job.assert_called_with(JOB_NAME, role=ROLE)
731+
sagemaker_session.create_model_from_job.assert_called_with(JOB_NAME, role=ROLE, tags=TAGS)
732732
assert transformer.strategy == strategy
733733
assert transformer.assemble_with == assemble_with
734734
assert transformer.output_path == OUTPUT_PATH

tests/unit/test_session.py

+12-1
Original file line numberDiff line numberDiff line change
@@ -862,9 +862,20 @@ def test_create_model_failure(expand_container_def, sagemaker_session):
862862
def test_create_model_from_job(sagemaker_session):
863863
ims = sagemaker_session
864864
ims.sagemaker_client.describe_training_job.return_value = COMPLETED_DESCRIBE_JOB_RESULT
865-
ims.sagemaker_client.list_tags.return_value = {'Tags': TAGS}
866865
ims.create_model_from_job(JOB_NAME)
867866

867+
assert call(TrainingJobName=JOB_NAME) in ims.sagemaker_client.describe_training_job.call_args_list
868+
ims.sagemaker_client.create_model.assert_called_with(ExecutionRoleArn=EXPANDED_ROLE,
869+
ModelName=JOB_NAME,
870+
PrimaryContainer=PRIMARY_CONTAINER,
871+
VpcConfig=VPC_CONFIG)
872+
873+
874+
def test_create_model_from_job_with_tags(sagemaker_session):
875+
ims = sagemaker_session
876+
ims.sagemaker_client.describe_training_job.return_value = COMPLETED_DESCRIBE_JOB_RESULT
877+
ims.create_model_from_job(JOB_NAME, tags=TAGS)
878+
868879
assert call(TrainingJobName=JOB_NAME) in ims.sagemaker_client.describe_training_job.call_args_list
869880
ims.sagemaker_client.create_model.assert_called_with(ExecutionRoleArn=EXPANDED_ROLE,
870881
ModelName=JOB_NAME,

0 commit comments

Comments
 (0)