32
32
33
33
34
34
class FrameworkVersionEnforcer (Modifier ):
35
+ """A class to ensure that ``framework_version`` is defined when
36
+ instantiating a framework estimator or model.
37
+ """
38
+
35
39
def node_should_be_modified (self , node ):
36
- """Check if the ast.Call node instantiates a framework estimator or model,
37
- but doesn't specify the framework_version parameter.
40
+ """Checks if the ast.Call node instantiates a framework estimator or model,
41
+ but doesn't specify the `` framework_version`` parameter.
38
42
39
43
This looks for the following formats:
40
44
@@ -57,34 +61,37 @@ def node_should_be_modified(self, node):
57
61
return False
58
62
59
63
def _is_framework_constructor (self , node ):
60
- """Check if the ``ast.Call`` node represents a call of the form
64
+ """Checks if the ``ast.Call`` node represents a call of the form
61
65
<Framework> or sagemaker.<framework>.<Framework>.
62
66
"""
67
+ # Check for <Framework> call
63
68
if isinstance (node .func , ast .Name ):
64
69
if node .func .id in FRAMEWORK_CLASSES :
65
70
return True
66
71
67
- if (
68
- isinstance (node .func , ast .Attribute )
69
- and node .func .attr in FRAMEWORK_CLASSES
70
- and isinstance (node .func .value , ast .Attribute )
72
+ # Check for sagemaker.<framework>.<Framework> call
73
+ ends_with_framework_constructor = (
74
+ isinstance (node .func , ast .Attribute ) and node .func .attr in FRAMEWORK_CLASSES
75
+ )
76
+
77
+ is_in_framework_module = (
78
+ isinstance (node .func .value , ast .Attribute )
71
79
and node .func .value .attr in FRAMEWORK_MODULES
72
80
and isinstance (node .func .value .value , ast .Name )
73
81
and node .func .value .value .id == "sagemaker"
74
- ):
75
- return True
82
+ )
76
83
77
- return False
84
+ return ends_with_framework_constructor and is_in_framework_module
78
85
79
86
def _fw_version_in_keywords (self , node ):
80
- """Check if the ``ast.Call`` node's keywords contain ``framework_version``."""
87
+ """Checks if the ``ast.Call`` node's keywords contain ``framework_version``."""
81
88
for kw in node .keywords :
82
89
if kw .arg == "framework_version" and kw .value :
83
90
return True
84
91
return False
85
92
86
93
def modify_node (self , node ):
87
- """Modify the ``ast.Call`` node's keywords to include ``framework_version``.
94
+ """Modifies the ``ast.Call`` node's keywords to include ``framework_version``.
88
95
89
96
The ``framework_version`` value is determined by the framework:
90
97
@@ -103,7 +110,7 @@ def modify_node(self, node):
103
110
)
104
111
105
112
def _framework_name_from_node (self , node ):
106
- """Retrieve the framework name based on the function call.
113
+ """Retrieves the framework name based on the function call.
107
114
108
115
Args:
109
116
node (ast.Call): a node that represents the constructor of a framework class.
0 commit comments