Skip to content

Commit 3253466

Browse files
laurenyuandremoeller
authored andcommitted
Allow Model and Transformer to use a different role from the Estimator (#308)
* Allow Model and Transformer to use a different role from Estimator * update changelog * Add new arg to the end to be non-breaking
1 parent b943baa commit 3253466

File tree

11 files changed

+123
-23
lines changed

11 files changed

+123
-23
lines changed

CHANGELOG.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ CHANGELOG
77

88
* bug-fix: get_execution_role no longer fails if user can't call get_role
99
* bug-fix: Session: use existing model instead of failing during ``create_model()``
10+
* enhancement: Estimator: allow for different role from the Estimator's when creating a Model or Transformer
1011

1112
1.7.0
1213
=====

src/sagemaker/chainer/estimator.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,18 +98,21 @@ def hyperparameters(self):
9898
hyperparameters.update(Framework._json_encode_hyperparameters(additional_hyperparameters))
9999
return hyperparameters
100100

101-
def create_model(self, model_server_workers=None):
101+
def create_model(self, model_server_workers=None, role=None):
102102
"""Create a SageMaker ``ChainerModel`` object that can be deployed to an ``Endpoint``.
103103
104104
Args:
105+
role (str): The ``ExecutionRoleArn`` IAM Role ARN for the ``Model``, which is also used during
106+
transform jobs. If not specified, the role from the Estimator will be used.
105107
model_server_workers (int): Optional. The number of worker processes used by the inference server.
106108
If None, server will use one worker per vCPU.
107109
108110
Returns:
109111
sagemaker.chainer.model.ChainerModel: A SageMaker ``ChainerModel`` object.
110112
See :func:`~sagemaker.chainer.model.ChainerModel` for full details.
111113
"""
112-
return ChainerModel(self.model_data, self.role, self.entry_point, source_dir=self._model_source_dir(),
114+
role = role or self.role
115+
return ChainerModel(self.model_data, role, self.entry_point, source_dir=self._model_source_dir(),
113116
enable_cloudwatch_metrics=self.enable_cloudwatch_metrics, name=self._current_job_name,
114117
container_log_level=self.container_log_level, code_location=self.code_location,
115118
py_version=self.py_version, framework_version=self.framework_version,

src/sagemaker/estimator.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -319,7 +319,7 @@ def delete_endpoint(self):
319319

320320
def transformer(self, instance_count, instance_type, strategy=None, assemble_with=None, output_path=None,
321321
output_kms_key=None, accept=None, env=None, max_concurrent_transforms=None,
322-
max_payload=None, tags=None):
322+
max_payload=None, tags=None, role=None):
323323
"""Return a ``Transformer`` that uses a SageMaker Model based on the training job. It reuses the
324324
SageMaker Session and base job name used by the Estimator.
325325
@@ -339,10 +339,12 @@ def transformer(self, instance_count, instance_type, strategy=None, assemble_wit
339339
max_payload (int): Maximum size of the payload in a single HTTP request to the container in MB.
340340
tags (list[dict]): List of tags for labeling a transform job. If none specified, then the tags used for
341341
the training job are used for the transform job.
342+
role (str): The ``ExecutionRoleArn`` IAM Role ARN for the ``Model``, which is also used during
343+
transform jobs. If not specified, the role from the Estimator will be used.
342344
"""
343345
self._ensure_latest_training_job()
344346

345-
model_name = self.sagemaker_session.create_model_from_job(self.latest_training_job.name)
347+
model_name = self.sagemaker_session.create_model_from_job(self.latest_training_job.name, role=role)
346348
tags = tags or self.tags
347349

348350
return Transformer(model_name, instance_count, instance_type, strategy=strategy, assemble_with=assemble_with,
@@ -476,12 +478,14 @@ def hyperparameters(self):
476478
"""
477479
return self.hyperparam_dict
478480

479-
def create_model(self, image=None, predictor_cls=None, serializer=None, deserializer=None,
481+
def create_model(self, role=None, image=None, predictor_cls=None, serializer=None, deserializer=None,
480482
content_type=None, accept=None, **kwargs):
481483
"""
482484
Create a model to deploy.
483485
484486
Args:
487+
role (str): The ``ExecutionRoleArn`` IAM Role ARN for the ``Model``, which is also used during
488+
transform jobs. If not specified, the role from the Estimator will be used.
485489
image (str): An container image to use for deploying the model. Defaults to the image used for training.
486490
predictor_cls (RealTimePredictor): The predictor class to use when deploying the model.
487491
serializer (callable): Should accept a single argument, the input data, and return a sequence
@@ -503,7 +507,9 @@ def predict_wrapper(endpoint, session):
503507
return RealTimePredictor(endpoint, session, serializer, deserializer, content_type, accept)
504508
predictor_cls = predict_wrapper
505509

506-
return Model(self.model_data, image or self.train_image(), self.role, sagemaker_session=self.sagemaker_session,
510+
role = role or self.role
511+
512+
return Model(self.model_data, image or self.train_image(), role, sagemaker_session=self.sagemaker_session,
507513
predictor_cls=predictor_cls, **kwargs)
508514

509515
@classmethod
@@ -737,7 +743,7 @@ def _update_init_params(cls, hp, tf_arguments):
737743

738744
def transformer(self, instance_count, instance_type, strategy=None, assemble_with=None, output_path=None,
739745
output_kms_key=None, accept=None, env=None, max_concurrent_transforms=None,
740-
max_payload=None, tags=None, model_server_workers=None):
746+
max_payload=None, tags=None, role=None, model_server_workers=None):
741747
"""Return a ``Transformer`` that uses a SageMaker Model based on the training job. It reuses the
742748
SageMaker Session and base job name used by the Estimator.
743749
@@ -757,16 +763,19 @@ def transformer(self, instance_count, instance_type, strategy=None, assemble_wit
757763
max_payload (int): Maximum size of the payload in a single HTTP request to the container in MB.
758764
tags (list[dict]): List of tags for labeling a transform job. If none specified, then the tags used for
759765
the training job are used for the transform job.
766+
role (str): The ``ExecutionRoleArn`` IAM Role ARN for the ``Model``, which is also used during
767+
transform jobs. If not specified, the role from the Estimator will be used.
760768
model_server_workers (int): Optional. The number of worker processes used by the inference server.
761769
If None, server will use one worker per vCPU.
762770
"""
763771
self._ensure_latest_training_job()
772+
role = role or self.role
764773

765-
model = self.create_model(model_server_workers=model_server_workers)
774+
model = self.create_model(role=role, model_server_workers=model_server_workers)
766775

767776
container_def = model.prepare_container_def(instance_type)
768777
model_name = model.name or name_from_image(container_def['Image'])
769-
self.sagemaker_session.create_model(model_name, self.role, container_def)
778+
self.sagemaker_session.create_model(model_name, role, container_def)
770779

771780
transform_env = model.env.copy()
772781
if env is not None:

src/sagemaker/mxnet/estimator.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,18 +65,21 @@ def __init__(self, entry_point, source_dir=None, hyperparameters=None, py_versio
6565
self.py_version = py_version
6666
self.framework_version = framework_version
6767

68-
def create_model(self, model_server_workers=None):
68+
def create_model(self, model_server_workers=None, role=None):
6969
"""Create a SageMaker ``MXNetModel`` object that can be deployed to an ``Endpoint``.
7070
7171
Args:
72+
role (str): The ``ExecutionRoleArn`` IAM Role ARN for the ``Model``, which is also used during
73+
transform jobs. If not specified, the role from the Estimator will be used.
7274
model_server_workers (int): Optional. The number of worker processes used by the inference server.
7375
If None, server will use one worker per vCPU.
7476
7577
Returns:
7678
sagemaker.mxnet.model.MXNetModel: A SageMaker ``MXNetModel`` object.
7779
See :func:`~sagemaker.mxnet.model.MXNetModel` for full details.
7880
"""
79-
return MXNetModel(self.model_data, self.role, self.entry_point, source_dir=self._model_source_dir(),
81+
role = role or self.role
82+
return MXNetModel(self.model_data, role, self.entry_point, source_dir=self._model_source_dir(),
8083
enable_cloudwatch_metrics=self.enable_cloudwatch_metrics, name=self._current_job_name,
8184
container_log_level=self.container_log_level, code_location=self.code_location,
8285
py_version=self.py_version, framework_version=self.framework_version, image=self.image_name,

src/sagemaker/pytorch/estimator.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,18 +63,21 @@ def __init__(self, entry_point, source_dir=None, hyperparameters=None, py_versio
6363
self.py_version = py_version
6464
self.framework_version = framework_version
6565

66-
def create_model(self, model_server_workers=None):
66+
def create_model(self, model_server_workers=None, role=None):
6767
"""Create a SageMaker ``PyTorchModel`` object that can be deployed to an ``Endpoint``.
6868
6969
Args:
70+
role (str): The ``ExecutionRoleArn`` IAM Role ARN for the ``Model``, which is also used during
71+
transform jobs. If not specified, the role from the Estimator will be used.
7072
model_server_workers (int): Optional. The number of worker processes used by the inference server.
7173
If None, server will use one worker per vCPU.
7274
7375
Returns:
7476
sagemaker.pytorch.model.PyTorchModel: A SageMaker ``PyTorchModel`` object.
7577
See :func:`~sagemaker.pytorch.model.PyTorchModel` for full details.
7678
"""
77-
return PyTorchModel(self.model_data, self.role, self.entry_point, source_dir=self._model_source_dir(),
79+
role = role or self.role
80+
return PyTorchModel(self.model_data, role, self.entry_point, source_dir=self._model_source_dir(),
7881
enable_cloudwatch_metrics=self.enable_cloudwatch_metrics, name=self._current_job_name,
7982
container_log_level=self.container_log_level, code_location=self.code_location,
8083
py_version=self.py_version, framework_version=self.framework_version, image=self.image_name,

src/sagemaker/tensorflow/estimator.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -289,10 +289,12 @@ def _prepare_init_params_from_job_description(cls, job_details):
289289

290290
return init_params
291291

292-
def create_model(self, model_server_workers=None):
292+
def create_model(self, model_server_workers=None, role=None):
293293
"""Create a SageMaker ``TensorFlowModel`` object that can be deployed to an ``Endpoint``.
294294
295295
Args:
296+
role (str): The ``ExecutionRoleArn`` IAM Role ARN for the ``Model``, which is also used during
297+
transform jobs. If not specified, the role from the Estimator will be used.
296298
model_server_workers (int): Optional. The number of worker processes used by the inference server.
297299
If None, server will use one worker per vCPU.
298300
@@ -301,7 +303,8 @@ def create_model(self, model_server_workers=None):
301303
See :func:`~sagemaker.tensorflow.model.TensorFlowModel` for full details.
302304
"""
303305
env = {'SAGEMAKER_REQUIREMENTS': self.requirements_file}
304-
return TensorFlowModel(self.model_data, self.role, self.entry_point, source_dir=self._model_source_dir(),
306+
role = role or self.role
307+
return TensorFlowModel(self.model_data, role, self.entry_point, source_dir=self._model_source_dir(),
305308
enable_cloudwatch_metrics=self.enable_cloudwatch_metrics, env=env, image=self.image_name,
306309
name=self._current_job_name, container_log_level=self.container_log_level,
307310
code_location=self.code_location, py_version=self.py_version,

tests/unit/test_chainer.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -233,7 +233,7 @@ def test_create_model(sagemaker_session, chainer_version):
233233
enable_cloudwatch_metrics=enable_cloudwatch_metrics)
234234

235235
job_name = 'new_name'
236-
chainer.fit(inputs='s3://mybucket/train', job_name='new_name')
236+
chainer.fit(inputs='s3://mybucket/train', job_name=job_name)
237237
model = chainer.create_model()
238238

239239
assert model.sagemaker_session == sagemaker_session
@@ -247,6 +247,25 @@ def test_create_model(sagemaker_session, chainer_version):
247247
assert model.enable_cloudwatch_metrics == enable_cloudwatch_metrics
248248

249249

250+
def test_create_model_with_optional_params(sagemaker_session):
251+
container_log_level = '"logging.INFO"'
252+
source_dir = 's3://mybucket/source'
253+
enable_cloudwatch_metrics = 'true'
254+
chainer = Chainer(entry_point=SCRIPT_PATH, role=ROLE, sagemaker_session=sagemaker_session,
255+
train_instance_count=INSTANCE_COUNT, train_instance_type=INSTANCE_TYPE,
256+
container_log_level=container_log_level, py_version=PYTHON_VERSION, base_job_name='job',
257+
source_dir=source_dir, enable_cloudwatch_metrics=enable_cloudwatch_metrics)
258+
259+
chainer.fit(inputs='s3://mybucket/train', job_name='new_name')
260+
261+
new_role = 'role'
262+
model_server_workers = 2
263+
model = chainer.create_model(role=new_role, model_server_workers=model_server_workers)
264+
265+
assert model.role == new_role
266+
assert model.model_server_workers == model_server_workers
267+
268+
250269
def test_create_model_with_custom_image(sagemaker_session):
251270
container_log_level = '"logging.INFO"'
252271
source_dir = 's3://mybucket/source'

tests/unit/test_estimator.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ class DummyFramework(Framework):
8686
def train_image(self):
8787
return IMAGE_NAME
8888

89-
def create_model(self, model_server_workers=None):
89+
def create_model(self, role=None, model_server_workers=None):
9090
return DummyFrameworkModel(self.sagemaker_session)
9191

9292
@classmethod
@@ -476,21 +476,21 @@ def test_framework_transformer_creation_with_optional_params(name_from_image, sa
476476
base_job_name=base_name)
477477
fw.latest_training_job = _TrainingJob(sagemaker_session, JOB_NAME)
478478

479-
transformer = fw.transformer(INSTANCE_COUNT, INSTANCE_TYPE)
480-
481479
strategy = 'MultiRecord'
482480
assemble_with = 'Line'
483481
kms_key = 'key'
484482
accept = 'text/csv'
485483
max_concurrent_transforms = 1
486484
max_payload = 6
487485
env = {'FOO': 'BAR'}
486+
new_role = 'dummy-model-role'
488487

489488
transformer = fw.transformer(INSTANCE_COUNT, INSTANCE_TYPE, strategy=strategy, assemble_with=assemble_with,
490489
output_path=OUTPUT_PATH, output_kms_key=kms_key, accept=accept, tags=TAGS,
491490
max_concurrent_transforms=max_concurrent_transforms, max_payload=max_payload,
492-
env=env, model_server_workers=1)
491+
env=env, role=new_role, model_server_workers=1)
493492

493+
sagemaker_session.create_model.assert_called_with(MODEL_IMAGE, new_role, MODEL_CONTAINER_DEF)
494494
assert transformer.strategy == strategy
495495
assert transformer.assemble_with == assemble_with
496496
assert transformer.output_path == OUTPUT_PATH
@@ -528,7 +528,7 @@ def test_estimator_transformer_creation(sagemaker_session):
528528

529529
transformer = estimator.transformer(INSTANCE_COUNT, INSTANCE_TYPE)
530530

531-
sagemaker_session.create_model_from_job.assert_called_with(JOB_NAME)
531+
sagemaker_session.create_model_from_job.assert_called_with(JOB_NAME, role=None)
532532
assert isinstance(transformer, Transformer)
533533
assert transformer.sagemaker_session == sagemaker_session
534534
assert transformer.instance_count == INSTANCE_COUNT
@@ -556,8 +556,9 @@ def test_estimator_transformer_creation_with_optional_params(sagemaker_session):
556556
transformer = estimator.transformer(INSTANCE_COUNT, INSTANCE_TYPE, strategy=strategy, assemble_with=assemble_with,
557557
output_path=OUTPUT_PATH, output_kms_key=kms_key, accept=accept, tags=TAGS,
558558
max_concurrent_transforms=max_concurrent_transforms, max_payload=max_payload,
559-
env=env)
559+
env=env, role=ROLE)
560560

561+
sagemaker_session.create_model_from_job.assert_called_with(JOB_NAME, role=ROLE)
561562
assert transformer.strategy == strategy
562563
assert transformer.assemble_with == assemble_with
563564
assert transformer.output_path == OUTPUT_PATH

tests/unit/test_mxnet.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ def test_create_model(sagemaker_session, mxnet_version):
108108
base_job_name='job', source_dir=source_dir, enable_cloudwatch_metrics=enable_cloudwatch_metrics)
109109

110110
job_name = 'new_name'
111-
mx.fit(inputs='s3://mybucket/train', job_name='new_name')
111+
mx.fit(inputs='s3://mybucket/train', job_name=job_name)
112112
model = mx.create_model()
113113

114114
assert model.sagemaker_session == sagemaker_session
@@ -122,6 +122,25 @@ def test_create_model(sagemaker_session, mxnet_version):
122122
assert model.enable_cloudwatch_metrics == enable_cloudwatch_metrics
123123

124124

125+
def test_create_model_with_optional_params(sagemaker_session):
126+
container_log_level = '"logging.INFO"'
127+
source_dir = 's3://mybucket/source'
128+
enable_cloudwatch_metrics = 'true'
129+
mx = MXNet(entry_point=SCRIPT_PATH, role=ROLE, sagemaker_session=sagemaker_session,
130+
train_instance_count=INSTANCE_COUNT, train_instance_type=INSTANCE_TYPE,
131+
container_log_level=container_log_level, base_job_name='job', source_dir=source_dir,
132+
enable_cloudwatch_metrics=enable_cloudwatch_metrics)
133+
134+
mx.fit(inputs='s3://mybucket/train', job_name='new_name')
135+
136+
new_role = 'role'
137+
model_server_workers = 2
138+
model = mx.create_model(role=new_role, model_server_workers=model_server_workers)
139+
140+
assert model.role == new_role
141+
assert model.model_server_workers == model_server_workers
142+
143+
125144
def test_create_model_with_custom_image(sagemaker_session):
126145
container_log_level = '"logging.INFO"'
127146
source_dir = 's3://mybucket/source'

tests/unit/test_pytorch.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,25 @@ def test_create_model(sagemaker_session, pytorch_version):
140140
assert model.enable_cloudwatch_metrics == enable_cloudwatch_metrics
141141

142142

143+
def test_create_model_with_optional_params(sagemaker_session):
144+
container_log_level = '"logging.INFO"'
145+
source_dir = 's3://mybucket/source'
146+
enable_cloudwatch_metrics = 'true'
147+
pytorch = PyTorch(entry_point=SCRIPT_PATH, role=ROLE, sagemaker_session=sagemaker_session,
148+
train_instance_count=INSTANCE_COUNT, train_instance_type=INSTANCE_TYPE,
149+
container_log_level=container_log_level, base_job_name='job', source_dir=source_dir,
150+
enable_cloudwatch_metrics=enable_cloudwatch_metrics)
151+
152+
pytorch.fit(inputs='s3://mybucket/train', job_name='new_name')
153+
154+
new_role = 'role'
155+
model_server_workers = 2
156+
model = pytorch.create_model(role=new_role, model_server_workers=model_server_workers)
157+
158+
assert model.role == new_role
159+
assert model.model_server_workers == model_server_workers
160+
161+
143162
def test_create_model_with_custom_image(sagemaker_session):
144163
container_log_level = '"logging.INFO"'
145164
source_dir = 's3://mybucket/source'

tests/unit/test_tf_estimator.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,26 @@ def test_create_model(sagemaker_session, tf_version):
205205
assert model.enable_cloudwatch_metrics == enable_cloudwatch_metrics
206206

207207

208+
def test_create_model_with_optional_params(sagemaker_session):
209+
container_log_level = '"logging.INFO"'
210+
source_dir = 's3://mybucket/source'
211+
enable_cloudwatch_metrics = 'true'
212+
tf = TensorFlow(entry_point=SCRIPT_PATH, role=ROLE, sagemaker_session=sagemaker_session,
213+
training_steps=1000, evaluation_steps=10, train_instance_count=INSTANCE_COUNT,
214+
train_instance_type=INSTANCE_TYPE, container_log_level=container_log_level, base_job_name='job',
215+
source_dir=source_dir, enable_cloudwatch_metrics=enable_cloudwatch_metrics)
216+
217+
job_name = 'doing something'
218+
tf.fit(inputs='s3://mybucket/train', job_name=job_name)
219+
220+
new_role = 'role'
221+
model_server_workers = 2
222+
model = tf.create_model(role=new_role, model_server_workers=2)
223+
224+
assert model.role == new_role
225+
assert model.model_server_workers == model_server_workers
226+
227+
208228
def test_create_model_with_custom_image(sagemaker_session):
209229
container_log_level = '"logging.INFO"'
210230
source_dir = 's3://mybucket/source'

0 commit comments

Comments
 (0)