Skip to content

Commit ab405bc

Browse files
committed
address PR comment
1 parent 6da2454 commit ab405bc

File tree

4 files changed

+32
-15
lines changed

4 files changed

+32
-15
lines changed

src/sagemaker/cli/compatibility/v2/modifiers/framework_version.py

+8-8
Original file line numberDiff line numberDiff line change
@@ -98,14 +98,14 @@ def modify_node(self, node):
9898
framework, is_model = _framework_from_node(node)
9999

100100
# if framework_version is not supplied, get default and append keyword
101-
framework_version = parsing.arg_value(node, FRAMEWORK_ARG)
102-
if framework_version is None:
101+
if matching.has_arg(node, FRAMEWORK_ARG):
102+
framework_version = parsing.arg_value(node, FRAMEWORK_ARG)
103+
else:
103104
framework_version = FRAMEWORK_DEFAULTS[framework]
104105
node.keywords.append(ast.keyword(arg=FRAMEWORK_ARG, value=ast.Str(s=framework_version)))
105106

106107
# if py_version is not supplied, get a conditional default, and if not None, append keyword
107-
py_version = parsing.arg_value(node, PY_ARG)
108-
if py_version is None:
108+
if not matching.has_arg(node, PY_ARG):
109109
py_version = _py_version_defaults(framework, framework_version, is_model)
110110
if py_version:
111111
node.keywords.append(ast.keyword(arg=PY_ARG, value=ast.Str(s=py_version)))
@@ -175,13 +175,13 @@ def _version_args_needed(node, image_arg):
175175
Applies similar logic as ``validate_version_or_image_args``
176176
"""
177177
# if image_arg is present, no need to supply version arguments
178-
image_name = parsing.arg_value(node, image_arg)
179-
if image_name:
178+
if matching.has_arg(node, image_arg):
180179
return False
181180

182181
# if framework_version is None, need args
183-
framework_version = parsing.arg_value(node, FRAMEWORK_ARG)
184-
if framework_version is None:
182+
if matching.has_arg(node, FRAMEWORK_ARG):
183+
framework_version = parsing.arg_value(node, FRAMEWORK_ARG)
184+
else:
185185
return True
186186

187187
# check if we expect py_version and we don't get it -- framework and model dependent

src/sagemaker/cli/compatibility/v2/modifiers/matching.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -116,4 +116,7 @@ def has_arg(node, arg):
116116
Returns:
117117
bool: if the node has the given argument.
118118
"""
119-
return parsing.arg_value(node, arg) is not None
119+
try:
120+
return parsing.arg_value(node, arg) is not None
121+
except KeyError:
122+
return False

src/sagemaker/cli/compatibility/v2/modifiers/parsing.py

+9-4
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
"""Functions for parsing AST nodes."""
1414
from __future__ import absolute_import
1515

16+
import pasta
17+
1618

1719
def arg_from_keywords(node, arg):
1820
"""Retrieves a keyword argument from the node's keywords.
@@ -41,10 +43,13 @@ def arg_value(node, arg):
4143
arg (str): the name of the argument.
4244
4345
Returns:
44-
obj: the keyword argument's value if it is present. Otherwise, this returns ``None``.
46+
obj: the keyword argument's value.
47+
48+
Raises:
49+
KeyError: if the node's keywords do not contain the argument.
4550
"""
4651
keyword = arg_from_keywords(node, arg)
47-
if keyword and keyword.value:
48-
return getattr(keyword.value, keyword.value._fields[0], None)
52+
if keyword is None:
53+
raise KeyError("arg '{}' not found in call: {}".format(arg, pasta.dump(node)))
4954

50-
return None
55+
return getattr(keyword.value, keyword.value._fields[0], None) if keyword.value else None

tests/unit/sagemaker/cli/compatibility/v2/modifiers/test_parsing.py

+11-2
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
# language governing permissions and limitations under the License.
1313
from __future__ import absolute_import
1414

15+
import pytest
16+
1517
from sagemaker.cli.compatibility.v2.modifiers import parsing
1618
from tests.unit.sagemaker.cli.compatibility.v2.modifiers.ast_converter import ast_call
1719

@@ -45,7 +47,14 @@ def test_arg_value():
4547
call = ast_call("MXNet(enable_network_isolation=True)")
4648
assert parsing.arg_value(call, "enable_network_isolation") is True
4749

50+
call = ast_call("MXNet(source_dir=None)")
51+
assert parsing.arg_value(call, "source_dir") is None
52+
4853

4954
def test_arg_value_absent_keyword():
50-
call = ast_call("MXNet(entry_point='run')")
51-
assert parsing.arg_value(call, "framework_version") is None
55+
code = "MXNet(entry_point='run')"
56+
57+
with pytest.raises(KeyError) as e:
58+
parsing.arg_value(ast_call(code), "framework_version")
59+
60+
assert "arg 'framework_version' not found in call: {}".format(code) in str(e.value)

0 commit comments

Comments
 (0)