Skip to content

Commit 8b33a30

Browse files
pass accelerator_type for tfs model (#667)
1 parent d2430a1 commit 8b33a30

File tree

6 files changed

+16
-5
lines changed

6 files changed

+16
-5
lines changed

CHANGELOG.rst

+1
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ CHANGELOG
99
* feature: ``Predictor``: delete SageMaker model
1010
* feature: ``Pipeline``: delete SageMaker model
1111
* bug-fix: Estimator.attach works with training jobs without hyperparameters
12+
* bug-fix: pass accelerator_type in ``deploy`` for REST API TFS ``Model``
1213

1314
1.18.3.post1
1415
============

src/sagemaker/fw_utils.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ def _accelerator_type_valid_for_framework(framework, accelerator_type=None, opti
106106

107107
if framework not in VALID_EIA_FRAMEWORKS:
108108
raise ValueError('{} is not supported with Amazon Elastic Inference. Currently only '
109-
'TensorFlow and MXNet are supported for SageMaker.'.format(framework))
109+
'Python-based TensorFlow and MXNet are supported.'.format(framework))
110110

111111
if optimized_families:
112112
raise ValueError('Neo does not support Amazon Elastic Inference.')

src/sagemaker/tensorflow/deploying_python.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ like this:
2525
2626
The code block above deploys a SageMaker Endpoint with one instance of the type 'ml.c4.xlarge'.
2727

28-
TensorFlow serving on SageMaker has support for `Elastic Inference <https://docs.aws.amazon.com/sagemaker/latest/dg/ei.html>`_, which allows for inference acceleration to a hosted endpoint for a fraction of the cost of using a full GPU instance. In order to attach an Elastic Inference accelerator to your endpoint provide the accelerator type to ``accelerator_type`` to your ``deploy`` call.
28+
Python-based TensorFlow serving on SageMaker has support for `Elastic Inference <https://docs.aws.amazon.com/sagemaker/latest/dg/ei.html>`_, which allows for inference acceleration to a hosted endpoint for a fraction of the cost of using a full GPU instance. In order to attach an Elastic Inference accelerator to your endpoint provide the accelerator type to ``accelerator_type`` to your ``deploy`` call.
2929

3030
.. code:: python
3131

src/sagemaker/tensorflow/deploying_tensorflow_serving.rst

+2
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@ estimator object to create a SageMaker Endpoint:
3434
3535
The code block above deploys a SageMaker Endpoint with one instance of the type 'ml.c5.xlarge'.
3636

37+
As of now, only the Python-based TensorFlow serving endpoints support Elastic Inference. For more information, see `Deploying to Python-based Endpoints <https://github.com/aws/sagemaker-python-sdk/blob/master/src/sagemaker/tensorflow/deploying_python.rst#deploying-to-python-based-endpoints>`_.
38+
3739
What happens when deploy is called
3840
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
3941

src/sagemaker/tensorflow/serving.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ def __init__(self, model_data, role, image=None, framework_version=TF_VERSION,
123123
self._container_log_level = container_log_level
124124

125125
def prepare_container_def(self, instance_type, accelerator_type=None):
126-
image = self._get_image_uri(instance_type)
126+
image = self._get_image_uri(instance_type, accelerator_type)
127127
env = self._get_container_env()
128128
return sagemaker.container_def(image, self.model_data, env)
129129

@@ -139,10 +139,10 @@ def _get_container_env(self):
139139
env[Model.LOG_LEVEL_PARAM_NAME] = Model.LOG_LEVEL_MAP[self._container_log_level]
140140
return env
141141

142-
def _get_image_uri(self, instance_type):
142+
def _get_image_uri(self, instance_type, accelerator_type=None):
143143
if self.image:
144144
return self.image
145145

146146
region_name = self.sagemaker_session.boto_region_name
147147
return create_image_uri(region_name, Model.FRAMEWORK_NAME, instance_type,
148-
self._framework_version)
148+
self._framework_version, accelerator_type=accelerator_type)

tests/unit/test_tfs.py

+8
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
CSV_CONTENT_TYPE = 'text/csv'
2626
INSTANCE_COUNT = 1
2727
INSTANCE_TYPE = 'ml.c4.4xlarge'
28+
ACCELERATOR_TYPE = 'ml.eia.medium'
2829
ROLE = 'Dummy'
2930
REGION = 'us-west-2'
3031
PREDICT_INPUT = {'instances': [1.0, 2.0, 5.0]}
@@ -75,6 +76,13 @@ def test_tfs_model(sagemaker_session, tf_version):
7576
assert isinstance(predictor, Predictor)
7677

7778

79+
def test_tfs_model_image_accelerator(sagemaker_session, tf_version):
80+
model = Model("s3://some/data.tar.gz", role=ROLE, framework_version=tf_version,
81+
sagemaker_session=sagemaker_session)
82+
with pytest.raises(ValueError):
83+
model.prepare_container_def(INSTANCE_TYPE, accelerator_type=ACCELERATOR_TYPE)
84+
85+
7886
def test_tfs_model_with_log_level(sagemaker_session, tf_version):
7987
model = Model("s3://some/data.tar.gz", role=ROLE, framework_version=tf_version,
8088
container_log_level=logging.INFO,

0 commit comments

Comments
 (0)