Skip to content

Commit d6729b3

Browse files
authored
Merge branch 'master' into remove_cw_metrics_arg
2 parents bfaafe3 + 3253466 commit d6729b3

File tree

12 files changed

+148
-24
lines changed

12 files changed

+148
-24
lines changed

CHANGELOG.rst

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@ CHANGELOG
77

88
* bug-fix: get_execution_role no longer fails if user can't call get_role
99
* bug-fix: Session: use existing model instead of failing during ``create_model()``
10-
* deprecate enable_cloudwatch_metrics from Framework Estimators.
10+
* bug-fix: deprecate enable_cloudwatch_metrics from Framework Estimators.
11+
* enhancement: Estimator: allow for different role from the Estimator's when creating a Model or Transformer
1112

1213
1.7.0
1314
=====

README.rst

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,30 @@ instance type.
192192
mxnet_estimator.delete_endpoint()
193193
194194
195+
If you have an existing model and would like to deploy it locally you can do that as well. If you don't
196+
specify a sagemaker_session argument to the MXNetModel constructor, the right session will be generated
197+
when calling model.deploy()
198+
199+
Here is an end to end example:
200+
201+
.. code:: python
202+
203+
import numpy
204+
from sagemaker.mxnet import MXNetModel
205+
206+
model_location = 's3://mybucket/my_model.tar.gz'
207+
code_location = 's3://mybucket/sourcedir.tar.gz'
208+
s3_model = MXNetModel(model_data=model_location, role='SageMakerRole',
209+
entry_point='mnist.py', source_dir=code_location)
210+
211+
predictor = s3_model.deploy(initial_instance_count=1, instance_type='local')
212+
data = numpy.zeros(shape=(1, 1, 28, 28))
213+
predictor.predict(data)
214+
215+
# Tear down the endpoint container
216+
predictor.delete_endpoint()
217+
218+
195219
For detailed examples of running docker in local mode, see:
196220

197221
- `TensorFlow local mode example notebook <https://github.com/awslabs/amazon-sagemaker-examples/blob/master/sagemaker-python-sdk/tensorflow_distributed_mnist/tensorflow_local_mode_mnist.ipynb>`__.

src/sagemaker/chainer/estimator.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,18 +98,21 @@ def hyperparameters(self):
9898
hyperparameters.update(Framework._json_encode_hyperparameters(additional_hyperparameters))
9999
return hyperparameters
100100

101-
def create_model(self, model_server_workers=None):
101+
def create_model(self, model_server_workers=None, role=None):
102102
"""Create a SageMaker ``ChainerModel`` object that can be deployed to an ``Endpoint``.
103103
104104
Args:
105+
role (str): The ``ExecutionRoleArn`` IAM Role ARN for the ``Model``, which is also used during
106+
transform jobs. If not specified, the role from the Estimator will be used.
105107
model_server_workers (int): Optional. The number of worker processes used by the inference server.
106108
If None, server will use one worker per vCPU.
107109
108110
Returns:
109111
sagemaker.chainer.model.ChainerModel: A SageMaker ``ChainerModel`` object.
110112
See :func:`~sagemaker.chainer.model.ChainerModel` for full details.
111113
"""
112-
return ChainerModel(self.model_data, self.role, self.entry_point, source_dir=self._model_source_dir(),
114+
role = role or self.role
115+
return ChainerModel(self.model_data, role, self.entry_point, source_dir=self._model_source_dir(),
113116
enable_cloudwatch_metrics=self.enable_cloudwatch_metrics, name=self._current_job_name,
114117
container_log_level=self.container_log_level, code_location=self.code_location,
115118
py_version=self.py_version, framework_version=self.framework_version,

src/sagemaker/estimator.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -320,7 +320,7 @@ def delete_endpoint(self):
320320

321321
def transformer(self, instance_count, instance_type, strategy=None, assemble_with=None, output_path=None,
322322
output_kms_key=None, accept=None, env=None, max_concurrent_transforms=None,
323-
max_payload=None, tags=None):
323+
max_payload=None, tags=None, role=None):
324324
"""Return a ``Transformer`` that uses a SageMaker Model based on the training job. It reuses the
325325
SageMaker Session and base job name used by the Estimator.
326326
@@ -340,10 +340,12 @@ def transformer(self, instance_count, instance_type, strategy=None, assemble_wit
340340
max_payload (int): Maximum size of the payload in a single HTTP request to the container in MB.
341341
tags (list[dict]): List of tags for labeling a transform job. If none specified, then the tags used for
342342
the training job are used for the transform job.
343+
role (str): The ``ExecutionRoleArn`` IAM Role ARN for the ``Model``, which is also used during
344+
transform jobs. If not specified, the role from the Estimator will be used.
343345
"""
344346
self._ensure_latest_training_job()
345347

346-
model_name = self.sagemaker_session.create_model_from_job(self.latest_training_job.name)
348+
model_name = self.sagemaker_session.create_model_from_job(self.latest_training_job.name, role=role)
347349
tags = tags or self.tags
348350

349351
return Transformer(model_name, instance_count, instance_type, strategy=strategy, assemble_with=assemble_with,
@@ -477,12 +479,14 @@ def hyperparameters(self):
477479
"""
478480
return self.hyperparam_dict
479481

480-
def create_model(self, image=None, predictor_cls=None, serializer=None, deserializer=None,
482+
def create_model(self, role=None, image=None, predictor_cls=None, serializer=None, deserializer=None,
481483
content_type=None, accept=None, **kwargs):
482484
"""
483485
Create a model to deploy.
484486
485487
Args:
488+
role (str): The ``ExecutionRoleArn`` IAM Role ARN for the ``Model``, which is also used during
489+
transform jobs. If not specified, the role from the Estimator will be used.
486490
image (str): An container image to use for deploying the model. Defaults to the image used for training.
487491
predictor_cls (RealTimePredictor): The predictor class to use when deploying the model.
488492
serializer (callable): Should accept a single argument, the input data, and return a sequence
@@ -504,7 +508,9 @@ def predict_wrapper(endpoint, session):
504508
return RealTimePredictor(endpoint, session, serializer, deserializer, content_type, accept)
505509
predictor_cls = predict_wrapper
506510

507-
return Model(self.model_data, image or self.train_image(), self.role, sagemaker_session=self.sagemaker_session,
511+
role = role or self.role
512+
513+
return Model(self.model_data, image or self.train_image(), role, sagemaker_session=self.sagemaker_session,
508514
predictor_cls=predictor_cls, **kwargs)
509515

510516
@classmethod
@@ -741,7 +747,7 @@ def _update_init_params(cls, hp, tf_arguments):
741747

742748
def transformer(self, instance_count, instance_type, strategy=None, assemble_with=None, output_path=None,
743749
output_kms_key=None, accept=None, env=None, max_concurrent_transforms=None,
744-
max_payload=None, tags=None, model_server_workers=None):
750+
max_payload=None, tags=None, role=None, model_server_workers=None):
745751
"""Return a ``Transformer`` that uses a SageMaker Model based on the training job. It reuses the
746752
SageMaker Session and base job name used by the Estimator.
747753
@@ -761,16 +767,19 @@ def transformer(self, instance_count, instance_type, strategy=None, assemble_wit
761767
max_payload (int): Maximum size of the payload in a single HTTP request to the container in MB.
762768
tags (list[dict]): List of tags for labeling a transform job. If none specified, then the tags used for
763769
the training job are used for the transform job.
770+
role (str): The ``ExecutionRoleArn`` IAM Role ARN for the ``Model``, which is also used during
771+
transform jobs. If not specified, the role from the Estimator will be used.
764772
model_server_workers (int): Optional. The number of worker processes used by the inference server.
765773
If None, server will use one worker per vCPU.
766774
"""
767775
self._ensure_latest_training_job()
776+
role = role or self.role
768777

769-
model = self.create_model(model_server_workers=model_server_workers)
778+
model = self.create_model(role=role, model_server_workers=model_server_workers)
770779

771780
container_def = model.prepare_container_def(instance_type)
772781
model_name = model.name or name_from_image(container_def['Image'])
773-
self.sagemaker_session.create_model(model_name, self.role, container_def)
782+
self.sagemaker_session.create_model(model_name, role, container_def)
774783

775784
transform_env = model.env.copy()
776785
if env is not None:

src/sagemaker/mxnet/estimator.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,18 +65,21 @@ def __init__(self, entry_point, source_dir=None, hyperparameters=None, py_versio
6565
self.py_version = py_version
6666
self.framework_version = framework_version
6767

68-
def create_model(self, model_server_workers=None):
68+
def create_model(self, model_server_workers=None, role=None):
6969
"""Create a SageMaker ``MXNetModel`` object that can be deployed to an ``Endpoint``.
7070
7171
Args:
72+
role (str): The ``ExecutionRoleArn`` IAM Role ARN for the ``Model``, which is also used during
73+
transform jobs. If not specified, the role from the Estimator will be used.
7274
model_server_workers (int): Optional. The number of worker processes used by the inference server.
7375
If None, server will use one worker per vCPU.
7476
7577
Returns:
7678
sagemaker.mxnet.model.MXNetModel: A SageMaker ``MXNetModel`` object.
7779
See :func:`~sagemaker.mxnet.model.MXNetModel` for full details.
7880
"""
79-
return MXNetModel(self.model_data, self.role, self.entry_point, source_dir=self._model_source_dir(),
81+
role = role or self.role
82+
return MXNetModel(self.model_data, role, self.entry_point, source_dir=self._model_source_dir(),
8083
enable_cloudwatch_metrics=self.enable_cloudwatch_metrics, name=self._current_job_name,
8184
container_log_level=self.container_log_level, code_location=self.code_location,
8285
py_version=self.py_version, framework_version=self.framework_version, image=self.image_name,

src/sagemaker/pytorch/estimator.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,18 +63,21 @@ def __init__(self, entry_point, source_dir=None, hyperparameters=None, py_versio
6363
self.py_version = py_version
6464
self.framework_version = framework_version
6565

66-
def create_model(self, model_server_workers=None):
66+
def create_model(self, model_server_workers=None, role=None):
6767
"""Create a SageMaker ``PyTorchModel`` object that can be deployed to an ``Endpoint``.
6868
6969
Args:
70+
role (str): The ``ExecutionRoleArn`` IAM Role ARN for the ``Model``, which is also used during
71+
transform jobs. If not specified, the role from the Estimator will be used.
7072
model_server_workers (int): Optional. The number of worker processes used by the inference server.
7173
If None, server will use one worker per vCPU.
7274
7375
Returns:
7476
sagemaker.pytorch.model.PyTorchModel: A SageMaker ``PyTorchModel`` object.
7577
See :func:`~sagemaker.pytorch.model.PyTorchModel` for full details.
7678
"""
77-
return PyTorchModel(self.model_data, self.role, self.entry_point, source_dir=self._model_source_dir(),
79+
role = role or self.role
80+
return PyTorchModel(self.model_data, role, self.entry_point, source_dir=self._model_source_dir(),
7881
enable_cloudwatch_metrics=self.enable_cloudwatch_metrics, name=self._current_job_name,
7982
container_log_level=self.container_log_level, code_location=self.code_location,
8083
py_version=self.py_version, framework_version=self.framework_version, image=self.image_name,

src/sagemaker/tensorflow/estimator.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -289,10 +289,12 @@ def _prepare_init_params_from_job_description(cls, job_details):
289289

290290
return init_params
291291

292-
def create_model(self, model_server_workers=None):
292+
def create_model(self, model_server_workers=None, role=None):
293293
"""Create a SageMaker ``TensorFlowModel`` object that can be deployed to an ``Endpoint``.
294294
295295
Args:
296+
role (str): The ``ExecutionRoleArn`` IAM Role ARN for the ``Model``, which is also used during
297+
transform jobs. If not specified, the role from the Estimator will be used.
296298
model_server_workers (int): Optional. The number of worker processes used by the inference server.
297299
If None, server will use one worker per vCPU.
298300
@@ -301,7 +303,8 @@ def create_model(self, model_server_workers=None):
301303
See :func:`~sagemaker.tensorflow.model.TensorFlowModel` for full details.
302304
"""
303305
env = {'SAGEMAKER_REQUIREMENTS': self.requirements_file}
304-
return TensorFlowModel(self.model_data, self.role, self.entry_point, source_dir=self._model_source_dir(),
306+
role = role or self.role
307+
return TensorFlowModel(self.model_data, role, self.entry_point, source_dir=self._model_source_dir(),
305308
enable_cloudwatch_metrics=self.enable_cloudwatch_metrics, env=env, image=self.image_name,
306309
name=self._current_job_name, container_log_level=self.container_log_level,
307310
code_location=self.code_location, py_version=self.py_version,

tests/unit/test_chainer.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,7 @@ def test_create_model(sagemaker_session, chainer_version):
229229
py_version=PYTHON_VERSION, base_job_name='job', source_dir=source_dir)
230230

231231
job_name = 'new_name'
232-
chainer.fit(inputs='s3://mybucket/train', job_name='new_name')
232+
chainer.fit(inputs='s3://mybucket/train', job_name=job_name)
233233
model = chainer.create_model()
234234

235235
assert model.sagemaker_session == sagemaker_session
@@ -242,6 +242,25 @@ def test_create_model(sagemaker_session, chainer_version):
242242
assert model.source_dir == source_dir
243243

244244

245+
def test_create_model_with_optional_params(sagemaker_session):
246+
container_log_level = '"logging.INFO"'
247+
source_dir = 's3://mybucket/source'
248+
enable_cloudwatch_metrics = 'true'
249+
chainer = Chainer(entry_point=SCRIPT_PATH, role=ROLE, sagemaker_session=sagemaker_session,
250+
train_instance_count=INSTANCE_COUNT, train_instance_type=INSTANCE_TYPE,
251+
container_log_level=container_log_level, py_version=PYTHON_VERSION, base_job_name='job',
252+
source_dir=source_dir, enable_cloudwatch_metrics=enable_cloudwatch_metrics)
253+
254+
chainer.fit(inputs='s3://mybucket/train', job_name='new_name')
255+
256+
new_role = 'role'
257+
model_server_workers = 2
258+
model = chainer.create_model(role=new_role, model_server_workers=model_server_workers)
259+
260+
assert model.role == new_role
261+
assert model.model_server_workers == model_server_workers
262+
263+
245264
def test_create_model_with_custom_image(sagemaker_session):
246265
container_log_level = '"logging.INFO"'
247266
source_dir = 's3://mybucket/source'

tests/unit/test_estimator.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ class DummyFramework(Framework):
8686
def train_image(self):
8787
return IMAGE_NAME
8888

89-
def create_model(self, model_server_workers=None):
89+
def create_model(self, role=None, model_server_workers=None):
9090
return DummyFrameworkModel(self.sagemaker_session)
9191

9292
@classmethod
@@ -476,21 +476,21 @@ def test_framework_transformer_creation_with_optional_params(name_from_image, sa
476476
base_job_name=base_name)
477477
fw.latest_training_job = _TrainingJob(sagemaker_session, JOB_NAME)
478478

479-
transformer = fw.transformer(INSTANCE_COUNT, INSTANCE_TYPE)
480-
481479
strategy = 'MultiRecord'
482480
assemble_with = 'Line'
483481
kms_key = 'key'
484482
accept = 'text/csv'
485483
max_concurrent_transforms = 1
486484
max_payload = 6
487485
env = {'FOO': 'BAR'}
486+
new_role = 'dummy-model-role'
488487

489488
transformer = fw.transformer(INSTANCE_COUNT, INSTANCE_TYPE, strategy=strategy, assemble_with=assemble_with,
490489
output_path=OUTPUT_PATH, output_kms_key=kms_key, accept=accept, tags=TAGS,
491490
max_concurrent_transforms=max_concurrent_transforms, max_payload=max_payload,
492-
env=env, model_server_workers=1)
491+
env=env, role=new_role, model_server_workers=1)
493492

493+
sagemaker_session.create_model.assert_called_with(MODEL_IMAGE, new_role, MODEL_CONTAINER_DEF)
494494
assert transformer.strategy == strategy
495495
assert transformer.assemble_with == assemble_with
496496
assert transformer.output_path == OUTPUT_PATH
@@ -528,7 +528,7 @@ def test_estimator_transformer_creation(sagemaker_session):
528528

529529
transformer = estimator.transformer(INSTANCE_COUNT, INSTANCE_TYPE)
530530

531-
sagemaker_session.create_model_from_job.assert_called_with(JOB_NAME)
531+
sagemaker_session.create_model_from_job.assert_called_with(JOB_NAME, role=None)
532532
assert isinstance(transformer, Transformer)
533533
assert transformer.sagemaker_session == sagemaker_session
534534
assert transformer.instance_count == INSTANCE_COUNT
@@ -556,8 +556,9 @@ def test_estimator_transformer_creation_with_optional_params(sagemaker_session):
556556
transformer = estimator.transformer(INSTANCE_COUNT, INSTANCE_TYPE, strategy=strategy, assemble_with=assemble_with,
557557
output_path=OUTPUT_PATH, output_kms_key=kms_key, accept=accept, tags=TAGS,
558558
max_concurrent_transforms=max_concurrent_transforms, max_payload=max_payload,
559-
env=env)
559+
env=env, role=ROLE)
560560

561+
sagemaker_session.create_model_from_job.assert_called_with(JOB_NAME, role=ROLE)
561562
assert transformer.strategy == strategy
562563
assert transformer.assemble_with == assemble_with
563564
assert transformer.output_path == OUTPUT_PATH

tests/unit/test_mxnet.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ def test_create_model(sagemaker_session, mxnet_version):
107107
base_job_name='job', source_dir=source_dir)
108108

109109
job_name = 'new_name'
110-
mx.fit(inputs='s3://mybucket/train', job_name='new_name')
110+
mx.fit(inputs='s3://mybucket/train', job_name=job_name)
111111
model = mx.create_model()
112112

113113
assert model.sagemaker_session == sagemaker_session
@@ -120,6 +120,25 @@ def test_create_model(sagemaker_session, mxnet_version):
120120
assert model.source_dir == source_dir
121121

122122

123+
def test_create_model_with_optional_params(sagemaker_session):
124+
container_log_level = '"logging.INFO"'
125+
source_dir = 's3://mybucket/source'
126+
enable_cloudwatch_metrics = 'true'
127+
mx = MXNet(entry_point=SCRIPT_PATH, role=ROLE, sagemaker_session=sagemaker_session,
128+
train_instance_count=INSTANCE_COUNT, train_instance_type=INSTANCE_TYPE,
129+
container_log_level=container_log_level, base_job_name='job', source_dir=source_dir,
130+
enable_cloudwatch_metrics=enable_cloudwatch_metrics)
131+
132+
mx.fit(inputs='s3://mybucket/train', job_name='new_name')
133+
134+
new_role = 'role'
135+
model_server_workers = 2
136+
model = mx.create_model(role=new_role, model_server_workers=model_server_workers)
137+
138+
assert model.role == new_role
139+
assert model.model_server_workers == model_server_workers
140+
141+
123142
def test_create_model_with_custom_image(sagemaker_session):
124143
container_log_level = '"logging.INFO"'
125144
source_dir = 's3://mybucket/source'

tests/unit/test_pytorch.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,25 @@ def test_create_model(sagemaker_session, pytorch_version):
137137
assert model.source_dir == source_dir
138138

139139

140+
def test_create_model_with_optional_params(sagemaker_session):
141+
container_log_level = '"logging.INFO"'
142+
source_dir = 's3://mybucket/source'
143+
enable_cloudwatch_metrics = 'true'
144+
pytorch = PyTorch(entry_point=SCRIPT_PATH, role=ROLE, sagemaker_session=sagemaker_session,
145+
train_instance_count=INSTANCE_COUNT, train_instance_type=INSTANCE_TYPE,
146+
container_log_level=container_log_level, base_job_name='job', source_dir=source_dir,
147+
enable_cloudwatch_metrics=enable_cloudwatch_metrics)
148+
149+
pytorch.fit(inputs='s3://mybucket/train', job_name='new_name')
150+
151+
new_role = 'role'
152+
model_server_workers = 2
153+
model = pytorch.create_model(role=new_role, model_server_workers=model_server_workers)
154+
155+
assert model.role == new_role
156+
assert model.model_server_workers == model_server_workers
157+
158+
140159
def test_create_model_with_custom_image(sagemaker_session):
141160
container_log_level = '"logging.INFO"'
142161
source_dir = 's3://mybucket/source'

0 commit comments

Comments
 (0)