Skip to content

Commit 6f3cf42

Browse files
authored
fix: add logic to use asimov image for TF 1.14 py2 (aws#997)
* fix: add logic to use asimov image for TF 1.14 py2
1 parent fa79418 commit 6f3cf42

File tree

5 files changed

+32
-13
lines changed

5 files changed

+32
-13
lines changed

doc/using_tf.rst

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,6 @@ models on SageMaker Hosting.
88

99
For general information about using the SageMaker Python SDK, see :ref:`overview:Using the SageMaker Python SDK`.
1010

11-
.. warning::
12-
The TensorFlow estimator is available only for Python 3, starting by the TensorFlow version 1.14.
13-
1411
.. warning::
1512
We have added a new format of your TensorFlow training script with TensorFlow version 1.11.
1613
This new way gives the user script more flexibility.

src/sagemaker/fw_utils.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,26 @@ def _using_merged_images(region, framework, py_version, accelerator_type, framew
113113
is_gov_region = region in VALID_ACCOUNTS_BY_REGION
114114
is_py3 = py_version == "py3" or py_version is None
115115
is_merged_versions = _is_merged_versions(framework, framework_version)
116-
return (not is_gov_region) and is_merged_versions and is_py3 and accelerator_type is None
116+
return (
117+
(not is_gov_region)
118+
and is_merged_versions
119+
and (is_py3 or _is_tf_14_or_later(framework, framework_version))
120+
and accelerator_type is None
121+
)
122+
123+
124+
def _is_tf_14_or_later(framework, framework_version):
125+
"""
126+
Args:
127+
framework:
128+
framework_version:
129+
"""
130+
# Asimov team now owns Tensorflow 1.14.0 py2 and py3
131+
asimov_lowest_tf_py2 = [1, 14, 0]
132+
version = [int(s) for s in framework_version.split(".")]
133+
return (
134+
framework == "tensorflow-scriptmode" and version >= asimov_lowest_tf_py2[0 : len(version)]
135+
)
117136

118137

119138
def _registry_id(region, framework, py_version, account, accelerator_type, framework_version):

src/sagemaker/tensorflow/estimator.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,8 @@ class TensorFlow(Framework):
199199
"""The latest version of TensorFlow included in the SageMaker pre-built Docker images."""
200200

201201
_LOWEST_SCRIPT_MODE_ONLY_VERSION = [1, 13]
202+
# 1.14.0 now supports py2
203+
# we will need to update this version number if future versions do not support py2 anymore
202204
_LOWEST_PYTHON_2_ONLY_VERSION = [1, 14]
203205

204206
def __init__(
@@ -343,7 +345,7 @@ def _validate_args(
343345

344346
if py_version == "py2" and self._only_python_3_supported():
345347
msg = (
346-
"Python 2 containers are only available until TensorFlow version 1.13.1. "
348+
"Python 2 containers are only available until TensorFlow version 1.14.0. "
347349
"Please use a Python 3 container."
348350
)
349351
raise AttributeError(msg)

tests/unit/test_fw_utils.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,13 @@ def test_create_image_uri_merged_py2():
187187
== "520713654638.dkr.ecr.us-west-2.amazonaws.com/sagemaker-tensorflow-scriptmode:1.13.1-gpu-py2"
188188
)
189189

190+
image_uri = fw_utils.create_image_uri(
191+
"us-west-2", "tensorflow-scriptmode", "ml.p3.2xlarge", "1.14", "py2"
192+
)
193+
assert (
194+
image_uri == "763104351884.dkr.ecr.us-west-2.amazonaws.com/tensorflow-training:1.14-gpu-py2"
195+
)
196+
190197
image_uri = fw_utils.create_image_uri("us-west-2", "mxnet", "ml.p3.2xlarge", "1.4.1", "py2")
191198
assert image_uri == "520713654638.dkr.ecr.us-west-2.amazonaws.com/sagemaker-mxnet:1.4.1-gpu-py2"
192199

tests/unit/test_tf_estimator.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -957,18 +957,12 @@ def test_script_mode_deprecated_args(sagemaker_session):
957957

958958
def test_py2_version_deprecated(sagemaker_session):
959959
with pytest.raises(AttributeError) as e:
960-
_build_tf(sagemaker_session=sagemaker_session, framework_version="1.14", py_version="py2")
960+
_build_tf(sagemaker_session=sagemaker_session, framework_version="1.14.1", py_version="py2")
961961

962-
msg = "Python 2 containers are only available until TensorFlow version 1.13.1. Please use a Python 3 container."
962+
msg = "Python 2 containers are only available until TensorFlow version 1.14.0. Please use a Python 3 container."
963963
assert msg in str(e.value)
964964

965965

966-
def test_py3_is_default_version_after_tf1_14(sagemaker_session):
967-
estimator = _build_tf(sagemaker_session=sagemaker_session, framework_version="1.14")
968-
969-
assert estimator.py_version == "py3"
970-
971-
972966
def test_py3_is_default_version_before_tf1_14(sagemaker_session):
973967
estimator = _build_tf(sagemaker_session=sagemaker_session, framework_version="1.13")
974968

0 commit comments

Comments
 (0)