Skip to content

Commit 0da45c3

Browse files
committed
feature: add TensorFlow 1.13 support
1 parent 0346abd commit 0da45c3

File tree

7 files changed

+31
-13
lines changed

7 files changed

+31
-13
lines changed

README.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ TensorFlow SageMaker Estimators
172172

173173
By using TensorFlow SageMaker ``Estimators``, you can train and host TensorFlow models on Amazon SageMaker.
174174

175-
Supported versions of TensorFlow: ``1.4.1``, ``1.5.0``, ``1.6.0``, ``1.7.0``, ``1.8.0``, ``1.9.0``, ``1.10.0``, ``1.11.0``, ``1.12.0``.
175+
Supported versions of TensorFlow: ``1.4.1``, ``1.5.0``, ``1.6.0``, ``1.7.0``, ``1.8.0``, ``1.9.0``, ``1.10.0``, ``1.11.0``, ``1.12.0``, ``1.13.1``.
176176

177177
Supported versions of TensorFlow for Elastic Inference: ``1.11.0``, ``1.12.0``.
178178

src/sagemaker/tensorflow/estimator.py

+12-3
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from sagemaker.tensorflow.defaults import TF_VERSION
2727
from sagemaker.tensorflow.model import TensorFlowModel
2828
from sagemaker.tensorflow.serving import Model
29-
from sagemaker.utils import get_config_value
29+
from sagemaker.utils import get_config_value, get_short_version
3030
from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT
3131

3232
logger = logging.getLogger('sagemaker')
@@ -171,9 +171,11 @@ class TensorFlow(Framework):
171171

172172
__framework_name__ = 'tensorflow'
173173

174-
LATEST_VERSION = '1.12'
174+
LATEST_VERSION = '1.13'
175175
"""The latest version of TensorFlow included in the SageMaker pre-built Docker images."""
176176

177+
_LOWEST_SCRIPT_MODE_ONLY_VERSION = [1, 13]
178+
177179
def __init__(self, training_steps=None, evaluation_steps=None, checkpoint_path=None, py_version='py2',
178180
framework_version=None, model_dir=None, requirements_file='', image_name=None,
179181
script_mode=False, distributions=None, **kwargs):
@@ -276,6 +278,13 @@ def _validate_args(self, py_version, script_mode, framework_version, training_st
276278
.format(', '.join(_FRAMEWORK_MODE_ARGS), ', '.join(found_args))
277279
)
278280

281+
if (not self._script_mode_enabled()) and \
282+
[int(s) for s in self.framework_version.split('.')] >= self._LOWEST_SCRIPT_MODE_ONLY_VERSION:
283+
raise AttributeError(
284+
'Legacy mode is deprecated in versions 1.13 and higher.'
285+
'Please set the script_mode argument to True to use Script Mode'
286+
)
287+
279288
def _validate_requirements_file(self, requirements_file):
280289
if not requirements_file:
281290
return
@@ -427,7 +436,7 @@ def _create_tfs_model(self, role=None, vpc_config_override=VPC_CONFIG_DEFAULT):
427436
image=self.image_name,
428437
name=self._current_job_name,
429438
container_log_level=self.container_log_level,
430-
framework_version=self.framework_version,
439+
framework_version=get_short_version(self.framework_version),
431440
sagemaker_session=self.sagemaker_session,
432441
vpc_config=self.get_vpc_config(vpc_config_override))
433442

src/sagemaker/utils.py

+4
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,10 @@ def get_config_value(key_path, config):
122122
return current_section
123123

124124

125+
def get_short_version(framework_version):
126+
return '.'.join(framework_version.split('.')[:2])
127+
128+
125129
def to_str(value):
126130
"""Convert the input to a string, unless it is a unicode string in Python 2.
127131

tests/conftest.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,8 @@ def sklearn_version(request):
125125

126126
@pytest.fixture(scope='module', params=['1.4', '1.4.1', '1.5', '1.5.0', '1.6', '1.6.0',
127127
'1.7', '1.7.0', '1.8', '1.8.0', '1.9', '1.9.0',
128-
'1.10', '1.10.0', '1.11', '1.11.0', '1.12', '1.12.0'])
128+
'1.10', '1.10.0', '1.11', '1.11.0', '1.12', '1.12.0',
129+
'1.13', '1.13.1'])
129130
def tf_version(request):
130131
return request.param
131132

tests/data/tensorflow_mnist/mnist.py

-8
Original file line numberDiff line numberDiff line change
@@ -14,17 +14,11 @@
1414

1515
import argparse
1616
import json
17-
import logging as _logging
1817
import numpy as np
1918
import os
20-
import sys as _sys
2119
import tensorflow as tf
22-
from tensorflow.python.platform import tf_logging
2320

2421
tf.logging.set_verbosity(tf.logging.DEBUG)
25-
_handler = _logging.StreamHandler(_sys.stdout)
26-
tf_logger = tf_logging._get_logger()
27-
tf_logger.handlers = [_handler]
2822

2923
def cnn_model_fn(features, labels, mode):
3024
"""Model function for CNN."""
@@ -188,5 +182,3 @@ def serving_input_fn():
188182

189183
if args.current_host == args.hosts[0]:
190184
mnist_classifier.export_savedmodel('/opt/ml/model', serving_input_fn)
191-
192-
tf_logger.info('====== Training finished =========')

tests/unit/test_tf_estimator.py

+7
Original file line numberDiff line numberDiff line change
@@ -738,6 +738,13 @@ def test_script_mode_deprecated_args(sagemaker_session):
738738
assert _deprecated_args_msg('training_steps, evaluation_steps, requirements_file, checkpoint_path') in str(e.value)
739739

740740

741+
def test_legacy_mode_deprecation_error(sagemaker_session):
742+
with pytest.raises(AttributeError) as e:
743+
_build_tf(sagemaker_session=sagemaker_session, framework_version='1.13.1',
744+
py_version='py2', script_mode=False)
745+
assert 'Legacy mode is deprecated' in str(e.value)
746+
747+
741748
def test_script_mode_enabled(sagemaker_session):
742749
tf = _build_tf(sagemaker_session=sagemaker_session, py_version='py3')
743750
assert tf._script_mode_enabled() is True

tests/unit/test_utils.py

+5
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,11 @@ def test_get_config_value():
5252
assert sagemaker.utils.get_config_value('other.key', None) is None
5353

5454

55+
def test_get_short_version():
56+
assert sagemaker.utils.get_short_version('1.13.1') == '1.13'
57+
assert sagemaker.utils.get_short_version('1.13') == '1.13'
58+
59+
5560
def test_deferred_error():
5661
de = sagemaker.utils.DeferredError(ImportError("pretend the import failed"))
5762
with pytest.raises(ImportError) as _: # noqa: F841

0 commit comments

Comments
 (0)