Skip to content

breaking: remove legacy TensorFlowModel and TensorFlowPredictor classes #1531

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
May 29, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 0 additions & 16 deletions doc/sagemaker.tensorflow.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,22 +10,6 @@ TensorFlow Estimator
:undoc-members:
:show-inheritance:

TensorFlow Model
----------------

.. autoclass:: sagemaker.tensorflow.model.TensorFlowModel
:members:
:undoc-members:
:show-inheritance:

TensorFlow Predictor
--------------------

.. autoclass:: sagemaker.tensorflow.model.TensorFlowPredictor
:members:
:undoc-members:
:show-inheritance:

TensorFlow Serving Model
------------------------

Expand Down
5 changes: 2 additions & 3 deletions src/sagemaker/cli/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,13 +68,12 @@ def create_model(self, model_url):
Args:
model_url:
"""
from sagemaker.tensorflow.model import TensorFlowModel
from sagemaker.tensorflow.serving import Model

return TensorFlowModel(
return Model(
model_data=model_url,
role=self.role_name,
entry_point=self.script,
py_version=self.python,
name=self.endpoint_name,
env=self.environment,
)
18 changes: 2 additions & 16 deletions src/sagemaker/tensorflow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,21 +10,7 @@
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""Placeholder docstring"""
"""Classes for using TensorFlow and TensorFlow Serving with Amazon SageMaker."""
from __future__ import absolute_import

import sys
import os

# Hack to use our local copy of tensorflow_serving.apis, which contains the protobuf-generated
# classes for tensorflow serving. Currently tensorflow_serving_api can only be pip-installed for
# python 2.
sys.path.append(os.path.dirname(__file__))

from sagemaker.tensorflow.estimator import ( # noqa: E402, F401 # pylint: disable=wrong-import-position
TensorFlow,
)
from sagemaker.tensorflow.model import ( # noqa: E402, F401 # pylint: disable=wrong-import-position
TensorFlowModel,
TensorFlowPredictor,
)
from sagemaker.tensorflow.estimator import TensorFlow # noqa: F401 (imported but unused)
114 changes: 12 additions & 102 deletions src/sagemaker/tensorflow/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
from sagemaker.estimator import Framework
import sagemaker.fw_utils as fw
from sagemaker.tensorflow import defaults
from sagemaker.tensorflow.model import TensorFlowModel
from sagemaker.tensorflow.serving import Model
from sagemaker.transformer import Transformer
from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT
Expand Down Expand Up @@ -252,10 +251,8 @@ def _prepare_init_params_from_job_description(cls, job_details, model_channel_na

def create_model(
self,
model_server_workers=None,
role=None,
vpc_config_override=VPC_CONFIG_DEFAULT,
endpoint_type=None,
entry_point=None,
source_dir=None,
dependencies=None,
Expand All @@ -266,43 +263,25 @@ def create_model(

Args:
role (str): The ``ExecutionRoleArn`` IAM Role ARN for the ``Model``, which is also
used during transform jobs. If not specified, the role from the Estimator will be
used.
model_server_workers (int): Optional. The number of worker processes used by the
inference server. If None, server will use one worker per vCPU.
used during transform jobs. If not specified, the role from the Estimator is used.
vpc_config_override (dict[str, list[str]]): Optional override for VpcConfig set on the
model.
Default: use subnets and security groups from this Estimator.
model. Default: use subnets and security groups from this Estimator.

* 'Subnets' (list[str]): List of subnet ids.
* 'SecurityGroupIds' (list[str]): List of security group ids.
endpoint_type (str): Optional. Selects the software stack used by the inference server.
If not specified, the model will be configured to use the default
SageMaker model server. If 'tensorflow-serving', the model will be configured to
use the SageMaker Tensorflow Serving container.

entry_point (str): Path (absolute or relative) to the local Python source file which
should be executed as the entry point to training. If not specified and
``endpoint_type`` is 'tensorflow-serving', no entry point is used. If
``endpoint_type`` is also ``None``, then the training entry point is used.
should be executed as the entry point to training (default: None).
source_dir (str): Path (absolute or relative) to a directory with any other serving
source code dependencies aside from the entry point file. If not specified and
``endpoint_type`` is 'tensorflow-serving', no source_dir is used. If
``endpoint_type`` is also ``None``, then the model source directory from training
is used.
source code dependencies aside from the entry point file (default: None).
dependencies (list[str]): A list of paths to directories (absolute or relative) with
any additional libraries that will be exported to the container.
If not specified and ``endpoint_type`` is 'tensorflow-serving', ``dependencies`` is
set to ``None``.
If ``endpoint_type`` is also ``None``, then the dependencies from training are used.
**kwargs: Additional kwargs passed to :class:`~sagemaker.tensorflow.serving.Model`
and :class:`~sagemaker.tensorflow.model.TensorFlowModel` constructors.
any additional libraries that will be exported to the container (default: None).
**kwargs: Additional kwargs passed to :class:`~sagemaker.tensorflow.serving.Model`.

Returns:
sagemaker.tensorflow.model.TensorFlowModel or sagemaker.tensorflow.serving.Model: A
``Model`` object. See :class:`~sagemaker.tensorflow.serving.Model` or
:class:`~sagemaker.tensorflow.model.TensorFlowModel` for full details.
sagemaker.tensorflow.serving.Model: A ``Model`` object.
See :class:`~sagemaker.tensorflow.serving.Model` for full details.
"""
role = role or self.role

if "image" not in kwargs:
kwargs["image"] = self.image_name

Expand All @@ -312,41 +291,11 @@ def create_model(
if "enable_network_isolation" not in kwargs:
kwargs["enable_network_isolation"] = self.enable_network_isolation()

if endpoint_type == "tensorflow-serving" or self._script_mode_enabled:
return self._create_tfs_model(
role=role,
vpc_config_override=vpc_config_override,
entry_point=entry_point,
source_dir=source_dir,
dependencies=dependencies,
**kwargs
)

return self._create_default_model(
model_server_workers=model_server_workers,
role=role,
vpc_config_override=vpc_config_override,
entry_point=entry_point,
source_dir=source_dir,
dependencies=dependencies,
**kwargs
)

def _create_tfs_model(
self,
role=None,
vpc_config_override=VPC_CONFIG_DEFAULT,
entry_point=None,
source_dir=None,
dependencies=None,
**kwargs
):
"""Placeholder docstring"""
return Model(
model_data=self.model_data,
role=role,
role=role or self.role,
container_log_level=self.container_log_level,
framework_version=utils.get_short_version(self.framework_version),
framework_version=self.framework_version,
sagemaker_session=self.sagemaker_session,
vpc_config=self.get_vpc_config(vpc_config_override),
entry_point=entry_point,
Expand All @@ -355,34 +304,6 @@ def _create_tfs_model(
**kwargs
)

def _create_default_model(
self,
model_server_workers,
role,
vpc_config_override,
entry_point=None,
source_dir=None,
dependencies=None,
**kwargs
):
"""Placeholder docstring"""
return TensorFlowModel(
self.model_data,
role,
entry_point or self.entry_point,
source_dir=source_dir or self._model_source_dir(),
enable_cloudwatch_metrics=self.enable_cloudwatch_metrics,
container_log_level=self.container_log_level,
code_location=self.code_location,
py_version=self.py_version,
framework_version=self.framework_version,
model_server_workers=model_server_workers,
sagemaker_session=self.sagemaker_session,
vpc_config=self.get_vpc_config(vpc_config_override),
dependencies=dependencies or self.dependencies,
**kwargs
)

def hyperparameters(self):
"""Return hyperparameters used by your custom TensorFlow code during model training."""
hyperparameters = super(TensorFlow, self).hyperparameters()
Expand Down Expand Up @@ -479,9 +400,7 @@ def transformer(
max_payload=None,
tags=None,
role=None,
model_server_workers=None,
volume_kms_key=None,
endpoint_type=None,
entry_point=None,
vpc_config_override=VPC_CONFIG_DEFAULT,
enable_network_isolation=None,
Expand Down Expand Up @@ -515,15 +434,8 @@ def transformer(
role (str): The ``ExecutionRoleArn`` IAM Role ARN for the ``Model``, which is also
used during transform jobs. If not specified, the role from the Estimator will be
used.
model_server_workers (int): Optional. The number of worker processes used by the
inference server. If None, server will use one worker per vCPU.
volume_kms_key (str): Optional. KMS key ID for encrypting the volume attached to the ML
compute instance (default: None).
endpoint_type (str): Optional. Selects the software stack used by the inference server.
If not specified, the model will be configured to use the default
SageMaker model server.
If 'tensorflow-serving', the model will be configured to
use the SageMaker Tensorflow Serving container.
entry_point (str): Path (absolute or relative) to the local Python source file which
should be executed as the entry point to training. If not specified and
``endpoint_type`` is 'tensorflow-serving', no entry point is used. If
Expand Down Expand Up @@ -575,10 +487,8 @@ def transformer(
enable_network_isolation = self.enable_network_isolation()

model = self.create_model(
model_server_workers=model_server_workers,
role=role,
vpc_config_override=vpc_config_override,
endpoint_type=endpoint_type,
entry_point=entry_point,
enable_network_isolation=enable_network_isolation,
name=model_name,
Expand Down
Loading