diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 119d05ab80..82d961e1ec 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -7,6 +7,7 @@ CHANGELOG * bug-fix: get_execution_role no longer fails if user can't call get_role * bug-fix: Session: use existing model instead of failing during ``create_model()`` +* enhancement: Estimator: allow for different role from the Estimator's when creating a Model or Transformer 1.7.0 ===== diff --git a/src/sagemaker/chainer/estimator.py b/src/sagemaker/chainer/estimator.py index 54d8b446b2..e2c1efc487 100644 --- a/src/sagemaker/chainer/estimator.py +++ b/src/sagemaker/chainer/estimator.py @@ -98,10 +98,12 @@ def hyperparameters(self): hyperparameters.update(Framework._json_encode_hyperparameters(additional_hyperparameters)) return hyperparameters - def create_model(self, model_server_workers=None): + def create_model(self, model_server_workers=None, role=None): """Create a SageMaker ``ChainerModel`` object that can be deployed to an ``Endpoint``. Args: + role (str): The ``ExecutionRoleArn`` IAM Role ARN for the ``Model``, which is also used during + transform jobs. If not specified, the role from the Estimator will be used. model_server_workers (int): Optional. The number of worker processes used by the inference server. If None, server will use one worker per vCPU. @@ -109,7 +111,8 @@ def create_model(self, model_server_workers=None): sagemaker.chainer.model.ChainerModel: A SageMaker ``ChainerModel`` object. See :func:`~sagemaker.chainer.model.ChainerModel` for full details. """ - return ChainerModel(self.model_data, self.role, self.entry_point, source_dir=self._model_source_dir(), + role = role or self.role + return ChainerModel(self.model_data, role, self.entry_point, source_dir=self._model_source_dir(), enable_cloudwatch_metrics=self.enable_cloudwatch_metrics, name=self._current_job_name, container_log_level=self.container_log_level, code_location=self.code_location, py_version=self.py_version, framework_version=self.framework_version, diff --git a/src/sagemaker/estimator.py b/src/sagemaker/estimator.py index 182c00796b..7959656535 100644 --- a/src/sagemaker/estimator.py +++ b/src/sagemaker/estimator.py @@ -319,7 +319,7 @@ def delete_endpoint(self): def transformer(self, instance_count, instance_type, strategy=None, assemble_with=None, output_path=None, output_kms_key=None, accept=None, env=None, max_concurrent_transforms=None, - max_payload=None, tags=None): + max_payload=None, tags=None, role=None): """Return a ``Transformer`` that uses a SageMaker Model based on the training job. It reuses the SageMaker Session and base job name used by the Estimator. @@ -339,10 +339,12 @@ def transformer(self, instance_count, instance_type, strategy=None, assemble_wit max_payload (int): Maximum size of the payload in a single HTTP request to the container in MB. tags (list[dict]): List of tags for labeling a transform job. If none specified, then the tags used for the training job are used for the transform job. + role (str): The ``ExecutionRoleArn`` IAM Role ARN for the ``Model``, which is also used during + transform jobs. If not specified, the role from the Estimator will be used. """ self._ensure_latest_training_job() - model_name = self.sagemaker_session.create_model_from_job(self.latest_training_job.name) + model_name = self.sagemaker_session.create_model_from_job(self.latest_training_job.name, role=role) tags = tags or self.tags return Transformer(model_name, instance_count, instance_type, strategy=strategy, assemble_with=assemble_with, @@ -476,12 +478,14 @@ def hyperparameters(self): """ return self.hyperparam_dict - def create_model(self, image=None, predictor_cls=None, serializer=None, deserializer=None, + def create_model(self, role=None, image=None, predictor_cls=None, serializer=None, deserializer=None, content_type=None, accept=None, **kwargs): """ Create a model to deploy. Args: + role (str): The ``ExecutionRoleArn`` IAM Role ARN for the ``Model``, which is also used during + transform jobs. If not specified, the role from the Estimator will be used. image (str): An container image to use for deploying the model. Defaults to the image used for training. predictor_cls (RealTimePredictor): The predictor class to use when deploying the model. serializer (callable): Should accept a single argument, the input data, and return a sequence @@ -503,7 +507,9 @@ def predict_wrapper(endpoint, session): return RealTimePredictor(endpoint, session, serializer, deserializer, content_type, accept) predictor_cls = predict_wrapper - return Model(self.model_data, image or self.train_image(), self.role, sagemaker_session=self.sagemaker_session, + role = role or self.role + + return Model(self.model_data, image or self.train_image(), role, sagemaker_session=self.sagemaker_session, predictor_cls=predictor_cls, **kwargs) @classmethod @@ -737,7 +743,7 @@ def _update_init_params(cls, hp, tf_arguments): def transformer(self, instance_count, instance_type, strategy=None, assemble_with=None, output_path=None, output_kms_key=None, accept=None, env=None, max_concurrent_transforms=None, - max_payload=None, tags=None, model_server_workers=None): + max_payload=None, tags=None, role=None, model_server_workers=None): """Return a ``Transformer`` that uses a SageMaker Model based on the training job. It reuses the SageMaker Session and base job name used by the Estimator. @@ -757,16 +763,19 @@ def transformer(self, instance_count, instance_type, strategy=None, assemble_wit max_payload (int): Maximum size of the payload in a single HTTP request to the container in MB. tags (list[dict]): List of tags for labeling a transform job. If none specified, then the tags used for the training job are used for the transform job. + role (str): The ``ExecutionRoleArn`` IAM Role ARN for the ``Model``, which is also used during + transform jobs. If not specified, the role from the Estimator will be used. model_server_workers (int): Optional. The number of worker processes used by the inference server. If None, server will use one worker per vCPU. """ self._ensure_latest_training_job() + role = role or self.role - model = self.create_model(model_server_workers=model_server_workers) + model = self.create_model(role=role, model_server_workers=model_server_workers) container_def = model.prepare_container_def(instance_type) model_name = model.name or name_from_image(container_def['Image']) - self.sagemaker_session.create_model(model_name, self.role, container_def) + self.sagemaker_session.create_model(model_name, role, container_def) transform_env = model.env.copy() if env is not None: diff --git a/src/sagemaker/mxnet/estimator.py b/src/sagemaker/mxnet/estimator.py index dc226de199..ca12440234 100644 --- a/src/sagemaker/mxnet/estimator.py +++ b/src/sagemaker/mxnet/estimator.py @@ -65,10 +65,12 @@ def __init__(self, entry_point, source_dir=None, hyperparameters=None, py_versio self.py_version = py_version self.framework_version = framework_version - def create_model(self, model_server_workers=None): + def create_model(self, model_server_workers=None, role=None): """Create a SageMaker ``MXNetModel`` object that can be deployed to an ``Endpoint``. Args: + role (str): The ``ExecutionRoleArn`` IAM Role ARN for the ``Model``, which is also used during + transform jobs. If not specified, the role from the Estimator will be used. model_server_workers (int): Optional. The number of worker processes used by the inference server. If None, server will use one worker per vCPU. @@ -76,7 +78,8 @@ def create_model(self, model_server_workers=None): sagemaker.mxnet.model.MXNetModel: A SageMaker ``MXNetModel`` object. See :func:`~sagemaker.mxnet.model.MXNetModel` for full details. """ - return MXNetModel(self.model_data, self.role, self.entry_point, source_dir=self._model_source_dir(), + role = role or self.role + return MXNetModel(self.model_data, role, self.entry_point, source_dir=self._model_source_dir(), enable_cloudwatch_metrics=self.enable_cloudwatch_metrics, name=self._current_job_name, container_log_level=self.container_log_level, code_location=self.code_location, py_version=self.py_version, framework_version=self.framework_version, image=self.image_name, diff --git a/src/sagemaker/pytorch/estimator.py b/src/sagemaker/pytorch/estimator.py index b5d0120ea3..26e3a0c8d6 100644 --- a/src/sagemaker/pytorch/estimator.py +++ b/src/sagemaker/pytorch/estimator.py @@ -63,10 +63,12 @@ def __init__(self, entry_point, source_dir=None, hyperparameters=None, py_versio self.py_version = py_version self.framework_version = framework_version - def create_model(self, model_server_workers=None): + def create_model(self, model_server_workers=None, role=None): """Create a SageMaker ``PyTorchModel`` object that can be deployed to an ``Endpoint``. Args: + role (str): The ``ExecutionRoleArn`` IAM Role ARN for the ``Model``, which is also used during + transform jobs. If not specified, the role from the Estimator will be used. model_server_workers (int): Optional. The number of worker processes used by the inference server. If None, server will use one worker per vCPU. @@ -74,7 +76,8 @@ def create_model(self, model_server_workers=None): sagemaker.pytorch.model.PyTorchModel: A SageMaker ``PyTorchModel`` object. See :func:`~sagemaker.pytorch.model.PyTorchModel` for full details. """ - return PyTorchModel(self.model_data, self.role, self.entry_point, source_dir=self._model_source_dir(), + role = role or self.role + return PyTorchModel(self.model_data, role, self.entry_point, source_dir=self._model_source_dir(), enable_cloudwatch_metrics=self.enable_cloudwatch_metrics, name=self._current_job_name, container_log_level=self.container_log_level, code_location=self.code_location, py_version=self.py_version, framework_version=self.framework_version, image=self.image_name, diff --git a/src/sagemaker/tensorflow/estimator.py b/src/sagemaker/tensorflow/estimator.py index 88254e6d31..b489be57bc 100644 --- a/src/sagemaker/tensorflow/estimator.py +++ b/src/sagemaker/tensorflow/estimator.py @@ -289,10 +289,12 @@ def _prepare_init_params_from_job_description(cls, job_details): return init_params - def create_model(self, model_server_workers=None): + def create_model(self, model_server_workers=None, role=None): """Create a SageMaker ``TensorFlowModel`` object that can be deployed to an ``Endpoint``. Args: + role (str): The ``ExecutionRoleArn`` IAM Role ARN for the ``Model``, which is also used during + transform jobs. If not specified, the role from the Estimator will be used. model_server_workers (int): Optional. The number of worker processes used by the inference server. If None, server will use one worker per vCPU. @@ -301,7 +303,8 @@ def create_model(self, model_server_workers=None): See :func:`~sagemaker.tensorflow.model.TensorFlowModel` for full details. """ env = {'SAGEMAKER_REQUIREMENTS': self.requirements_file} - return TensorFlowModel(self.model_data, self.role, self.entry_point, source_dir=self._model_source_dir(), + role = role or self.role + return TensorFlowModel(self.model_data, role, self.entry_point, source_dir=self._model_source_dir(), enable_cloudwatch_metrics=self.enable_cloudwatch_metrics, env=env, image=self.image_name, name=self._current_job_name, container_log_level=self.container_log_level, code_location=self.code_location, py_version=self.py_version, diff --git a/tests/unit/test_chainer.py b/tests/unit/test_chainer.py index 532c254334..e9a24ab78c 100644 --- a/tests/unit/test_chainer.py +++ b/tests/unit/test_chainer.py @@ -233,7 +233,7 @@ def test_create_model(sagemaker_session, chainer_version): enable_cloudwatch_metrics=enable_cloudwatch_metrics) job_name = 'new_name' - chainer.fit(inputs='s3://mybucket/train', job_name='new_name') + chainer.fit(inputs='s3://mybucket/train', job_name=job_name) model = chainer.create_model() assert model.sagemaker_session == sagemaker_session @@ -247,6 +247,25 @@ def test_create_model(sagemaker_session, chainer_version): assert model.enable_cloudwatch_metrics == enable_cloudwatch_metrics +def test_create_model_with_optional_params(sagemaker_session): + container_log_level = '"logging.INFO"' + source_dir = 's3://mybucket/source' + enable_cloudwatch_metrics = 'true' + chainer = Chainer(entry_point=SCRIPT_PATH, role=ROLE, sagemaker_session=sagemaker_session, + train_instance_count=INSTANCE_COUNT, train_instance_type=INSTANCE_TYPE, + container_log_level=container_log_level, py_version=PYTHON_VERSION, base_job_name='job', + source_dir=source_dir, enable_cloudwatch_metrics=enable_cloudwatch_metrics) + + chainer.fit(inputs='s3://mybucket/train', job_name='new_name') + + new_role = 'role' + model_server_workers = 2 + model = chainer.create_model(role=new_role, model_server_workers=model_server_workers) + + assert model.role == new_role + assert model.model_server_workers == model_server_workers + + def test_create_model_with_custom_image(sagemaker_session): container_log_level = '"logging.INFO"' source_dir = 's3://mybucket/source' diff --git a/tests/unit/test_estimator.py b/tests/unit/test_estimator.py index 8831d35b76..574dbed629 100644 --- a/tests/unit/test_estimator.py +++ b/tests/unit/test_estimator.py @@ -86,7 +86,7 @@ class DummyFramework(Framework): def train_image(self): return IMAGE_NAME - def create_model(self, model_server_workers=None): + def create_model(self, role=None, model_server_workers=None): return DummyFrameworkModel(self.sagemaker_session) @classmethod @@ -476,8 +476,6 @@ def test_framework_transformer_creation_with_optional_params(name_from_image, sa base_job_name=base_name) fw.latest_training_job = _TrainingJob(sagemaker_session, JOB_NAME) - transformer = fw.transformer(INSTANCE_COUNT, INSTANCE_TYPE) - strategy = 'MultiRecord' assemble_with = 'Line' kms_key = 'key' @@ -485,12 +483,14 @@ def test_framework_transformer_creation_with_optional_params(name_from_image, sa max_concurrent_transforms = 1 max_payload = 6 env = {'FOO': 'BAR'} + new_role = 'dummy-model-role' transformer = fw.transformer(INSTANCE_COUNT, INSTANCE_TYPE, strategy=strategy, assemble_with=assemble_with, output_path=OUTPUT_PATH, output_kms_key=kms_key, accept=accept, tags=TAGS, max_concurrent_transforms=max_concurrent_transforms, max_payload=max_payload, - env=env, model_server_workers=1) + env=env, role=new_role, model_server_workers=1) + sagemaker_session.create_model.assert_called_with(MODEL_IMAGE, new_role, MODEL_CONTAINER_DEF) assert transformer.strategy == strategy assert transformer.assemble_with == assemble_with assert transformer.output_path == OUTPUT_PATH @@ -528,7 +528,7 @@ def test_estimator_transformer_creation(sagemaker_session): transformer = estimator.transformer(INSTANCE_COUNT, INSTANCE_TYPE) - sagemaker_session.create_model_from_job.assert_called_with(JOB_NAME) + sagemaker_session.create_model_from_job.assert_called_with(JOB_NAME, role=None) assert isinstance(transformer, Transformer) assert transformer.sagemaker_session == sagemaker_session assert transformer.instance_count == INSTANCE_COUNT @@ -556,8 +556,9 @@ def test_estimator_transformer_creation_with_optional_params(sagemaker_session): transformer = estimator.transformer(INSTANCE_COUNT, INSTANCE_TYPE, strategy=strategy, assemble_with=assemble_with, output_path=OUTPUT_PATH, output_kms_key=kms_key, accept=accept, tags=TAGS, max_concurrent_transforms=max_concurrent_transforms, max_payload=max_payload, - env=env) + env=env, role=ROLE) + sagemaker_session.create_model_from_job.assert_called_with(JOB_NAME, role=ROLE) assert transformer.strategy == strategy assert transformer.assemble_with == assemble_with assert transformer.output_path == OUTPUT_PATH diff --git a/tests/unit/test_mxnet.py b/tests/unit/test_mxnet.py index b07057f62b..16d7e49f25 100644 --- a/tests/unit/test_mxnet.py +++ b/tests/unit/test_mxnet.py @@ -108,7 +108,7 @@ def test_create_model(sagemaker_session, mxnet_version): base_job_name='job', source_dir=source_dir, enable_cloudwatch_metrics=enable_cloudwatch_metrics) job_name = 'new_name' - mx.fit(inputs='s3://mybucket/train', job_name='new_name') + mx.fit(inputs='s3://mybucket/train', job_name=job_name) model = mx.create_model() assert model.sagemaker_session == sagemaker_session @@ -122,6 +122,25 @@ def test_create_model(sagemaker_session, mxnet_version): assert model.enable_cloudwatch_metrics == enable_cloudwatch_metrics +def test_create_model_with_optional_params(sagemaker_session): + container_log_level = '"logging.INFO"' + source_dir = 's3://mybucket/source' + enable_cloudwatch_metrics = 'true' + mx = MXNet(entry_point=SCRIPT_PATH, role=ROLE, sagemaker_session=sagemaker_session, + train_instance_count=INSTANCE_COUNT, train_instance_type=INSTANCE_TYPE, + container_log_level=container_log_level, base_job_name='job', source_dir=source_dir, + enable_cloudwatch_metrics=enable_cloudwatch_metrics) + + mx.fit(inputs='s3://mybucket/train', job_name='new_name') + + new_role = 'role' + model_server_workers = 2 + model = mx.create_model(role=new_role, model_server_workers=model_server_workers) + + assert model.role == new_role + assert model.model_server_workers == model_server_workers + + def test_create_model_with_custom_image(sagemaker_session): container_log_level = '"logging.INFO"' source_dir = 's3://mybucket/source' diff --git a/tests/unit/test_pytorch.py b/tests/unit/test_pytorch.py index 92015ce103..0b11dcee25 100644 --- a/tests/unit/test_pytorch.py +++ b/tests/unit/test_pytorch.py @@ -140,6 +140,25 @@ def test_create_model(sagemaker_session, pytorch_version): assert model.enable_cloudwatch_metrics == enable_cloudwatch_metrics +def test_create_model_with_optional_params(sagemaker_session): + container_log_level = '"logging.INFO"' + source_dir = 's3://mybucket/source' + enable_cloudwatch_metrics = 'true' + pytorch = PyTorch(entry_point=SCRIPT_PATH, role=ROLE, sagemaker_session=sagemaker_session, + train_instance_count=INSTANCE_COUNT, train_instance_type=INSTANCE_TYPE, + container_log_level=container_log_level, base_job_name='job', source_dir=source_dir, + enable_cloudwatch_metrics=enable_cloudwatch_metrics) + + pytorch.fit(inputs='s3://mybucket/train', job_name='new_name') + + new_role = 'role' + model_server_workers = 2 + model = pytorch.create_model(role=new_role, model_server_workers=model_server_workers) + + assert model.role == new_role + assert model.model_server_workers == model_server_workers + + def test_create_model_with_custom_image(sagemaker_session): container_log_level = '"logging.INFO"' source_dir = 's3://mybucket/source' diff --git a/tests/unit/test_tf_estimator.py b/tests/unit/test_tf_estimator.py index 0a155ce341..a6969b840b 100644 --- a/tests/unit/test_tf_estimator.py +++ b/tests/unit/test_tf_estimator.py @@ -205,6 +205,26 @@ def test_create_model(sagemaker_session, tf_version): assert model.enable_cloudwatch_metrics == enable_cloudwatch_metrics +def test_create_model_with_optional_params(sagemaker_session): + container_log_level = '"logging.INFO"' + source_dir = 's3://mybucket/source' + enable_cloudwatch_metrics = 'true' + tf = TensorFlow(entry_point=SCRIPT_PATH, role=ROLE, sagemaker_session=sagemaker_session, + training_steps=1000, evaluation_steps=10, train_instance_count=INSTANCE_COUNT, + train_instance_type=INSTANCE_TYPE, container_log_level=container_log_level, base_job_name='job', + source_dir=source_dir, enable_cloudwatch_metrics=enable_cloudwatch_metrics) + + job_name = 'doing something' + tf.fit(inputs='s3://mybucket/train', job_name=job_name) + + new_role = 'role' + model_server_workers = 2 + model = tf.create_model(role=new_role, model_server_workers=2) + + assert model.role == new_role + assert model.model_server_workers == model_server_workers + + def test_create_model_with_custom_image(sagemaker_session): container_log_level = '"logging.INFO"' source_dir = 's3://mybucket/source'