Skip to content

Allow Model and Transformer to use a different role from the Estimator #308

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jul 20, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ CHANGELOG

* bug-fix: get_execution_role no longer fails if user can't call get_role
* bug-fix: Session: use existing model instead of failing during ``create_model()``
* enhancement: Estimator: allow for different role from the Estimator's when creating a Model or Transformer

1.7.0
=====
Expand Down
7 changes: 5 additions & 2 deletions src/sagemaker/chainer/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,18 +98,21 @@ def hyperparameters(self):
hyperparameters.update(Framework._json_encode_hyperparameters(additional_hyperparameters))
return hyperparameters

def create_model(self, model_server_workers=None):
def create_model(self, model_server_workers=None, role=None):
"""Create a SageMaker ``ChainerModel`` object that can be deployed to an ``Endpoint``.

Args:
role (str): The ``ExecutionRoleArn`` IAM Role ARN for the ``Model``, which is also used during
transform jobs. If not specified, the role from the Estimator will be used.
model_server_workers (int): Optional. The number of worker processes used by the inference server.
If None, server will use one worker per vCPU.

Returns:
sagemaker.chainer.model.ChainerModel: A SageMaker ``ChainerModel`` object.
See :func:`~sagemaker.chainer.model.ChainerModel` for full details.
"""
return ChainerModel(self.model_data, self.role, self.entry_point, source_dir=self._model_source_dir(),
role = role or self.role
return ChainerModel(self.model_data, role, self.entry_point, source_dir=self._model_source_dir(),
enable_cloudwatch_metrics=self.enable_cloudwatch_metrics, name=self._current_job_name,
container_log_level=self.container_log_level, code_location=self.code_location,
py_version=self.py_version, framework_version=self.framework_version,
Expand Down
23 changes: 16 additions & 7 deletions src/sagemaker/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,7 @@ def delete_endpoint(self):

def transformer(self, instance_count, instance_type, strategy=None, assemble_with=None, output_path=None,
output_kms_key=None, accept=None, env=None, max_concurrent_transforms=None,
max_payload=None, tags=None):
max_payload=None, tags=None, role=None):
"""Return a ``Transformer`` that uses a SageMaker Model based on the training job. It reuses the
SageMaker Session and base job name used by the Estimator.

Expand All @@ -339,10 +339,12 @@ def transformer(self, instance_count, instance_type, strategy=None, assemble_wit
max_payload (int): Maximum size of the payload in a single HTTP request to the container in MB.
tags (list[dict]): List of tags for labeling a transform job. If none specified, then the tags used for
the training job are used for the transform job.
role (str): The ``ExecutionRoleArn`` IAM Role ARN for the ``Model``, which is also used during
transform jobs. If not specified, the role from the Estimator will be used.
"""
self._ensure_latest_training_job()

model_name = self.sagemaker_session.create_model_from_job(self.latest_training_job.name)
model_name = self.sagemaker_session.create_model_from_job(self.latest_training_job.name, role=role)
tags = tags or self.tags

return Transformer(model_name, instance_count, instance_type, strategy=strategy, assemble_with=assemble_with,
Expand Down Expand Up @@ -476,12 +478,14 @@ def hyperparameters(self):
"""
return self.hyperparam_dict

def create_model(self, image=None, predictor_cls=None, serializer=None, deserializer=None,
def create_model(self, role=None, image=None, predictor_cls=None, serializer=None, deserializer=None,
content_type=None, accept=None, **kwargs):
"""
Create a model to deploy.

Args:
role (str): The ``ExecutionRoleArn`` IAM Role ARN for the ``Model``, which is also used during
transform jobs. If not specified, the role from the Estimator will be used.
image (str): An container image to use for deploying the model. Defaults to the image used for training.
predictor_cls (RealTimePredictor): The predictor class to use when deploying the model.
serializer (callable): Should accept a single argument, the input data, and return a sequence
Expand All @@ -503,7 +507,9 @@ def predict_wrapper(endpoint, session):
return RealTimePredictor(endpoint, session, serializer, deserializer, content_type, accept)
predictor_cls = predict_wrapper

return Model(self.model_data, image or self.train_image(), self.role, sagemaker_session=self.sagemaker_session,
role = role or self.role

return Model(self.model_data, image or self.train_image(), role, sagemaker_session=self.sagemaker_session,
predictor_cls=predictor_cls, **kwargs)

@classmethod
Expand Down Expand Up @@ -737,7 +743,7 @@ def _update_init_params(cls, hp, tf_arguments):

def transformer(self, instance_count, instance_type, strategy=None, assemble_with=None, output_path=None,
output_kms_key=None, accept=None, env=None, max_concurrent_transforms=None,
max_payload=None, tags=None, model_server_workers=None):
max_payload=None, tags=None, role=None, model_server_workers=None):
"""Return a ``Transformer`` that uses a SageMaker Model based on the training job. It reuses the
SageMaker Session and base job name used by the Estimator.

Expand All @@ -757,16 +763,19 @@ def transformer(self, instance_count, instance_type, strategy=None, assemble_wit
max_payload (int): Maximum size of the payload in a single HTTP request to the container in MB.
tags (list[dict]): List of tags for labeling a transform job. If none specified, then the tags used for
the training job are used for the transform job.
role (str): The ``ExecutionRoleArn`` IAM Role ARN for the ``Model``, which is also used during
transform jobs. If not specified, the role from the Estimator will be used.
model_server_workers (int): Optional. The number of worker processes used by the inference server.
If None, server will use one worker per vCPU.
"""
self._ensure_latest_training_job()
role = role or self.role

model = self.create_model(model_server_workers=model_server_workers)
model = self.create_model(role=role, model_server_workers=model_server_workers)

container_def = model.prepare_container_def(instance_type)
model_name = model.name or name_from_image(container_def['Image'])
self.sagemaker_session.create_model(model_name, self.role, container_def)
self.sagemaker_session.create_model(model_name, role, container_def)

transform_env = model.env.copy()
if env is not None:
Expand Down
7 changes: 5 additions & 2 deletions src/sagemaker/mxnet/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,18 +65,21 @@ def __init__(self, entry_point, source_dir=None, hyperparameters=None, py_versio
self.py_version = py_version
self.framework_version = framework_version

def create_model(self, model_server_workers=None):
def create_model(self, model_server_workers=None, role=None):
"""Create a SageMaker ``MXNetModel`` object that can be deployed to an ``Endpoint``.

Args:
role (str): The ``ExecutionRoleArn`` IAM Role ARN for the ``Model``, which is also used during
transform jobs. If not specified, the role from the Estimator will be used.
model_server_workers (int): Optional. The number of worker processes used by the inference server.
If None, server will use one worker per vCPU.

Returns:
sagemaker.mxnet.model.MXNetModel: A SageMaker ``MXNetModel`` object.
See :func:`~sagemaker.mxnet.model.MXNetModel` for full details.
"""
return MXNetModel(self.model_data, self.role, self.entry_point, source_dir=self._model_source_dir(),
role = role or self.role
return MXNetModel(self.model_data, role, self.entry_point, source_dir=self._model_source_dir(),
enable_cloudwatch_metrics=self.enable_cloudwatch_metrics, name=self._current_job_name,
container_log_level=self.container_log_level, code_location=self.code_location,
py_version=self.py_version, framework_version=self.framework_version, image=self.image_name,
Expand Down
7 changes: 5 additions & 2 deletions src/sagemaker/pytorch/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,18 +63,21 @@ def __init__(self, entry_point, source_dir=None, hyperparameters=None, py_versio
self.py_version = py_version
self.framework_version = framework_version

def create_model(self, model_server_workers=None):
def create_model(self, model_server_workers=None, role=None):
"""Create a SageMaker ``PyTorchModel`` object that can be deployed to an ``Endpoint``.

Args:
role (str): The ``ExecutionRoleArn`` IAM Role ARN for the ``Model``, which is also used during
transform jobs. If not specified, the role from the Estimator will be used.
model_server_workers (int): Optional. The number of worker processes used by the inference server.
If None, server will use one worker per vCPU.

Returns:
sagemaker.pytorch.model.PyTorchModel: A SageMaker ``PyTorchModel`` object.
See :func:`~sagemaker.pytorch.model.PyTorchModel` for full details.
"""
return PyTorchModel(self.model_data, self.role, self.entry_point, source_dir=self._model_source_dir(),
role = role or self.role
return PyTorchModel(self.model_data, role, self.entry_point, source_dir=self._model_source_dir(),
enable_cloudwatch_metrics=self.enable_cloudwatch_metrics, name=self._current_job_name,
container_log_level=self.container_log_level, code_location=self.code_location,
py_version=self.py_version, framework_version=self.framework_version, image=self.image_name,
Expand Down
7 changes: 5 additions & 2 deletions src/sagemaker/tensorflow/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,10 +289,12 @@ def _prepare_init_params_from_job_description(cls, job_details):

return init_params

def create_model(self, model_server_workers=None):
def create_model(self, model_server_workers=None, role=None):
"""Create a SageMaker ``TensorFlowModel`` object that can be deployed to an ``Endpoint``.

Args:
role (str): The ``ExecutionRoleArn`` IAM Role ARN for the ``Model``, which is also used during
transform jobs. If not specified, the role from the Estimator will be used.
model_server_workers (int): Optional. The number of worker processes used by the inference server.
If None, server will use one worker per vCPU.

Expand All @@ -301,7 +303,8 @@ def create_model(self, model_server_workers=None):
See :func:`~sagemaker.tensorflow.model.TensorFlowModel` for full details.
"""
env = {'SAGEMAKER_REQUIREMENTS': self.requirements_file}
return TensorFlowModel(self.model_data, self.role, self.entry_point, source_dir=self._model_source_dir(),
role = role or self.role
return TensorFlowModel(self.model_data, role, self.entry_point, source_dir=self._model_source_dir(),
enable_cloudwatch_metrics=self.enable_cloudwatch_metrics, env=env, image=self.image_name,
name=self._current_job_name, container_log_level=self.container_log_level,
code_location=self.code_location, py_version=self.py_version,
Expand Down
21 changes: 20 additions & 1 deletion tests/unit/test_chainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ def test_create_model(sagemaker_session, chainer_version):
enable_cloudwatch_metrics=enable_cloudwatch_metrics)

job_name = 'new_name'
chainer.fit(inputs='s3://mybucket/train', job_name='new_name')
chainer.fit(inputs='s3://mybucket/train', job_name=job_name)
model = chainer.create_model()

assert model.sagemaker_session == sagemaker_session
Expand All @@ -247,6 +247,25 @@ def test_create_model(sagemaker_session, chainer_version):
assert model.enable_cloudwatch_metrics == enable_cloudwatch_metrics


def test_create_model_with_optional_params(sagemaker_session):
container_log_level = '"logging.INFO"'
source_dir = 's3://mybucket/source'
enable_cloudwatch_metrics = 'true'
chainer = Chainer(entry_point=SCRIPT_PATH, role=ROLE, sagemaker_session=sagemaker_session,
train_instance_count=INSTANCE_COUNT, train_instance_type=INSTANCE_TYPE,
container_log_level=container_log_level, py_version=PYTHON_VERSION, base_job_name='job',
source_dir=source_dir, enable_cloudwatch_metrics=enable_cloudwatch_metrics)

chainer.fit(inputs='s3://mybucket/train', job_name='new_name')

new_role = 'role'
model_server_workers = 2
model = chainer.create_model(role=new_role, model_server_workers=model_server_workers)

assert model.role == new_role
assert model.model_server_workers == model_server_workers


def test_create_model_with_custom_image(sagemaker_session):
container_log_level = '"logging.INFO"'
source_dir = 's3://mybucket/source'
Expand Down
13 changes: 7 additions & 6 deletions tests/unit/test_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ class DummyFramework(Framework):
def train_image(self):
return IMAGE_NAME

def create_model(self, model_server_workers=None):
def create_model(self, role=None, model_server_workers=None):
return DummyFrameworkModel(self.sagemaker_session)

@classmethod
Expand Down Expand Up @@ -476,21 +476,21 @@ def test_framework_transformer_creation_with_optional_params(name_from_image, sa
base_job_name=base_name)
fw.latest_training_job = _TrainingJob(sagemaker_session, JOB_NAME)

transformer = fw.transformer(INSTANCE_COUNT, INSTANCE_TYPE)

strategy = 'MultiRecord'
assemble_with = 'Line'
kms_key = 'key'
accept = 'text/csv'
max_concurrent_transforms = 1
max_payload = 6
env = {'FOO': 'BAR'}
new_role = 'dummy-model-role'

transformer = fw.transformer(INSTANCE_COUNT, INSTANCE_TYPE, strategy=strategy, assemble_with=assemble_with,
output_path=OUTPUT_PATH, output_kms_key=kms_key, accept=accept, tags=TAGS,
max_concurrent_transforms=max_concurrent_transforms, max_payload=max_payload,
env=env, model_server_workers=1)
env=env, role=new_role, model_server_workers=1)

sagemaker_session.create_model.assert_called_with(MODEL_IMAGE, new_role, MODEL_CONTAINER_DEF)
assert transformer.strategy == strategy
assert transformer.assemble_with == assemble_with
assert transformer.output_path == OUTPUT_PATH
Expand Down Expand Up @@ -528,7 +528,7 @@ def test_estimator_transformer_creation(sagemaker_session):

transformer = estimator.transformer(INSTANCE_COUNT, INSTANCE_TYPE)

sagemaker_session.create_model_from_job.assert_called_with(JOB_NAME)
sagemaker_session.create_model_from_job.assert_called_with(JOB_NAME, role=None)
assert isinstance(transformer, Transformer)
assert transformer.sagemaker_session == sagemaker_session
assert transformer.instance_count == INSTANCE_COUNT
Expand Down Expand Up @@ -556,8 +556,9 @@ def test_estimator_transformer_creation_with_optional_params(sagemaker_session):
transformer = estimator.transformer(INSTANCE_COUNT, INSTANCE_TYPE, strategy=strategy, assemble_with=assemble_with,
output_path=OUTPUT_PATH, output_kms_key=kms_key, accept=accept, tags=TAGS,
max_concurrent_transforms=max_concurrent_transforms, max_payload=max_payload,
env=env)
env=env, role=ROLE)

sagemaker_session.create_model_from_job.assert_called_with(JOB_NAME, role=ROLE)
assert transformer.strategy == strategy
assert transformer.assemble_with == assemble_with
assert transformer.output_path == OUTPUT_PATH
Expand Down
21 changes: 20 additions & 1 deletion tests/unit/test_mxnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def test_create_model(sagemaker_session, mxnet_version):
base_job_name='job', source_dir=source_dir, enable_cloudwatch_metrics=enable_cloudwatch_metrics)

job_name = 'new_name'
mx.fit(inputs='s3://mybucket/train', job_name='new_name')
mx.fit(inputs='s3://mybucket/train', job_name=job_name)
model = mx.create_model()

assert model.sagemaker_session == sagemaker_session
Expand All @@ -122,6 +122,25 @@ def test_create_model(sagemaker_session, mxnet_version):
assert model.enable_cloudwatch_metrics == enable_cloudwatch_metrics


def test_create_model_with_optional_params(sagemaker_session):
container_log_level = '"logging.INFO"'
source_dir = 's3://mybucket/source'
enable_cloudwatch_metrics = 'true'
mx = MXNet(entry_point=SCRIPT_PATH, role=ROLE, sagemaker_session=sagemaker_session,
train_instance_count=INSTANCE_COUNT, train_instance_type=INSTANCE_TYPE,
container_log_level=container_log_level, base_job_name='job', source_dir=source_dir,
enable_cloudwatch_metrics=enable_cloudwatch_metrics)

mx.fit(inputs='s3://mybucket/train', job_name='new_name')

new_role = 'role'
model_server_workers = 2
model = mx.create_model(role=new_role, model_server_workers=model_server_workers)

assert model.role == new_role
assert model.model_server_workers == model_server_workers


def test_create_model_with_custom_image(sagemaker_session):
container_log_level = '"logging.INFO"'
source_dir = 's3://mybucket/source'
Expand Down
19 changes: 19 additions & 0 deletions tests/unit/test_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,25 @@ def test_create_model(sagemaker_session, pytorch_version):
assert model.enable_cloudwatch_metrics == enable_cloudwatch_metrics


def test_create_model_with_optional_params(sagemaker_session):
container_log_level = '"logging.INFO"'
source_dir = 's3://mybucket/source'
enable_cloudwatch_metrics = 'true'
pytorch = PyTorch(entry_point=SCRIPT_PATH, role=ROLE, sagemaker_session=sagemaker_session,
train_instance_count=INSTANCE_COUNT, train_instance_type=INSTANCE_TYPE,
container_log_level=container_log_level, base_job_name='job', source_dir=source_dir,
enable_cloudwatch_metrics=enable_cloudwatch_metrics)

pytorch.fit(inputs='s3://mybucket/train', job_name='new_name')

new_role = 'role'
model_server_workers = 2
model = pytorch.create_model(role=new_role, model_server_workers=model_server_workers)

assert model.role == new_role
assert model.model_server_workers == model_server_workers


def test_create_model_with_custom_image(sagemaker_session):
container_log_level = '"logging.INFO"'
source_dir = 's3://mybucket/source'
Expand Down
20 changes: 20 additions & 0 deletions tests/unit/test_tf_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,26 @@ def test_create_model(sagemaker_session, tf_version):
assert model.enable_cloudwatch_metrics == enable_cloudwatch_metrics


def test_create_model_with_optional_params(sagemaker_session):
container_log_level = '"logging.INFO"'
source_dir = 's3://mybucket/source'
enable_cloudwatch_metrics = 'true'
tf = TensorFlow(entry_point=SCRIPT_PATH, role=ROLE, sagemaker_session=sagemaker_session,
training_steps=1000, evaluation_steps=10, train_instance_count=INSTANCE_COUNT,
train_instance_type=INSTANCE_TYPE, container_log_level=container_log_level, base_job_name='job',
source_dir=source_dir, enable_cloudwatch_metrics=enable_cloudwatch_metrics)

job_name = 'doing something'
tf.fit(inputs='s3://mybucket/train', job_name=job_name)

new_role = 'role'
model_server_workers = 2
model = tf.create_model(role=new_role, model_server_workers=2)

assert model.role == new_role
assert model.model_server_workers == model_server_workers


def test_create_model_with_custom_image(sagemaker_session):
container_log_level = '"logging.INFO"'
source_dir = 's3://mybucket/source'
Expand Down