Skip to content

pass accelerator_type for tfs model #667

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Feb 26, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ CHANGELOG
* feature: ``Predictor``: delete SageMaker model
* feature: ``Pipeline``: delete SageMaker model
* bug-fix: Estimator.attach works with training jobs without hyperparameters
* bug-fix: pass accelerator_type in ``deploy`` for REST API TFS ``Model``

1.18.3.post1
============
Expand Down
2 changes: 1 addition & 1 deletion src/sagemaker/fw_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def _accelerator_type_valid_for_framework(framework, accelerator_type=None, opti

if framework not in VALID_EIA_FRAMEWORKS:
raise ValueError('{} is not supported with Amazon Elastic Inference. Currently only '
'TensorFlow and MXNet are supported for SageMaker.'.format(framework))
'Python-based TensorFlow and MXNet are supported.'.format(framework))

if optimized_families:
raise ValueError('Neo does not support Amazon Elastic Inference.')
Expand Down
2 changes: 1 addition & 1 deletion src/sagemaker/tensorflow/deploying_python.rst
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ like this:

The code block above deploys a SageMaker Endpoint with one instance of the type 'ml.c4.xlarge'.

TensorFlow serving on SageMaker has support for `Elastic Inference <https://docs.aws.amazon.com/sagemaker/latest/dg/ei.html>`_, which allows for inference acceleration to a hosted endpoint for a fraction of the cost of using a full GPU instance. In order to attach an Elastic Inference accelerator to your endpoint provide the accelerator type to ``accelerator_type`` to your ``deploy`` call.
Python-based TensorFlow serving on SageMaker has support for `Elastic Inference <https://docs.aws.amazon.com/sagemaker/latest/dg/ei.html>`_, which allows for inference acceleration to a hosted endpoint for a fraction of the cost of using a full GPU instance. In order to attach an Elastic Inference accelerator to your endpoint provide the accelerator type to ``accelerator_type`` to your ``deploy`` call.

.. code:: python

Expand Down
2 changes: 2 additions & 0 deletions src/sagemaker/tensorflow/deploying_tensorflow_serving.rst
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ estimator object to create a SageMaker Endpoint:

The code block above deploys a SageMaker Endpoint with one instance of the type 'ml.c5.xlarge'.

As of now, only the Python-based TensorFlow serving endpoints support Elastic Inference. For more information, see `Deploying to Python-based Endpoints <https://github.com/aws/sagemaker-python-sdk/blob/master/src/sagemaker/tensorflow/deploying_python.rst#deploying-to-python-based-endpoints>`_.

What happens when deploy is called
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Expand Down
6 changes: 3 additions & 3 deletions src/sagemaker/tensorflow/serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def __init__(self, model_data, role, image=None, framework_version=TF_VERSION,
self._container_log_level = container_log_level

def prepare_container_def(self, instance_type, accelerator_type=None):
image = self._get_image_uri(instance_type)
image = self._get_image_uri(instance_type, accelerator_type)
env = self._get_container_env()
return sagemaker.container_def(image, self.model_data, env)

Expand All @@ -139,10 +139,10 @@ def _get_container_env(self):
env[Model.LOG_LEVEL_PARAM_NAME] = Model.LOG_LEVEL_MAP[self._container_log_level]
return env

def _get_image_uri(self, instance_type):
def _get_image_uri(self, instance_type, accelerator_type=None):
if self.image:
return self.image

region_name = self.sagemaker_session.boto_region_name
return create_image_uri(region_name, Model.FRAMEWORK_NAME, instance_type,
self._framework_version)
self._framework_version, accelerator_type=accelerator_type)
8 changes: 8 additions & 0 deletions tests/unit/test_tfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
CSV_CONTENT_TYPE = 'text/csv'
INSTANCE_COUNT = 1
INSTANCE_TYPE = 'ml.c4.4xlarge'
ACCELERATOR_TYPE = 'ml.eia.medium'
ROLE = 'Dummy'
REGION = 'us-west-2'
PREDICT_INPUT = {'instances': [1.0, 2.0, 5.0]}
Expand Down Expand Up @@ -75,6 +76,13 @@ def test_tfs_model(sagemaker_session, tf_version):
assert isinstance(predictor, Predictor)


def test_tfs_model_image_accelerator(sagemaker_session, tf_version):
model = Model("s3://some/data.tar.gz", role=ROLE, framework_version=tf_version,
sagemaker_session=sagemaker_session)
with pytest.raises(ValueError):
model.prepare_container_def(INSTANCE_TYPE, accelerator_type=ACCELERATOR_TYPE)


def test_tfs_model_with_log_level(sagemaker_session, tf_version):
model = Model("s3://some/data.tar.gz", role=ROLE, framework_version=tf_version,
container_log_level=logging.INFO,
Expand Down