Skip to content

Commit 6872918

Browse files
committed
breaking: rename sagemaker.tensorflow.serving.Model/Predictor to sagemaker.tensorflow.model.TensorFlowModel/Predictor
This commit also changes two attributes to match the other framework model classes: * FRAMEWORK_NAME --> __framework_name__ * _framework_version --> framework_version
1 parent d0eb4a2 commit 6872918

File tree

8 files changed

+87
-82
lines changed

8 files changed

+87
-82
lines changed

src/sagemaker/cli/tensorflow.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -68,9 +68,9 @@ def create_model(self, model_url):
6868
Args:
6969
model_url:
7070
"""
71-
from sagemaker.tensorflow.serving import Model
71+
from sagemaker.tensorflow.model import TensorFlowModel
7272

73-
return Model(
73+
return TensorFlowModel(
7474
model_data=model_url,
7575
role=self.role_name,
7676
entry_point=self.script,

src/sagemaker/rl/estimator.py

+9-10
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import sagemaker.fw_utils as fw_utils
2222
from sagemaker.model import FrameworkModel, SAGEMAKER_OUTPUT_LOCATION
2323
from sagemaker.mxnet.model import MXNetModel
24+
from sagemaker.tensorflow.model import TensorFlowModel
2425
from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT
2526

2627
logger = logging.getLogger("sagemaker")
@@ -90,7 +91,7 @@ def __init__(
9091
:meth:`~sagemaker.amazon.estimator.Framework.deploy` creates a hosted
9192
SageMaker endpoint and based on the specified framework returns an
9293
:class:`~sagemaker.amazon.mxnet.model.MXNetPredictor` or
93-
:class:`~sagemaker.amazon.tensorflow.serving.Predictor` instance that
94+
:class:`~sagemaker.amazon.tensorflow.model.TensorFlowPredictor` instance that
9495
can be used to perform inference against the hosted model.
9596
9697
Technical documentation on preparing RLEstimator scripts for
@@ -205,15 +206,15 @@ def create_model(
205206
sagemaker.model.FrameworkModel: Depending on input parameters returns
206207
one of the following:
207208
208-
* :class:`~sagemaker.model.FrameworkModel` - if ``image_name`` was specified
209+
* :class:`~sagemaker.model.FrameworkModel` - if ``image_name`` is specified
209210
on the estimator;
210-
* :class:`~sagemaker.mxnet.MXNetModel` - if ``image_name`` wasn't specified and
211-
MXNet was used as the RL backend;
212-
* :class:`~sagemaker.tensorflow.serving.Model` - if ``image_name`` wasn't specified
213-
and TensorFlow was used as the RL backend.
211+
* :class:`~sagemaker.mxnet.MXNetModel` - if ``image_name`` isn't specified and
212+
MXNet is used as the RL backend;
213+
* :class:`~sagemaker.tensorflow.model.TensorFlowModel` - if ``image_name`` isn't
214+
specified and TensorFlow is used as the RL backend.
214215
215216
Raises:
216-
ValueError: If image_name was not specified and framework enum is not valid.
217+
ValueError: If image_name is not specified and framework enum is not valid.
217218
"""
218219
base_args = dict(
219220
model_data=self.model_data,
@@ -252,9 +253,7 @@ def create_model(
252253
)
253254

254255
if self.framework == RLFramework.TENSORFLOW.value:
255-
from sagemaker.tensorflow.serving import Model as tfsModel
256-
257-
return tfsModel(framework_version=self.framework_version, **base_args)
256+
return TensorFlowModel(framework_version=self.framework_version, **base_args)
258257
if self.framework == RLFramework.MXNET.value:
259258
return MXNetModel(
260259
framework_version=self.framework_version, py_version=PYTHON_VERSION, **extended_args

src/sagemaker/tensorflow/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,4 @@
1414
from __future__ import absolute_import
1515

1616
from sagemaker.tensorflow.estimator import TensorFlow # noqa: F401 (imported but unused)
17+
from sagemaker.tensorflow.model import TensorFlowModel, TensorFlowPredictor # noqa: F401

src/sagemaker/tensorflow/estimator.py

+13-12
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from sagemaker.estimator import Framework
2424
import sagemaker.fw_utils as fw
2525
from sagemaker.tensorflow import defaults
26-
from sagemaker.tensorflow.serving import Model
26+
from sagemaker.tensorflow.model import TensorFlowModel
2727
from sagemaker.transformer import Transformer
2828
from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT
2929

@@ -249,12 +249,13 @@ def create_model(
249249
dependencies=None,
250250
**kwargs
251251
):
252-
"""Create a ``Model`` object that can be used for creating SageMaker model entities,
253-
deploying to a SageMaker endpoint, or starting SageMaker Batch Transform jobs.
252+
"""Create a ``TensorFlowModel`` object that can be used for creating
253+
SageMaker model entities, deploying to a SageMaker endpoint, or
254+
starting SageMaker Batch Transform jobs.
254255
255256
Args:
256-
role (str): The ``ExecutionRoleArn`` IAM Role ARN for the ``Model``, which is also
257-
used during transform jobs. If not specified, the role from the Estimator is used.
257+
role (str): The ``TensorFlowModel``, which is also used during transform jobs.
258+
If not specified, the role from the Estimator is used.
258259
vpc_config_override (dict[str, list[str]]): Optional override for VpcConfig set on the
259260
model. Default: use subnets and security groups from this Estimator.
260261
@@ -267,11 +268,12 @@ def create_model(
267268
source code dependencies aside from the entry point file (default: None).
268269
dependencies (list[str]): A list of paths to directories (absolute or relative) with
269270
any additional libraries that will be exported to the container (default: None).
270-
**kwargs: Additional kwargs passed to :class:`~sagemaker.tensorflow.serving.Model`.
271+
**kwargs: Additional kwargs passed to
272+
:class:`~sagemaker.tensorflow.model.TensorFlowModel`.
271273
272274
Returns:
273-
sagemaker.tensorflow.serving.Model: A ``Model`` object.
274-
See :class:`~sagemaker.tensorflow.serving.Model` for full details.
275+
sagemaker.tensorflow.model.TensorFlowModel: A ``TensorFlowModel`` object.
276+
See :class:`~sagemaker.tensorflow.model.TensorFlowModel` for full details.
275277
"""
276278
if "image" not in kwargs:
277279
kwargs["image"] = self.image_name
@@ -282,7 +284,7 @@ def create_model(
282284
if "enable_network_isolation" not in kwargs:
283285
kwargs["enable_network_isolation"] = self.enable_network_isolation()
284286

285-
return Model(
287+
return TensorFlowModel(
286288
model_data=self.model_data,
287289
role=role or self.role,
288290
container_log_level=self.container_log_level,
@@ -418,9 +420,8 @@ def transformer(
418420
container in MB.
419421
tags (list[dict]): List of tags for labeling a transform job. If none specified, then
420422
the tags used for the training job are used for the transform job.
421-
role (str): The ``ExecutionRoleArn`` IAM Role ARN for the ``Model``, which is also
422-
used during transform jobs. If not specified, the role from the Estimator will be
423-
used.
423+
role (str): The IAM Role ARN for the ``TensorFlowModel``, which is also used
424+
during transform jobs. If not specified, the role from the Estimator is used.
424425
volume_kms_key (str): Optional. KMS key ID for encrypting the volume attached to the ML
425426
compute instance (default: None).
426427
entry_point (str): Path (absolute or relative) to the local Python source file which

src/sagemaker/tensorflow/serving.py renamed to src/sagemaker/tensorflow/model.py

+19-19
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
1111
# ANY KIND, either express or implied. See the License for the specific
1212
# language governing permissions and limitations under the License.
13-
"""Placeholder docstring"""
13+
"""Classes for using TensorFlow on Amazon SageMaker for inference."""
1414
from __future__ import absolute_import
1515

1616
import logging
@@ -22,7 +22,7 @@
2222
from sagemaker.tensorflow.defaults import TF_VERSION
2323

2424

25-
class Predictor(sagemaker.RealTimePredictor):
25+
class TensorFlowPredictor(sagemaker.RealTimePredictor):
2626
"""A ``RealTimePredictor`` implementation for inference against TensorFlow
2727
Serving endpoints.
2828
"""
@@ -37,7 +37,7 @@ def __init__(
3737
model_name=None,
3838
model_version=None,
3939
):
40-
"""Initialize a ``TFSPredictor``. See ``sagemaker.RealTimePredictor``
40+
"""Initialize a ``TensorFlowPredictor``. See :class:`~sagemaker.predictor.RealTimePredictor`
4141
for more info about parameters.
4242
4343
Args:
@@ -61,7 +61,7 @@ def __init__(
6161
that should handle the request. If not specified, the latest
6262
version of the model will be used.
6363
"""
64-
super(Predictor, self).__init__(
64+
super(TensorFlowPredictor, self).__init__(
6565
endpoint_name, sagemaker_session, serializer, deserializer, content_type
6666
)
6767

@@ -115,13 +115,13 @@ def predict(self, data, initial_args=None):
115115
else:
116116
args["CustomAttributes"] = self._model_attributes
117117

118-
return super(Predictor, self).predict(data, args)
118+
return super(TensorFlowPredictor, self).predict(data, args)
119119

120120

121-
class Model(sagemaker.model.FrameworkModel):
122-
"""Placeholder docstring"""
121+
class TensorFlowModel(sagemaker.model.FrameworkModel):
122+
"""A ``FrameworkModel`` implementation for inference with TensorFlow Serving."""
123123

124-
FRAMEWORK_NAME = "tensorflow-serving"
124+
__framework_name__ = "tensorflow-serving"
125125
LOG_LEVEL_PARAM_NAME = "SAGEMAKER_TFS_NGINX_LOGLEVEL"
126126
LOG_LEVEL_MAP = {
127127
logging.DEBUG: "debug",
@@ -140,7 +140,7 @@ def __init__(
140140
image=None,
141141
framework_version=TF_VERSION,
142142
container_log_level=None,
143-
predictor_cls=Predictor,
143+
predictor_cls=TensorFlowPredictor,
144144
**kwargs
145145
):
146146
"""Initialize a Model.
@@ -171,15 +171,15 @@ def __init__(
171171
:class:`~sagemaker.model.FrameworkModel` and
172172
:class:`~sagemaker.model.Model`.
173173
"""
174-
super(Model, self).__init__(
174+
super(TensorFlowModel, self).__init__(
175175
model_data=model_data,
176176
role=role,
177177
image=image,
178178
predictor_cls=predictor_cls,
179179
entry_point=entry_point,
180180
**kwargs
181181
)
182-
self._framework_version = framework_version
182+
self.framework_version = framework_version
183183
self._container_log_level = container_log_level
184184

185185
def deploy(
@@ -196,10 +196,10 @@ def deploy(
196196
):
197197

198198
if accelerator_type and not self._eia_supported():
199-
msg = "The TensorFlow version %s doesn't support EIA." % self._framework_version
200-
199+
msg = "The TensorFlow version %s doesn't support EIA." % self.framework_version
201200
raise AttributeError(msg)
202-
return super(Model, self).deploy(
201+
202+
return super(TensorFlowModel, self).deploy(
203203
initial_instance_count=initial_instance_count,
204204
instance_type=instance_type,
205205
accelerator_type=accelerator_type,
@@ -213,7 +213,7 @@ def deploy(
213213

214214
def _eia_supported(self):
215215
"""Return true if TF version is EIA enabled"""
216-
return [int(s) for s in self._framework_version.split(".")][:2] <= self.LATEST_EIA_VERSION
216+
return [int(s) for s in self.framework_version.split(".")][:2] <= self.LATEST_EIA_VERSION
217217

218218
def prepare_container_def(self, instance_type, accelerator_type=None):
219219
"""
@@ -249,12 +249,12 @@ def _get_container_env(self):
249249
if not self._container_log_level:
250250
return self.env
251251

252-
if self._container_log_level not in Model.LOG_LEVEL_MAP:
252+
if self._container_log_level not in self.LOG_LEVEL_MAP:
253253
logging.warning("ignoring invalid container log level: %s", self._container_log_level)
254254
return self.env
255255

256256
env = dict(self.env)
257-
env[Model.LOG_LEVEL_PARAM_NAME] = Model.LOG_LEVEL_MAP[self._container_log_level]
257+
env[self.LOG_LEVEL_PARAM_NAME] = self.LOG_LEVEL_MAP[self._container_log_level]
258258
return env
259259

260260
def _get_image_uri(self, instance_type, accelerator_type=None):
@@ -269,9 +269,9 @@ def _get_image_uri(self, instance_type, accelerator_type=None):
269269
region_name = self.sagemaker_session.boto_region_name
270270
return create_image_uri(
271271
region_name,
272-
Model.FRAMEWORK_NAME,
272+
self.__framework_name__,
273273
instance_type,
274-
self._framework_version,
274+
self.framework_version,
275275
accelerator_type=accelerator_type,
276276
)
277277

tests/unit/sagemaker/tensorflow/test_estimator.py

+11-11
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
import pytest
2222

2323
from sagemaker.estimator import _TrainingJob
24-
from sagemaker.tensorflow import defaults, serving, TensorFlow
24+
from sagemaker.tensorflow import defaults, model, TensorFlow
2525
from tests.unit import DATA_DIR
2626

2727
SCRIPT_FILE = "dummy_script.py"
@@ -188,7 +188,7 @@ def test_create_model(sagemaker_session, tf_version):
188188
model = tf.create_model()
189189

190190
assert model.sagemaker_session == sagemaker_session
191-
assert model._framework_version == tf_version
191+
assert model.framework_version == tf_version
192192
assert model.entry_point is None
193193
assert model.role == ROLE
194194
assert model.name == job_name
@@ -372,17 +372,17 @@ def test_script_mode_create_model(sagemaker_session):
372372
)
373373
tf._prepare_for_training() # set output_path and job name as if training happened
374374

375-
model = tf.create_model()
375+
tf_model = tf.create_model()
376376

377-
assert isinstance(model, serving.Model)
377+
assert isinstance(tf_model, model.TensorFlowModel)
378378

379-
assert model.model_data == tf.model_data
380-
assert model.role == tf.role
381-
assert model.name == tf._current_job_name
382-
assert model.container_log_level == tf.container_log_level
383-
assert model._framework_version == "1.11"
384-
assert model.sagemaker_session == sagemaker_session
385-
assert model.enable_network_isolation()
379+
assert tf_model.model_data == tf.model_data
380+
assert tf_model.role == tf.role
381+
assert tf_model.name == tf._current_job_name
382+
assert tf_model.container_log_level == tf.container_log_level
383+
assert tf_model.framework_version == "1.11"
384+
assert tf_model.sagemaker_session == sagemaker_session
385+
assert tf_model.enable_network_isolation()
386386

387387

388388
@patch("time.strftime", return_value=TIMESTAMP)

0 commit comments

Comments
 (0)