Skip to content

Commit a3bce6e

Browse files
committed
reformatting
1 parent c2b6f94 commit a3bce6e

File tree

4 files changed

+22
-10
lines changed

4 files changed

+22
-10
lines changed

src/sagemaker/tensorflow/estimator.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,7 @@ class TensorFlow(Framework):
190190

191191
__framework_name__ = "tensorflow"
192192

193-
LATEST_VERSION = '1.13'
193+
LATEST_VERSION = "1.13"
194194
"""The latest version of TensorFlow included in the SageMaker pre-built Docker images."""
195195

196196
_LOWEST_SCRIPT_MODE_ONLY_VERSION = [1, 13]
@@ -324,11 +324,15 @@ def _validate_args(
324324
)
325325

326326
if (not self._script_mode_enabled()) and self._only_script_mode_supported():
327-
logger.warning('Legacy mode is deprecated in versions 1.13 and higher. Using script mode instead.')
327+
logger.warning(
328+
"Legacy mode is deprecated in versions 1.13 and higher. Using script mode instead."
329+
)
328330
self.script_mode = True
329331

330332
def _only_script_mode_supported(self):
331-
return [int(s) for s in self.framework_version.split('.')] >= self._LOWEST_SCRIPT_MODE_ONLY_VERSION
333+
return [
334+
int(s) for s in self.framework_version.split(".")
335+
] >= self._LOWEST_SCRIPT_MODE_ONLY_VERSION
332336

333337
def _validate_requirements_file(self, requirements_file):
334338
if not requirements_file:

src/sagemaker/utils.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ def get_short_version(framework_version):
132132
Returns:
133133
str: The short version string
134134
"""
135-
return '.'.join(framework_version.split('.')[:2])
135+
return ".".join(framework_version.split(".")[:2])
136136

137137

138138
def to_str(value):

tests/unit/test_tf_estimator.py

+12-4
Original file line numberDiff line numberDiff line change
@@ -871,12 +871,20 @@ def test_script_mode_deprecated_args(sagemaker_session):
871871

872872

873873
def test_legacy_mode_deprecated(sagemaker_session):
874-
tf = _build_tf(sagemaker_session=sagemaker_session, framework_version='1.13.1',
875-
py_version='py2', script_mode=False)
874+
tf = _build_tf(
875+
sagemaker_session=sagemaker_session,
876+
framework_version="1.13.1",
877+
py_version="py2",
878+
script_mode=False,
879+
)
876880
assert tf._script_mode_enabled() is True
877881

878-
tf = _build_tf(sagemaker_session=sagemaker_session, framework_version='1.12',
879-
py_version='py2', script_mode=False)
882+
tf = _build_tf(
883+
sagemaker_session=sagemaker_session,
884+
framework_version="1.12",
885+
py_version="py2",
886+
script_mode=False,
887+
)
880888
assert tf._script_mode_enabled() is False
881889

882890

tests/unit/test_utils.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,8 @@ def test_get_config_value():
4848

4949

5050
def test_get_short_version():
51-
assert sagemaker.utils.get_short_version('1.13.1') == '1.13'
52-
assert sagemaker.utils.get_short_version('1.13') == '1.13'
51+
assert sagemaker.utils.get_short_version("1.13.1") == "1.13"
52+
assert sagemaker.utils.get_short_version("1.13") == "1.13"
5353

5454

5555
def test_deferred_error():

0 commit comments

Comments
 (0)