Skip to content

Commit 07cc35c

Browse files
committed
feature: support for TensorFlow 1.14
1 parent 40fb338 commit 07cc35c

File tree

5 files changed

+54
-4
lines changed

5 files changed

+54
-4
lines changed

README.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,7 @@ TensorFlow SageMaker Estimators
189189

190190
By using TensorFlow SageMaker Estimators, you can train and host TensorFlow models on Amazon SageMaker.
191191

192-
Supported versions of TensorFlow: ``1.4.1``, ``1.5.0``, ``1.6.0``, ``1.7.0``, ``1.8.0``, ``1.9.0``, ``1.10.0``, ``1.11.0``, ``1.12.0``, ``1.13.1``.
192+
Supported versions of TensorFlow: ``1.4.1``, ``1.5.0``, ``1.6.0``, ``1.7.0``, ``1.8.0``, ``1.9.0``, ``1.10.0``, ``1.11.0``, ``1.12.0``, ``1.13.1``, ``1.14``.
193193

194194
Supported versions of TensorFlow for Elastic Inference: ``1.11.0``, ``1.12.0``, ``1.13.0``
195195

doc/using_tf.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,9 @@ models on SageMaker Hosting.
88

99
For general information about using the SageMaker Python SDK, see :ref:`overview:Using the SageMaker Python SDK`.
1010

11+
.. warning::
12+
The TensorFlow estimator is available only for Python 3, starting by the TensorFlow version 1.14.
13+
1114
.. warning::
1215
We have added a new format of your TensorFlow training script with TensorFlow version 1.11.
1316
This new way gives the user script more flexibility.

src/sagemaker/tensorflow/estimator.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -195,17 +195,18 @@ class TensorFlow(Framework):
195195

196196
__framework_name__ = "tensorflow"
197197

198-
LATEST_VERSION = "1.13"
198+
LATEST_VERSION = "1.14"
199199
"""The latest version of TensorFlow included in the SageMaker pre-built Docker images."""
200200

201201
_LOWEST_SCRIPT_MODE_ONLY_VERSION = [1, 13]
202+
_LOWEST_PYTHON_2_ONLY_VERSION = [1, 14]
202203

203204
def __init__(
204205
self,
205206
training_steps=None,
206207
evaluation_steps=None,
207208
checkpoint_path=None,
208-
py_version="py2",
209+
py_version=None,
209210
framework_version=None,
210211
model_dir=None,
211212
requirements_file="",
@@ -279,6 +280,9 @@ def __init__(
279280
logger.warning(fw.empty_framework_version_warning(TF_VERSION, self.LATEST_VERSION))
280281
self.framework_version = framework_version or TF_VERSION
281282

283+
if not py_version:
284+
py_version = "py3" if self._only_python_3_supported() else "py2"
285+
282286
super(TensorFlow, self).__init__(image_name=image_name, **kwargs)
283287
self.checkpoint_path = checkpoint_path
284288

@@ -337,6 +341,13 @@ def _validate_args(
337341
)
338342
)
339343

344+
if py_version == "py2" and self._only_python_3_supported():
345+
msg = (
346+
"Python 2 containers are only available until TensorFlow version 1.13.1. "
347+
"Please use a Python 3 container."
348+
)
349+
raise AttributeError(msg)
350+
340351
if (not self._script_mode_enabled()) and self._only_script_mode_supported():
341352
logger.warning(
342353
"Legacy mode is deprecated in versions 1.13 and higher. Using script mode instead."
@@ -349,6 +360,12 @@ def _only_script_mode_supported(self):
349360
int(s) for s in self.framework_version.split(".")
350361
] >= self._LOWEST_SCRIPT_MODE_ONLY_VERSION
351362

363+
def _only_python_3_supported(self):
364+
"""Placeholder docstring"""
365+
return [
366+
int(s) for s in self.framework_version.split(".")
367+
] >= self._LOWEST_PYTHON_2_ONLY_VERSION
368+
352369
def _validate_requirements_file(self, requirements_file):
353370
"""Placeholder docstring"""
354371
if not requirements_file:

tests/unit/test_fw_utils.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,13 @@ def test_create_image_uri_gov_cloud():
137137

138138

139139
def test_create_image_uri_merged():
140+
image_uri = fw_utils.create_image_uri(
141+
"us-west-2", "tensorflow-scriptmode", "ml.p3.2xlarge", "1.14", "py3"
142+
)
143+
assert (
144+
image_uri == "763104351884.dkr.ecr.us-west-2.amazonaws.com/tensorflow-training:1.14-gpu-py3"
145+
)
146+
140147
image_uri = fw_utils.create_image_uri(
141148
"us-west-2", "tensorflow-scriptmode", "ml.p3.2xlarge", "1.13.1", "py3"
142149
)

tests/unit/test_tf_estimator.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
from sagemaker.tensorflow import defaults, TensorFlow, TensorFlowModel, TensorFlowPredictor
2626
import sagemaker.tensorflow.estimator as tfe
2727

28-
2928
DATA_DIR = os.path.join(os.path.dirname(__file__), "..", "data")
3029
SCRIPT_FILE = "dummy_script.py"
3130
SCRIPT_PATH = os.path.join(DATA_DIR, SCRIPT_FILE)
@@ -956,6 +955,30 @@ def test_script_mode_deprecated_args(sagemaker_session):
956955
) in str(e.value)
957956

958957

958+
def test_py2_version_deprecated(sagemaker_session):
959+
with pytest.raises(AttributeError) as e:
960+
_build_tf(sagemaker_session=sagemaker_session, framework_version="1.14", py_version="py2")
961+
962+
msg = "Python 2 containers are only available until TensorFlow version 1.13.1. Please use a Python 3 container."
963+
assert msg in str(e.value)
964+
965+
966+
def test_py3_is_default_version_after_tf1_14(sagemaker_session):
967+
estimator = _build_tf(sagemaker_session=sagemaker_session, framework_version="1.14")
968+
969+
assert estimator.py_version == "py3"
970+
971+
972+
def test_py3_is_default_version_before_tf1_14(sagemaker_session):
973+
estimator = _build_tf(sagemaker_session=sagemaker_session, framework_version="1.13")
974+
975+
assert estimator.py_version == "py2"
976+
977+
estimator = _build_tf(sagemaker_session=sagemaker_session, framework_version="1.10")
978+
979+
assert estimator.py_version == "py2"
980+
981+
959982
def test_legacy_mode_deprecated(sagemaker_session):
960983
tf = _build_tf(
961984
sagemaker_session=sagemaker_session,

0 commit comments

Comments
 (0)