Skip to content

Commit 9d0e4ff

Browse files
committed
fix unit tests
1 parent 01b39ae commit 9d0e4ff

File tree

5 files changed

+16
-8
lines changed

5 files changed

+16
-8
lines changed

src/sagemaker/cli/compatibility/v2/ast_transformer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
modifiers.framework_version.FrameworkVersionEnforcer(),
2222
modifiers.tf_legacy_mode.TensorFlowLegacyModeConstructorUpgrader(),
2323
modifiers.tf_legacy_mode.TensorBoardParameterRemover(),
24+
modifiers.deprecated_params.TensorFlowScriptModeParameterRemover(),
2425
]
2526

2627

src/sagemaker/cli/compatibility/v2/modifiers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from __future__ import absolute_import
1515

1616
from sagemaker.cli.compatibility.v2.modifiers import ( # noqa: F401 (imported but unused)
17+
deprecated_params,
1718
framework_version,
1819
tf_legacy_mode,
1920
)

src/sagemaker/cli/compatibility/v2/modifiers/deprecated_params.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from sagemaker.cli.compatibility.v2.modifiers.modifier import Modifier
1919

2020

21-
class TensorFlowScriptModeParamRemover(Modifier):
21+
class TensorFlowScriptModeParameterRemover(Modifier):
2222
"""A class to remove ``script_mode`` from TensorFlow estimators (because it's the only mode)."""
2323

2424
def node_should_be_modified(self, node):

tests/unit/sagemaker/cli/compatibility/v2/modifiers/test_tf_legacy_mode.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
import pasta
1818
import pytest
19-
from mock import patch
19+
from mock import MagicMock, patch
2020

2121
from sagemaker.cli.compatibility.v2.modifiers import tf_legacy_mode
2222
from tests.unit.sagemaker.cli.compatibility.v2.modifiers.ast_converter import ast_call
@@ -118,7 +118,9 @@ def test_modify_node_set_image_name_from_args(create_image_uri, boto_session):
118118
assert expected_string == pasta.dump(node)
119119

120120

121-
def test_modify_node_set_hyperparameters():
121+
@patch("boto3.Session", MagicMock())
122+
@patch("sagemaker.fw_utils.create_image_uri", return_value=IMAGE_URI)
123+
def test_modify_node_set_hyperparameters(create_image_uri):
122124
tf_constructor = """TensorFlow(
123125
checkpoint_path='s3://foo/bar',
124126
training_steps=100,
@@ -140,7 +142,9 @@ def test_modify_node_set_hyperparameters():
140142
assert expected_hyperparameters == _hyperparameters_from_node(node)
141143

142144

143-
def test_modify_node_preserve_other_hyperparameters():
145+
@patch("boto3.Session", MagicMock())
146+
@patch("sagemaker.fw_utils.create_image_uri", return_value=IMAGE_URI)
147+
def test_modify_node_preserve_other_hyperparameters(create_image_uri):
144148
tf_constructor = """sagemaker.tensorflow.TensorFlow(
145149
training_steps=100,
146150
evaluation_steps=10,
@@ -164,7 +168,9 @@ def test_modify_node_preserve_other_hyperparameters():
164168
assert expected_hyperparameters == _hyperparameters_from_node(node)
165169

166170

167-
def test_modify_node_prefer_param_over_hyperparameter():
171+
@patch("boto3.Session", MagicMock())
172+
@patch("sagemaker.fw_utils.create_image_uri", return_value=IMAGE_URI)
173+
def test_modify_node_prefer_param_over_hyperparameter(create_image_uri):
168174
tf_constructor = """sagemaker.tensorflow.TensorFlow(
169175
training_steps=100,
170176
requirements_file='source/requirements.txt',

tests/unit/sagemaker/cli/compatibility/v2/modifiers/test_tf_script_mode.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,15 +24,15 @@ def test_node_should_be_modified_tf_constructor_script_mode():
2424
"sagemaker.tensorflow.TensorFlow(script_mode=True)",
2525
)
2626

27-
modifier = deprecated_params.TensorFlowScriptModeParamRemover()
27+
modifier = deprecated_params.TensorFlowScriptModeParameterRemover()
2828

2929
for constructor in tf_script_mode_constructors:
3030
node = ast_call(constructor)
3131
assert modifier.node_should_be_modified(node) is True
3232

3333

3434
def test_node_should_be_modified_not_tf_script_mode():
35-
modifier = deprecated_params.TensorFlowScriptModeParamRemover()
35+
modifier = deprecated_params.TensorFlowScriptModeParameterRemover()
3636

3737
for call in ("TensorFlow()", "random()"):
3838
node = ast_call(call)
@@ -41,7 +41,7 @@ def test_node_should_be_modified_not_tf_script_mode():
4141

4242
def test_modify_node():
4343
node = ast_call("TensorFlow(script_mode=True)")
44-
modifier = deprecated_params.TensorFlowScriptModeParamRemover()
44+
modifier = deprecated_params.TensorFlowScriptModeParameterRemover()
4545
modifier.modify_node(node)
4646

4747
assert "TensorFlow()" == pasta.dump(node)

0 commit comments

Comments
 (0)