Skip to content

Commit ba48521

Browse files
committed
breaking: remove legacy TensorFlowModel and TensorFlowPredictor classes
This change also removes the associated serialization/deserialization code used by TensorFlowPredictor and the locally copied TFS APIs.
1 parent 5b078f7 commit ba48521

20 files changed

+22
-3944
lines changed

src/sagemaker/cli/tensorflow.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,13 +68,12 @@ def create_model(self, model_url):
6868
Args:
6969
model_url:
7070
"""
71-
from sagemaker.tensorflow.model import TensorFlowModel
71+
from sagemaker.tensorflow.serving import Model
7272

73-
return TensorFlowModel(
73+
return Model(
7474
model_data=model_url,
7575
role=self.role_name,
7676
entry_point=self.script,
77-
py_version=self.python,
7877
name=self.endpoint_name,
7978
env=self.environment,
8079
)

src/sagemaker/tensorflow/__init__.py

Lines changed: 2 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -10,21 +10,7 @@
1010
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
1111
# ANY KIND, either express or implied. See the License for the specific
1212
# language governing permissions and limitations under the License.
13-
"""Placeholder docstring"""
13+
"""Classes for using TensorFlow and TensorFlow Serving with Amazon SageMaker."""
1414
from __future__ import absolute_import
1515

16-
import sys
17-
import os
18-
19-
# Hack to use our local copy of tensorflow_serving.apis, which contains the protobuf-generated
20-
# classes for tensorflow serving. Currently tensorflow_serving_api can only be pip-installed for
21-
# python 2.
22-
sys.path.append(os.path.dirname(__file__))
23-
24-
from sagemaker.tensorflow.estimator import ( # noqa: E402, F401 # pylint: disable=wrong-import-position
25-
TensorFlow,
26-
)
27-
from sagemaker.tensorflow.model import ( # noqa: E402, F401 # pylint: disable=wrong-import-position
28-
TensorFlowModel,
29-
TensorFlowPredictor,
30-
)
16+
from sagemaker.tensorflow.estimator import TensorFlow # noqa: F401 (imported but unused)

src/sagemaker/tensorflow/estimator.py

Lines changed: 12 additions & 226 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,7 @@
2323
from sagemaker.estimator import Framework
2424
import sagemaker.fw_utils as fw
2525
from sagemaker.tensorflow import defaults
26-
from sagemaker.tensorflow.model import TensorFlowModel
2726
from sagemaker.tensorflow.serving import Model
28-
from sagemaker.transformer import Transformer
2927
from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT
3028

3129
logger = logging.getLogger("sagemaker")
@@ -252,10 +250,8 @@ def _prepare_init_params_from_job_description(cls, job_details, model_channel_na
252250

253251
def create_model(
254252
self,
255-
model_server_workers=None,
256253
role=None,
257254
vpc_config_override=VPC_CONFIG_DEFAULT,
258-
endpoint_type=None,
259255
entry_point=None,
260256
source_dir=None,
261257
dependencies=None,
@@ -266,43 +262,25 @@ def create_model(
266262
267263
Args:
268264
role (str): The ``ExecutionRoleArn`` IAM Role ARN for the ``Model``, which is also
269-
used during transform jobs. If not specified, the role from the Estimator will be
270-
used.
271-
model_server_workers (int): Optional. The number of worker processes used by the
272-
inference server. If None, server will use one worker per vCPU.
265+
used during transform jobs. If not specified, the role from the Estimator is used.
273266
vpc_config_override (dict[str, list[str]]): Optional override for VpcConfig set on the
274-
model.
275-
Default: use subnets and security groups from this Estimator.
267+
model. Default: use subnets and security groups from this Estimator.
268+
276269
* 'Subnets' (list[str]): List of subnet ids.
277270
* 'SecurityGroupIds' (list[str]): List of security group ids.
278-
endpoint_type (str): Optional. Selects the software stack used by the inference server.
279-
If not specified, the model will be configured to use the default
280-
SageMaker model server. If 'tensorflow-serving', the model will be configured to
281-
use the SageMaker Tensorflow Serving container.
271+
282272
entry_point (str): Path (absolute or relative) to the local Python source file which
283-
should be executed as the entry point to training. If not specified and
284-
``endpoint_type`` is 'tensorflow-serving', no entry point is used. If
285-
``endpoint_type`` is also ``None``, then the training entry point is used.
273+
should be executed as the entry point to training (default: None).
286274
source_dir (str): Path (absolute or relative) to a directory with any other serving
287-
source code dependencies aside from the entry point file. If not specified and
288-
``endpoint_type`` is 'tensorflow-serving', no source_dir is used. If
289-
``endpoint_type`` is also ``None``, then the model source directory from training
290-
is used.
275+
source code dependencies aside from the entry point file (default: None).
291276
dependencies (list[str]): A list of paths to directories (absolute or relative) with
292-
any additional libraries that will be exported to the container.
293-
If not specified and ``endpoint_type`` is 'tensorflow-serving', ``dependencies`` is
294-
set to ``None``.
295-
If ``endpoint_type`` is also ``None``, then the dependencies from training are used.
296-
**kwargs: Additional kwargs passed to :class:`~sagemaker.tensorflow.serving.Model`
297-
and :class:`~sagemaker.tensorflow.model.TensorFlowModel` constructors.
277+
any additional libraries that will be exported to the container (default: None).
278+
**kwargs: Additional kwargs passed to :class:`~sagemaker.tensorflow.serving.Model`.
298279
299280
Returns:
300-
sagemaker.tensorflow.model.TensorFlowModel or sagemaker.tensorflow.serving.Model: A
301-
``Model`` object. See :class:`~sagemaker.tensorflow.serving.Model` or
302-
:class:`~sagemaker.tensorflow.model.TensorFlowModel` for full details.
281+
sagemaker.tensorflow.serving.Model: A ``Model`` object.
282+
See :class:`~sagemaker.tensorflow.serving.Model` for full details.
303283
"""
304-
role = role or self.role
305-
306284
if "image" not in kwargs:
307285
kwargs["image"] = self.image_name
308286

@@ -312,41 +290,11 @@ def create_model(
312290
if "enable_network_isolation" not in kwargs:
313291
kwargs["enable_network_isolation"] = self.enable_network_isolation()
314292

315-
if endpoint_type == "tensorflow-serving" or self._script_mode_enabled:
316-
return self._create_tfs_model(
317-
role=role,
318-
vpc_config_override=vpc_config_override,
319-
entry_point=entry_point,
320-
source_dir=source_dir,
321-
dependencies=dependencies,
322-
**kwargs
323-
)
324-
325-
return self._create_default_model(
326-
model_server_workers=model_server_workers,
327-
role=role,
328-
vpc_config_override=vpc_config_override,
329-
entry_point=entry_point,
330-
source_dir=source_dir,
331-
dependencies=dependencies,
332-
**kwargs
333-
)
334-
335-
def _create_tfs_model(
336-
self,
337-
role=None,
338-
vpc_config_override=VPC_CONFIG_DEFAULT,
339-
entry_point=None,
340-
source_dir=None,
341-
dependencies=None,
342-
**kwargs
343-
):
344-
"""Placeholder docstring"""
345293
return Model(
346294
model_data=self.model_data,
347-
role=role,
295+
role=role or self.role,
348296
container_log_level=self.container_log_level,
349-
framework_version=utils.get_short_version(self.framework_version),
297+
framework_version=self.framework_version,
350298
sagemaker_session=self.sagemaker_session,
351299
vpc_config=self.get_vpc_config(vpc_config_override),
352300
entry_point=entry_point,
@@ -355,34 +303,6 @@ def _create_tfs_model(
355303
**kwargs
356304
)
357305

358-
def _create_default_model(
359-
self,
360-
model_server_workers,
361-
role,
362-
vpc_config_override,
363-
entry_point=None,
364-
source_dir=None,
365-
dependencies=None,
366-
**kwargs
367-
):
368-
"""Placeholder docstring"""
369-
return TensorFlowModel(
370-
self.model_data,
371-
role,
372-
entry_point or self.entry_point,
373-
source_dir=source_dir or self._model_source_dir(),
374-
enable_cloudwatch_metrics=self.enable_cloudwatch_metrics,
375-
container_log_level=self.container_log_level,
376-
code_location=self.code_location,
377-
py_version=self.py_version,
378-
framework_version=self.framework_version,
379-
model_server_workers=model_server_workers,
380-
sagemaker_session=self.sagemaker_session,
381-
vpc_config=self.get_vpc_config(vpc_config_override),
382-
dependencies=dependencies or self.dependencies,
383-
**kwargs
384-
)
385-
386306
def hyperparameters(self):
387307
"""Return hyperparameters used by your custom TensorFlow code during model training."""
388308
hyperparameters = super(TensorFlow, self).hyperparameters()
@@ -464,137 +384,3 @@ def train_image(self):
464384
)
465385

466386
return super(TensorFlow, self).train_image()
467-
468-
def transformer(
469-
self,
470-
instance_count,
471-
instance_type,
472-
strategy=None,
473-
assemble_with=None,
474-
output_path=None,
475-
output_kms_key=None,
476-
accept=None,
477-
env=None,
478-
max_concurrent_transforms=None,
479-
max_payload=None,
480-
tags=None,
481-
role=None,
482-
model_server_workers=None,
483-
volume_kms_key=None,
484-
endpoint_type=None,
485-
entry_point=None,
486-
vpc_config_override=VPC_CONFIG_DEFAULT,
487-
enable_network_isolation=None,
488-
model_name=None,
489-
):
490-
"""Return a ``Transformer`` that uses a SageMaker Model based on the training job. It
491-
reuses the SageMaker Session and base job name used by the Estimator.
492-
493-
Args:
494-
instance_count (int): Number of EC2 instances to use.
495-
instance_type (str): Type of EC2 instance to use, for example, 'ml.c4.xlarge'.
496-
strategy (str): The strategy used to decide how to batch records in a single request
497-
(default: None). Valid values: 'MultiRecord' and 'SingleRecord'.
498-
assemble_with (str): How the output is assembled (default: None). Valid values: 'Line'
499-
or 'None'.
500-
output_path (str): S3 location for saving the transform result. If not specified,
501-
results are stored to a default bucket.
502-
output_kms_key (str): Optional. KMS key ID for encrypting the transform output
503-
(default: None).
504-
accept (str): The accept header passed by the client to
505-
the inference endpoint. If it is supported by the endpoint,
506-
it will be the format of the batch transform output.
507-
env (dict): Environment variables to be set for use during the transform job
508-
(default: None).
509-
max_concurrent_transforms (int): The maximum number of HTTP requests to be made to
510-
each individual transform container at one time.
511-
max_payload (int): Maximum size of the payload in a single HTTP request to the
512-
container in MB.
513-
tags (list[dict]): List of tags for labeling a transform job. If none specified, then
514-
the tags used for the training job are used for the transform job.
515-
role (str): The ``ExecutionRoleArn`` IAM Role ARN for the ``Model``, which is also
516-
used during transform jobs. If not specified, the role from the Estimator will be
517-
used.
518-
model_server_workers (int): Optional. The number of worker processes used by the
519-
inference server. If None, server will use one worker per vCPU.
520-
volume_kms_key (str): Optional. KMS key ID for encrypting the volume attached to the ML
521-
compute instance (default: None).
522-
endpoint_type (str): Optional. Selects the software stack used by the inference server.
523-
If not specified, the model will be configured to use the default
524-
SageMaker model server.
525-
If 'tensorflow-serving', the model will be configured to
526-
use the SageMaker Tensorflow Serving container.
527-
entry_point (str): Path (absolute or relative) to the local Python source file which
528-
should be executed as the entry point to training. If not specified and
529-
``endpoint_type`` is 'tensorflow-serving', no entry point is used. If
530-
``endpoint_type`` is also ``None``, then the training entry point is used.
531-
vpc_config_override (dict[str, list[str]]): Optional override for
532-
the VpcConfig set on the model.
533-
Default: use subnets and security groups from this Estimator.
534-
535-
* 'Subnets' (list[str]): List of subnet ids.
536-
* 'SecurityGroupIds' (list[str]): List of security group ids.
537-
538-
enable_network_isolation (bool): Specifies whether container will
539-
run in network isolation mode. Network isolation mode restricts
540-
the container access to outside networks (such as the internet).
541-
The container does not make any inbound or outbound network
542-
calls. If True, a channel named "code" will be created for any
543-
user entry script for inference. Also known as Internet-free mode.
544-
If not specified, this setting is taken from the estimator's
545-
current configuration.
546-
model_name (str): Name to use for creating an Amazon SageMaker
547-
model. If not specified, the name of the training job is used.
548-
"""
549-
role = role or self.role
550-
551-
if self.latest_training_job is None:
552-
logging.warning(
553-
"No finished training job found associated with this estimator. Please make sure "
554-
"this estimator is only used for building workflow config"
555-
)
556-
return Transformer(
557-
model_name or self._current_job_name,
558-
instance_count,
559-
instance_type,
560-
strategy=strategy,
561-
assemble_with=assemble_with,
562-
output_path=output_path,
563-
output_kms_key=output_kms_key,
564-
accept=accept,
565-
max_concurrent_transforms=max_concurrent_transforms,
566-
max_payload=max_payload,
567-
env=env or {},
568-
tags=tags,
569-
base_transform_job_name=self.base_job_name,
570-
volume_kms_key=volume_kms_key,
571-
sagemaker_session=self.sagemaker_session,
572-
)
573-
574-
if enable_network_isolation is None:
575-
enable_network_isolation = self.enable_network_isolation()
576-
577-
model = self.create_model(
578-
model_server_workers=model_server_workers,
579-
role=role,
580-
vpc_config_override=vpc_config_override,
581-
endpoint_type=endpoint_type,
582-
entry_point=entry_point,
583-
enable_network_isolation=enable_network_isolation,
584-
name=model_name,
585-
)
586-
587-
return model.transformer(
588-
instance_count,
589-
instance_type,
590-
strategy=strategy,
591-
assemble_with=assemble_with,
592-
output_path=output_path,
593-
output_kms_key=output_kms_key,
594-
accept=accept,
595-
env=env,
596-
max_concurrent_transforms=max_concurrent_transforms,
597-
max_payload=max_payload,
598-
tags=tags,
599-
volume_kms_key=volume_kms_key,
600-
)

0 commit comments

Comments
 (0)