Skip to content

Commit 87c21bd

Browse files
authored
Merge branch 'zwei' into migration-script-tf-legacy-parameters
2 parents de6ebfe + c65c80f commit 87c21bd

22 files changed

+137
-3837
lines changed

doc/sagemaker.tensorflow.rst

-16
Original file line numberDiff line numberDiff line change
@@ -10,22 +10,6 @@ TensorFlow Estimator
1010
:undoc-members:
1111
:show-inheritance:
1212

13-
TensorFlow Model
14-
----------------
15-
16-
.. autoclass:: sagemaker.tensorflow.model.TensorFlowModel
17-
:members:
18-
:undoc-members:
19-
:show-inheritance:
20-
21-
TensorFlow Predictor
22-
--------------------
23-
24-
.. autoclass:: sagemaker.tensorflow.model.TensorFlowPredictor
25-
:members:
26-
:undoc-members:
27-
:show-inheritance:
28-
2913
TensorFlow Serving Model
3014
------------------------
3115

src/sagemaker/cli/tensorflow.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -68,13 +68,12 @@ def create_model(self, model_url):
6868
Args:
6969
model_url:
7070
"""
71-
from sagemaker.tensorflow.model import TensorFlowModel
71+
from sagemaker.tensorflow.serving import Model
7272

73-
return TensorFlowModel(
73+
return Model(
7474
model_data=model_url,
7575
role=self.role_name,
7676
entry_point=self.script,
77-
py_version=self.python,
7877
name=self.endpoint_name,
7978
env=self.environment,
8079
)

src/sagemaker/tensorflow/__init__.py

+2-16
Original file line numberDiff line numberDiff line change
@@ -10,21 +10,7 @@
1010
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
1111
# ANY KIND, either express or implied. See the License for the specific
1212
# language governing permissions and limitations under the License.
13-
"""Placeholder docstring"""
13+
"""Classes for using TensorFlow and TensorFlow Serving with Amazon SageMaker."""
1414
from __future__ import absolute_import
1515

16-
import sys
17-
import os
18-
19-
# Hack to use our local copy of tensorflow_serving.apis, which contains the protobuf-generated
20-
# classes for tensorflow serving. Currently tensorflow_serving_api can only be pip-installed for
21-
# python 2.
22-
sys.path.append(os.path.dirname(__file__))
23-
24-
from sagemaker.tensorflow.estimator import ( # noqa: E402, F401 # pylint: disable=wrong-import-position
25-
TensorFlow,
26-
)
27-
from sagemaker.tensorflow.model import ( # noqa: E402, F401 # pylint: disable=wrong-import-position
28-
TensorFlowModel,
29-
TensorFlowPredictor,
30-
)
16+
from sagemaker.tensorflow.estimator import TensorFlow # noqa: F401 (imported but unused)

src/sagemaker/tensorflow/estimator.py

+12-102
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
from sagemaker.estimator import Framework
2424
import sagemaker.fw_utils as fw
2525
from sagemaker.tensorflow import defaults
26-
from sagemaker.tensorflow.model import TensorFlowModel
2726
from sagemaker.tensorflow.serving import Model
2827
from sagemaker.transformer import Transformer
2928
from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT
@@ -252,10 +251,8 @@ def _prepare_init_params_from_job_description(cls, job_details, model_channel_na
252251

253252
def create_model(
254253
self,
255-
model_server_workers=None,
256254
role=None,
257255
vpc_config_override=VPC_CONFIG_DEFAULT,
258-
endpoint_type=None,
259256
entry_point=None,
260257
source_dir=None,
261258
dependencies=None,
@@ -266,43 +263,25 @@ def create_model(
266263
267264
Args:
268265
role (str): The ``ExecutionRoleArn`` IAM Role ARN for the ``Model``, which is also
269-
used during transform jobs. If not specified, the role from the Estimator will be
270-
used.
271-
model_server_workers (int): Optional. The number of worker processes used by the
272-
inference server. If None, server will use one worker per vCPU.
266+
used during transform jobs. If not specified, the role from the Estimator is used.
273267
vpc_config_override (dict[str, list[str]]): Optional override for VpcConfig set on the
274-
model.
275-
Default: use subnets and security groups from this Estimator.
268+
model. Default: use subnets and security groups from this Estimator.
269+
276270
* 'Subnets' (list[str]): List of subnet ids.
277271
* 'SecurityGroupIds' (list[str]): List of security group ids.
278-
endpoint_type (str): Optional. Selects the software stack used by the inference server.
279-
If not specified, the model will be configured to use the default
280-
SageMaker model server. If 'tensorflow-serving', the model will be configured to
281-
use the SageMaker Tensorflow Serving container.
272+
282273
entry_point (str): Path (absolute or relative) to the local Python source file which
283-
should be executed as the entry point to training. If not specified and
284-
``endpoint_type`` is 'tensorflow-serving', no entry point is used. If
285-
``endpoint_type`` is also ``None``, then the training entry point is used.
274+
should be executed as the entry point to training (default: None).
286275
source_dir (str): Path (absolute or relative) to a directory with any other serving
287-
source code dependencies aside from the entry point file. If not specified and
288-
``endpoint_type`` is 'tensorflow-serving', no source_dir is used. If
289-
``endpoint_type`` is also ``None``, then the model source directory from training
290-
is used.
276+
source code dependencies aside from the entry point file (default: None).
291277
dependencies (list[str]): A list of paths to directories (absolute or relative) with
292-
any additional libraries that will be exported to the container.
293-
If not specified and ``endpoint_type`` is 'tensorflow-serving', ``dependencies`` is
294-
set to ``None``.
295-
If ``endpoint_type`` is also ``None``, then the dependencies from training are used.
296-
**kwargs: Additional kwargs passed to :class:`~sagemaker.tensorflow.serving.Model`
297-
and :class:`~sagemaker.tensorflow.model.TensorFlowModel` constructors.
278+
any additional libraries that will be exported to the container (default: None).
279+
**kwargs: Additional kwargs passed to :class:`~sagemaker.tensorflow.serving.Model`.
298280
299281
Returns:
300-
sagemaker.tensorflow.model.TensorFlowModel or sagemaker.tensorflow.serving.Model: A
301-
``Model`` object. See :class:`~sagemaker.tensorflow.serving.Model` or
302-
:class:`~sagemaker.tensorflow.model.TensorFlowModel` for full details.
282+
sagemaker.tensorflow.serving.Model: A ``Model`` object.
283+
See :class:`~sagemaker.tensorflow.serving.Model` for full details.
303284
"""
304-
role = role or self.role
305-
306285
if "image" not in kwargs:
307286
kwargs["image"] = self.image_name
308287

@@ -312,41 +291,11 @@ def create_model(
312291
if "enable_network_isolation" not in kwargs:
313292
kwargs["enable_network_isolation"] = self.enable_network_isolation()
314293

315-
if endpoint_type == "tensorflow-serving" or self._script_mode_enabled:
316-
return self._create_tfs_model(
317-
role=role,
318-
vpc_config_override=vpc_config_override,
319-
entry_point=entry_point,
320-
source_dir=source_dir,
321-
dependencies=dependencies,
322-
**kwargs
323-
)
324-
325-
return self._create_default_model(
326-
model_server_workers=model_server_workers,
327-
role=role,
328-
vpc_config_override=vpc_config_override,
329-
entry_point=entry_point,
330-
source_dir=source_dir,
331-
dependencies=dependencies,
332-
**kwargs
333-
)
334-
335-
def _create_tfs_model(
336-
self,
337-
role=None,
338-
vpc_config_override=VPC_CONFIG_DEFAULT,
339-
entry_point=None,
340-
source_dir=None,
341-
dependencies=None,
342-
**kwargs
343-
):
344-
"""Placeholder docstring"""
345294
return Model(
346295
model_data=self.model_data,
347-
role=role,
296+
role=role or self.role,
348297
container_log_level=self.container_log_level,
349-
framework_version=utils.get_short_version(self.framework_version),
298+
framework_version=self.framework_version,
350299
sagemaker_session=self.sagemaker_session,
351300
vpc_config=self.get_vpc_config(vpc_config_override),
352301
entry_point=entry_point,
@@ -355,34 +304,6 @@ def _create_tfs_model(
355304
**kwargs
356305
)
357306

358-
def _create_default_model(
359-
self,
360-
model_server_workers,
361-
role,
362-
vpc_config_override,
363-
entry_point=None,
364-
source_dir=None,
365-
dependencies=None,
366-
**kwargs
367-
):
368-
"""Placeholder docstring"""
369-
return TensorFlowModel(
370-
self.model_data,
371-
role,
372-
entry_point or self.entry_point,
373-
source_dir=source_dir or self._model_source_dir(),
374-
enable_cloudwatch_metrics=self.enable_cloudwatch_metrics,
375-
container_log_level=self.container_log_level,
376-
code_location=self.code_location,
377-
py_version=self.py_version,
378-
framework_version=self.framework_version,
379-
model_server_workers=model_server_workers,
380-
sagemaker_session=self.sagemaker_session,
381-
vpc_config=self.get_vpc_config(vpc_config_override),
382-
dependencies=dependencies or self.dependencies,
383-
**kwargs
384-
)
385-
386307
def hyperparameters(self):
387308
"""Return hyperparameters used by your custom TensorFlow code during model training."""
388309
hyperparameters = super(TensorFlow, self).hyperparameters()
@@ -479,9 +400,7 @@ def transformer(
479400
max_payload=None,
480401
tags=None,
481402
role=None,
482-
model_server_workers=None,
483403
volume_kms_key=None,
484-
endpoint_type=None,
485404
entry_point=None,
486405
vpc_config_override=VPC_CONFIG_DEFAULT,
487406
enable_network_isolation=None,
@@ -515,15 +434,8 @@ def transformer(
515434
role (str): The ``ExecutionRoleArn`` IAM Role ARN for the ``Model``, which is also
516435
used during transform jobs. If not specified, the role from the Estimator will be
517436
used.
518-
model_server_workers (int): Optional. The number of worker processes used by the
519-
inference server. If None, server will use one worker per vCPU.
520437
volume_kms_key (str): Optional. KMS key ID for encrypting the volume attached to the ML
521438
compute instance (default: None).
522-
endpoint_type (str): Optional. Selects the software stack used by the inference server.
523-
If not specified, the model will be configured to use the default
524-
SageMaker model server.
525-
If 'tensorflow-serving', the model will be configured to
526-
use the SageMaker Tensorflow Serving container.
527439
entry_point (str): Path (absolute or relative) to the local Python source file which
528440
should be executed as the entry point to training. If not specified and
529441
``endpoint_type`` is 'tensorflow-serving', no entry point is used. If
@@ -575,10 +487,8 @@ def transformer(
575487
enable_network_isolation = self.enable_network_isolation()
576488

577489
model = self.create_model(
578-
model_server_workers=model_server_workers,
579490
role=role,
580491
vpc_config_override=vpc_config_override,
581-
endpoint_type=endpoint_type,
582492
entry_point=entry_point,
583493
enable_network_isolation=enable_network_isolation,
584494
name=model_name,

0 commit comments

Comments
 (0)