Skip to content

Commit 778a4ee

Browse files
authored
breaking: rename sagemaker.tensorflow.serving to sagemaker.tensorflow.model (#1541)
This changes the following two classes: * sagemaker.tensorflow.serving.Model --> sagemaker.tensorflow.model.TensorFlowModel * sagemaker.tensorflow.serving.Predictor --> sagemaker.tensorflow.model.TensorFlowPredictor This commit also changes two attributes to match the other framework model classes: * FRAMEWORK_NAME --> __framework_name__ * _framework_version --> framework_version
1 parent d0eb4a2 commit 778a4ee

File tree

13 files changed

+102
-97
lines changed

13 files changed

+102
-97
lines changed

doc/sagemaker.tensorflow.rst

+2-2
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,15 @@ TensorFlow Estimator
1313
TensorFlow Serving Model
1414
------------------------
1515

16-
.. autoclass:: sagemaker.tensorflow.serving.Model
16+
.. autoclass:: sagemaker.tensorflow.model.TensorFlowModel
1717
:members:
1818
:undoc-members:
1919
:show-inheritance:
2020

2121
TensorFlow Serving Predictor
2222
----------------------------
2323

24-
.. autoclass:: sagemaker.tensorflow.serving.Predictor
24+
.. autoclass:: sagemaker.tensorflow.model.TensorFlowPredictor
2525
:members:
2626
:undoc-members:
2727
:show-inheritance:

src/sagemaker/cli/tensorflow.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -68,9 +68,9 @@ def create_model(self, model_url):
6868
Args:
6969
model_url:
7070
"""
71-
from sagemaker.tensorflow.serving import Model
71+
from sagemaker.tensorflow.model import TensorFlowModel
7272

73-
return Model(
73+
return TensorFlowModel(
7474
model_data=model_url,
7575
role=self.role_name,
7676
entry_point=self.script,

src/sagemaker/rl/estimator.py

+9-10
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import sagemaker.fw_utils as fw_utils
2222
from sagemaker.model import FrameworkModel, SAGEMAKER_OUTPUT_LOCATION
2323
from sagemaker.mxnet.model import MXNetModel
24+
from sagemaker.tensorflow.model import TensorFlowModel
2425
from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT
2526

2627
logger = logging.getLogger("sagemaker")
@@ -90,7 +91,7 @@ def __init__(
9091
:meth:`~sagemaker.amazon.estimator.Framework.deploy` creates a hosted
9192
SageMaker endpoint and based on the specified framework returns an
9293
:class:`~sagemaker.amazon.mxnet.model.MXNetPredictor` or
93-
:class:`~sagemaker.amazon.tensorflow.serving.Predictor` instance that
94+
:class:`~sagemaker.amazon.tensorflow.model.TensorFlowPredictor` instance that
9495
can be used to perform inference against the hosted model.
9596
9697
Technical documentation on preparing RLEstimator scripts for
@@ -205,15 +206,15 @@ def create_model(
205206
sagemaker.model.FrameworkModel: Depending on input parameters returns
206207
one of the following:
207208
208-
* :class:`~sagemaker.model.FrameworkModel` - if ``image_name`` was specified
209+
* :class:`~sagemaker.model.FrameworkModel` - if ``image_name`` is specified
209210
on the estimator;
210-
* :class:`~sagemaker.mxnet.MXNetModel` - if ``image_name`` wasn't specified and
211-
MXNet was used as the RL backend;
212-
* :class:`~sagemaker.tensorflow.serving.Model` - if ``image_name`` wasn't specified
213-
and TensorFlow was used as the RL backend.
211+
* :class:`~sagemaker.mxnet.MXNetModel` - if ``image_name`` isn't specified and
212+
MXNet is used as the RL backend;
213+
* :class:`~sagemaker.tensorflow.model.TensorFlowModel` - if ``image_name`` isn't
214+
specified and TensorFlow is used as the RL backend.
214215
215216
Raises:
216-
ValueError: If image_name was not specified and framework enum is not valid.
217+
ValueError: If image_name is not specified and framework enum is not valid.
217218
"""
218219
base_args = dict(
219220
model_data=self.model_data,
@@ -252,9 +253,7 @@ def create_model(
252253
)
253254

254255
if self.framework == RLFramework.TENSORFLOW.value:
255-
from sagemaker.tensorflow.serving import Model as tfsModel
256-
257-
return tfsModel(framework_version=self.framework_version, **base_args)
256+
return TensorFlowModel(framework_version=self.framework_version, **base_args)
258257
if self.framework == RLFramework.MXNET.value:
259258
return MXNetModel(
260259
framework_version=self.framework_version, py_version=PYTHON_VERSION, **extended_args

src/sagemaker/tensorflow/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,4 @@
1414
from __future__ import absolute_import
1515

1616
from sagemaker.tensorflow.estimator import TensorFlow # noqa: F401 (imported but unused)
17+
from sagemaker.tensorflow.model import TensorFlowModel, TensorFlowPredictor # noqa: F401

src/sagemaker/tensorflow/estimator.py

+13-12
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from sagemaker.estimator import Framework
2424
import sagemaker.fw_utils as fw
2525
from sagemaker.tensorflow import defaults
26-
from sagemaker.tensorflow.serving import Model
26+
from sagemaker.tensorflow.model import TensorFlowModel
2727
from sagemaker.transformer import Transformer
2828
from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT
2929

@@ -249,12 +249,13 @@ def create_model(
249249
dependencies=None,
250250
**kwargs
251251
):
252-
"""Create a ``Model`` object that can be used for creating SageMaker model entities,
253-
deploying to a SageMaker endpoint, or starting SageMaker Batch Transform jobs.
252+
"""Create a ``TensorFlowModel`` object that can be used for creating
253+
SageMaker model entities, deploying to a SageMaker endpoint, or
254+
starting SageMaker Batch Transform jobs.
254255
255256
Args:
256-
role (str): The ``ExecutionRoleArn`` IAM Role ARN for the ``Model``, which is also
257-
used during transform jobs. If not specified, the role from the Estimator is used.
257+
role (str): The ``TensorFlowModel``, which is also used during transform jobs.
258+
If not specified, the role from the Estimator is used.
258259
vpc_config_override (dict[str, list[str]]): Optional override for VpcConfig set on the
259260
model. Default: use subnets and security groups from this Estimator.
260261
@@ -267,11 +268,12 @@ def create_model(
267268
source code dependencies aside from the entry point file (default: None).
268269
dependencies (list[str]): A list of paths to directories (absolute or relative) with
269270
any additional libraries that will be exported to the container (default: None).
270-
**kwargs: Additional kwargs passed to :class:`~sagemaker.tensorflow.serving.Model`.
271+
**kwargs: Additional kwargs passed to
272+
:class:`~sagemaker.tensorflow.model.TensorFlowModel`.
271273
272274
Returns:
273-
sagemaker.tensorflow.serving.Model: A ``Model`` object.
274-
See :class:`~sagemaker.tensorflow.serving.Model` for full details.
275+
sagemaker.tensorflow.model.TensorFlowModel: A ``TensorFlowModel`` object.
276+
See :class:`~sagemaker.tensorflow.model.TensorFlowModel` for full details.
275277
"""
276278
if "image" not in kwargs:
277279
kwargs["image"] = self.image_name
@@ -282,7 +284,7 @@ def create_model(
282284
if "enable_network_isolation" not in kwargs:
283285
kwargs["enable_network_isolation"] = self.enable_network_isolation()
284286

285-
return Model(
287+
return TensorFlowModel(
286288
model_data=self.model_data,
287289
role=role or self.role,
288290
container_log_level=self.container_log_level,
@@ -418,9 +420,8 @@ def transformer(
418420
container in MB.
419421
tags (list[dict]): List of tags for labeling a transform job. If none specified, then
420422
the tags used for the training job are used for the transform job.
421-
role (str): The ``ExecutionRoleArn`` IAM Role ARN for the ``Model``, which is also
422-
used during transform jobs. If not specified, the role from the Estimator will be
423-
used.
423+
role (str): The IAM Role ARN for the ``TensorFlowModel``, which is also used
424+
during transform jobs. If not specified, the role from the Estimator is used.
424425
volume_kms_key (str): Optional. KMS key ID for encrypting the volume attached to the ML
425426
compute instance (default: None).
426427
entry_point (str): Path (absolute or relative) to the local Python source file which

src/sagemaker/tensorflow/serving.py renamed to src/sagemaker/tensorflow/model.py

+19-19
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
1111
# ANY KIND, either express or implied. See the License for the specific
1212
# language governing permissions and limitations under the License.
13-
"""Placeholder docstring"""
13+
"""Classes for using TensorFlow on Amazon SageMaker for inference."""
1414
from __future__ import absolute_import
1515

1616
import logging
@@ -22,7 +22,7 @@
2222
from sagemaker.tensorflow.defaults import TF_VERSION
2323

2424

25-
class Predictor(sagemaker.RealTimePredictor):
25+
class TensorFlowPredictor(sagemaker.RealTimePredictor):
2626
"""A ``RealTimePredictor`` implementation for inference against TensorFlow
2727
Serving endpoints.
2828
"""
@@ -37,7 +37,7 @@ def __init__(
3737
model_name=None,
3838
model_version=None,
3939
):
40-
"""Initialize a ``TFSPredictor``. See ``sagemaker.RealTimePredictor``
40+
"""Initialize a ``TensorFlowPredictor``. See :class:`~sagemaker.predictor.RealTimePredictor`
4141
for more info about parameters.
4242
4343
Args:
@@ -61,7 +61,7 @@ def __init__(
6161
that should handle the request. If not specified, the latest
6262
version of the model will be used.
6363
"""
64-
super(Predictor, self).__init__(
64+
super(TensorFlowPredictor, self).__init__(
6565
endpoint_name, sagemaker_session, serializer, deserializer, content_type
6666
)
6767

@@ -115,13 +115,13 @@ def predict(self, data, initial_args=None):
115115
else:
116116
args["CustomAttributes"] = self._model_attributes
117117

118-
return super(Predictor, self).predict(data, args)
118+
return super(TensorFlowPredictor, self).predict(data, args)
119119

120120

121-
class Model(sagemaker.model.FrameworkModel):
122-
"""Placeholder docstring"""
121+
class TensorFlowModel(sagemaker.model.FrameworkModel):
122+
"""A ``FrameworkModel`` implementation for inference with TensorFlow Serving."""
123123

124-
FRAMEWORK_NAME = "tensorflow-serving"
124+
__framework_name__ = "tensorflow-serving"
125125
LOG_LEVEL_PARAM_NAME = "SAGEMAKER_TFS_NGINX_LOGLEVEL"
126126
LOG_LEVEL_MAP = {
127127
logging.DEBUG: "debug",
@@ -140,7 +140,7 @@ def __init__(
140140
image=None,
141141
framework_version=TF_VERSION,
142142
container_log_level=None,
143-
predictor_cls=Predictor,
143+
predictor_cls=TensorFlowPredictor,
144144
**kwargs
145145
):
146146
"""Initialize a Model.
@@ -171,15 +171,15 @@ def __init__(
171171
:class:`~sagemaker.model.FrameworkModel` and
172172
:class:`~sagemaker.model.Model`.
173173
"""
174-
super(Model, self).__init__(
174+
super(TensorFlowModel, self).__init__(
175175
model_data=model_data,
176176
role=role,
177177
image=image,
178178
predictor_cls=predictor_cls,
179179
entry_point=entry_point,
180180
**kwargs
181181
)
182-
self._framework_version = framework_version
182+
self.framework_version = framework_version
183183
self._container_log_level = container_log_level
184184

185185
def deploy(
@@ -196,10 +196,10 @@ def deploy(
196196
):
197197

198198
if accelerator_type and not self._eia_supported():
199-
msg = "The TensorFlow version %s doesn't support EIA." % self._framework_version
200-
199+
msg = "The TensorFlow version %s doesn't support EIA." % self.framework_version
201200
raise AttributeError(msg)
202-
return super(Model, self).deploy(
201+
202+
return super(TensorFlowModel, self).deploy(
203203
initial_instance_count=initial_instance_count,
204204
instance_type=instance_type,
205205
accelerator_type=accelerator_type,
@@ -213,7 +213,7 @@ def deploy(
213213

214214
def _eia_supported(self):
215215
"""Return true if TF version is EIA enabled"""
216-
return [int(s) for s in self._framework_version.split(".")][:2] <= self.LATEST_EIA_VERSION
216+
return [int(s) for s in self.framework_version.split(".")][:2] <= self.LATEST_EIA_VERSION
217217

218218
def prepare_container_def(self, instance_type, accelerator_type=None):
219219
"""
@@ -249,12 +249,12 @@ def _get_container_env(self):
249249
if not self._container_log_level:
250250
return self.env
251251

252-
if self._container_log_level not in Model.LOG_LEVEL_MAP:
252+
if self._container_log_level not in self.LOG_LEVEL_MAP:
253253
logging.warning("ignoring invalid container log level: %s", self._container_log_level)
254254
return self.env
255255

256256
env = dict(self.env)
257-
env[Model.LOG_LEVEL_PARAM_NAME] = Model.LOG_LEVEL_MAP[self._container_log_level]
257+
env[self.LOG_LEVEL_PARAM_NAME] = self.LOG_LEVEL_MAP[self._container_log_level]
258258
return env
259259

260260
def _get_image_uri(self, instance_type, accelerator_type=None):
@@ -269,9 +269,9 @@ def _get_image_uri(self, instance_type, accelerator_type=None):
269269
region_name = self.sagemaker_session.boto_region_name
270270
return create_image_uri(
271271
region_name,
272-
Model.FRAMEWORK_NAME,
272+
self.__framework_name__,
273273
instance_type,
274-
self._framework_version,
274+
self.framework_version,
275275
accelerator_type=accelerator_type,
276276
)
277277

tests/integ/test_data_capture_config.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import tests.integ
1919
import tests.integ.timeout
2020
from sagemaker.model_monitor import DataCaptureConfig, NetworkConfig
21-
from sagemaker.tensorflow.serving import Model
21+
from sagemaker.tensorflow.model import TensorFlowModel
2222
from sagemaker.utils import unique_name_from_base
2323
from tests.integ.retry import retries
2424

@@ -49,7 +49,7 @@ def test_enabling_data_capture_on_endpoint_shows_correct_data_capture_status(
4949
key_prefix="tensorflow-serving/models",
5050
)
5151
with tests.integ.timeout.timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session):
52-
model = Model(
52+
model = TensorFlowModel(
5353
model_data=model_data,
5454
role=ROLE,
5555
framework_version=tf_full_version,
@@ -106,7 +106,7 @@ def test_disabling_data_capture_on_endpoint_shows_correct_data_capture_status(
106106
key_prefix="tensorflow-serving/models",
107107
)
108108
with tests.integ.timeout.timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session):
109-
model = Model(
109+
model = TensorFlowModel(
110110
model_data=model_data,
111111
role=ROLE,
112112
framework_version=tf_full_version,
@@ -192,7 +192,7 @@ def test_updating_data_capture_on_endpoint_shows_correct_data_capture_status(
192192
key_prefix="tensorflow-serving/models",
193193
)
194194
with tests.integ.timeout.timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session):
195-
model = Model(
195+
model = TensorFlowModel(
196196
model_data=model_data,
197197
role=ROLE,
198198
framework_version=tf_full_version,

tests/integ/test_model_monitor.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
from sagemaker.model_monitor import CronExpressionGenerator
3737
from sagemaker.processing import ProcessingInput
3838
from sagemaker.processing import ProcessingOutput
39-
from sagemaker.tensorflow.serving import Model
39+
from sagemaker.tensorflow.model import TensorFlowModel
4040
from sagemaker.utils import unique_name_from_base
4141

4242
from tests.integ.kms_utils import get_or_create_kms_key
@@ -97,7 +97,7 @@ def predictor(sagemaker_session, tf_full_version):
9797
with tests.integ.timeout.timeout_and_delete_endpoint_by_name(
9898
endpoint_name=endpoint_name, sagemaker_session=sagemaker_session, hours=2
9999
):
100-
model = Model(
100+
model = TensorFlowModel(
101101
model_data=model_data,
102102
role=ROLE,
103103
framework_version=tf_full_version,

tests/integ/test_tfs.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
import sagemaker.utils
2424
import tests.integ
2525
import tests.integ.timeout
26-
from sagemaker.tensorflow.serving import Model, Predictor
26+
from sagemaker.tensorflow.model import TensorFlowModel, TensorFlowPredictor
2727

2828

2929
@pytest.fixture(scope="module")
@@ -34,7 +34,7 @@ def tfs_predictor(sagemaker_session, tf_full_version):
3434
key_prefix="tensorflow-serving/models",
3535
)
3636
with tests.integ.timeout.timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session):
37-
model = Model(
37+
model = TensorFlowModel(
3838
model_data=model_data,
3939
role="SageMakerRole",
4040
framework_version=tf_full_version,
@@ -62,7 +62,7 @@ def tfs_predictor_with_model_and_entry_point_same_tar(
6262
os.path.join(tests.integ.DATA_DIR, "tfs/tfs-test-model-with-inference"), tmpdir
6363
)
6464

65-
model = Model(
65+
model = TensorFlowModel(
6666
model_data="file://" + model_tar,
6767
role="SageMakerRole",
6868
framework_version=tf_full_version,
@@ -93,7 +93,7 @@ def tfs_predictor_with_model_and_entry_point_and_dependencies(
9393
tests.integ.DATA_DIR, "tensorflow-serving-test-model.tar.gz"
9494
)
9595

96-
model = Model(
96+
model = TensorFlowModel(
9797
entry_point=entry_point,
9898
model_data=model_data,
9999
role="SageMakerRole",
@@ -118,7 +118,7 @@ def tfs_predictor_with_accelerator(sagemaker_session, ei_tf_full_version, cpu_in
118118
key_prefix="tensorflow-serving/models",
119119
)
120120
with tests.integ.timeout.timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session):
121-
model = Model(
121+
model = TensorFlowModel(
122122
model_data=model_data,
123123
role="SageMakerRole",
124124
framework_version=ei_tf_full_version,
@@ -235,7 +235,7 @@ def test_predict_csv(tfs_predictor):
235235
input_data = "1.0,2.0,5.0\n1.0,2.0,5.0"
236236
expected_result = {"predictions": [[3.5, 4.0, 5.5], [3.5, 4.0, 5.5]]}
237237

238-
predictor = Predictor(
238+
predictor = TensorFlowPredictor(
239239
tfs_predictor.endpoint,
240240
tfs_predictor.sagemaker_session,
241241
serializer=sagemaker.predictor.csv_serializer,

0 commit comments

Comments
 (0)