Skip to content

Commit 37a806b

Browse files
icywang86ruipengk19
authored andcommitted
feature: add TensorFlow 1.13 support (aws#860)
1 parent c32e7ac commit 37a806b

File tree

13 files changed

+67
-493
lines changed

13 files changed

+67
-493
lines changed

README.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,7 @@ TensorFlow SageMaker Estimators
173173

174174
By using TensorFlow SageMaker Estimators, you can train and host TensorFlow models on Amazon SageMaker.
175175

176-
Supported versions of TensorFlow: ``1.4.1``, ``1.5.0``, ``1.6.0``, ``1.7.0``, ``1.8.0``, ``1.9.0``, ``1.10.0``, ``1.11.0``, ``1.12.0``.
176+
Supported versions of TensorFlow: ``1.4.1``, ``1.5.0``, ``1.6.0``, ``1.7.0``, ``1.8.0``, ``1.9.0``, ``1.10.0``, ``1.11.0``, ``1.12.0``, ``1.13.1``.
177177

178178
Supported versions of TensorFlow for Elastic Inference: ``1.11.0``, ``1.12.0``.
179179

doc/using_tf.rst

+3-13
Original file line numberDiff line numberDiff line change
@@ -443,20 +443,10 @@ After a TensorFlow estimator has been fit, it saves a TensorFlow SavedModel in
443443
the S3 location defined by ``output_path``. You can call ``deploy`` on a TensorFlow
444444
estimator to create a SageMaker Endpoint.
445445

446-
SageMaker provides two different options for deploying TensorFlow models to a SageMaker
447-
Endpoint:
446+
Your model will be deployed to a TensorFlow Serving-based server. The server provides a super-set of the
447+
`TensorFlow Serving REST API <https://www.tensorflow.org/serving/api_rest>`_.
448448

449-
- The first option uses a Python-based server that allows you to specify your own custom
450-
input and output handling functions in a Python script. This is the default option.
451-
452-
See `Deploying to Python-based Endpoints <https://github.com/aws/sagemaker-python-sdk/blob/master/src/sagemaker/tensorflow/deploying_python.rst>`_ to learn how to use this option.
453-
454-
455-
- The second option uses a TensorFlow Serving-based server to provide a super-set of the
456-
`TensorFlow Serving REST API <https://www.tensorflow.org/serving/api_rest>`_. This option
457-
does not require (or allow) a custom python script.
458-
459-
See `Deploying to TensorFlow Serving Endpoints <https://github.com/aws/sagemaker-python-sdk/blob/master/src/sagemaker/tensorflow/deploying_tensorflow_serving.rst>`_ to learn how to use this option.
449+
See `Deploying to TensorFlow Serving Endpoints <https://github.com/aws/sagemaker-python-sdk/blob/master/src/sagemaker/tensorflow/deploying_tensorflow_serving.rst>`_ to learn how to deploy your model and make inference requests.
460450

461451

462452
SageMaker TensorFlow Docker containers

src/sagemaker/tensorflow/deploying_python.rst

-199
This file was deleted.

src/sagemaker/tensorflow/estimator.py

+17-4
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from sagemaker.tensorflow.defaults import TF_VERSION
2727
from sagemaker.tensorflow.model import TensorFlowModel
2828
from sagemaker.tensorflow.serving import Model
29-
from sagemaker.utils import get_config_value
29+
from sagemaker import utils
3030
from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT
3131

3232
logger = logging.getLogger("sagemaker")
@@ -190,9 +190,11 @@ class TensorFlow(Framework):
190190

191191
__framework_name__ = "tensorflow"
192192

193-
LATEST_VERSION = "1.12"
193+
LATEST_VERSION = "1.13"
194194
"""The latest version of TensorFlow included in the SageMaker pre-built Docker images."""
195195

196+
_LOWEST_SCRIPT_MODE_ONLY_VERSION = [1, 13]
197+
196198
def __init__(
197199
self,
198200
training_steps=None,
@@ -321,6 +323,17 @@ def _validate_args(
321323
)
322324
)
323325

326+
if (not self._script_mode_enabled()) and self._only_script_mode_supported():
327+
logger.warning(
328+
"Legacy mode is deprecated in versions 1.13 and higher. Using script mode instead."
329+
)
330+
self.script_mode = True
331+
332+
def _only_script_mode_supported(self):
333+
return [
334+
int(s) for s in self.framework_version.split(".")
335+
] >= self._LOWEST_SCRIPT_MODE_ONLY_VERSION
336+
324337
def _validate_requirements_file(self, requirements_file):
325338
if not requirements_file:
326339
return
@@ -489,7 +502,7 @@ def _create_tfs_model(self, role=None, vpc_config_override=VPC_CONFIG_DEFAULT):
489502
image=self.image_name,
490503
name=self._current_job_name,
491504
container_log_level=self.container_log_level,
492-
framework_version=self.framework_version,
505+
framework_version=utils.get_short_version(self.framework_version),
493506
sagemaker_session=self.sagemaker_session,
494507
vpc_config=self.get_vpc_config(vpc_config_override),
495508
)
@@ -553,7 +566,7 @@ def hyperparameters(self):
553566
return hyperparameters
554567

555568
def _default_s3_path(self, directory, mpi=False):
556-
local_code = get_config_value("local.local_code", self.sagemaker_session.config)
569+
local_code = utils.get_config_value("local.local_code", self.sagemaker_session.config)
557570
if self.sagemaker_session.local_mode and local_code:
558571
return "/opt/ml/shared/{}".format(directory)
559572
elif mpi:

src/sagemaker/utils.py

+12
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,18 @@ def get_config_value(key_path, config):
123123
return current_section
124124

125125

126+
def get_short_version(framework_version):
127+
"""Return short version in the format of x.x
128+
129+
Args:
130+
framework_version: The version string to be shortened.
131+
132+
Returns:
133+
str: The short version string
134+
"""
135+
return ".".join(framework_version.split(".")[:2])
136+
137+
126138
def to_str(value):
127139
"""Convert the input to a string, unless it is a unicode string in Python 2.
128140

tests/data/tensorflow_mnist/mnist.py

-8
Original file line numberDiff line numberDiff line change
@@ -14,17 +14,11 @@
1414

1515
import argparse
1616
import json
17-
import logging as _logging
1817
import numpy as np
1918
import os
20-
import sys as _sys
2119
import tensorflow as tf
22-
from tensorflow.python.platform import tf_logging
2320

2421
tf.logging.set_verbosity(tf.logging.DEBUG)
25-
_handler = _logging.StreamHandler(_sys.stdout)
26-
tf_logger = tf_logging._get_logger()
27-
tf_logger.handlers = [_handler]
2822

2923

3024
def cnn_model_fn(features, labels, mode):
@@ -179,5 +173,3 @@ def serving_input_fn():
179173

180174
if args.current_host == args.hosts[0]:
181175
mnist_classifier.export_savedmodel("/opt/ml/model", serving_input_fn)
182-
183-
tf_logger.info("====== Training finished =========")

tests/integ/test_local_mode.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -85,14 +85,14 @@ def _create_model(output_path):
8585

8686
@pytest.mark.local_mode
8787
@pytest.mark.skipif(PYTHON_VERSION != "py2", reason="TensorFlow image supports only python 2.")
88-
def test_tf_local_mode(tf_full_version, sagemaker_local_session):
88+
def test_tf_local_mode(sagemaker_local_session):
8989
with timeout(minutes=5):
9090
script_path = os.path.join(DATA_DIR, "iris", "iris-dnn-classifier.py")
9191

9292
estimator = TensorFlow(
9393
entry_point=script_path,
9494
role="SageMakerRole",
95-
framework_version=tf_full_version,
95+
framework_version="1.12",
9696
training_steps=1,
9797
evaluation_steps=1,
9898
hyperparameters={"input_tensor_name": "inputs"},
@@ -135,6 +135,7 @@ def test_tf_distributed_local_mode(sagemaker_local_session):
135135
estimator = TensorFlow(
136136
entry_point=script_path,
137137
role="SageMakerRole",
138+
framework_version="1.12",
138139
training_steps=1,
139140
evaluation_steps=1,
140141
hyperparameters={"input_tensor_name": "inputs"},
@@ -176,6 +177,7 @@ def test_tf_local_data(sagemaker_local_session):
176177
estimator = TensorFlow(
177178
entry_point=script_path,
178179
role="SageMakerRole",
180+
framework_version="1.12",
179181
training_steps=1,
180182
evaluation_steps=1,
181183
hyperparameters={"input_tensor_name": "inputs"},
@@ -216,6 +218,7 @@ def test_tf_local_data_local_script():
216218
estimator = TensorFlow(
217219
entry_point=script_path,
218220
role="SageMakerRole",
221+
framework_version="1.12",
219222
training_steps=1,
220223
evaluation_steps=1,
221224
hyperparameters={"input_tensor_name": "inputs"},

0 commit comments

Comments
 (0)