17
17
18
18
from sagemaker .cli .compatibility .v2 .modifiers .modifier import Modifier
19
19
20
+ FRAMEWORK_ARG = "framework_version"
21
+ PY_ARG = "py_version"
22
+
20
23
FRAMEWORK_DEFAULTS = {
21
24
"Chainer" : "4.1.0" ,
22
25
"MXNet" : "1.2.0" ,
25
28
"TensorFlow" : "1.11.0" ,
26
29
}
27
30
28
- FRAMEWORKS = list (FRAMEWORK_DEFAULTS .keys ())
31
+ FRAMEWORK_CLASSES = list (FRAMEWORK_DEFAULTS .keys ())
32
+ MODEL_CLASSES = ["{}Model" .format (fw ) for fw in FRAMEWORK_CLASSES ]
33
+
29
34
# TODO: check for sagemaker.tensorflow.serving.Model
30
- FRAMEWORK_CLASSES = FRAMEWORKS + ["{}Model" .format (fw ) for fw in FRAMEWORKS ]
31
- FRAMEWORK_MODULES = [fw .lower () for fw in FRAMEWORKS ]
35
+ FRAMEWORK_MODULES = [fw .lower () for fw in FRAMEWORK_CLASSES ]
32
36
FRAMEWORK_SUBMODULES = ("model" , "estimator" )
33
37
34
38
@@ -39,7 +43,8 @@ class FrameworkVersionEnforcer(Modifier):
39
43
40
44
def node_should_be_modified (self , node ):
41
45
"""Checks if the ast.Call node instantiates a framework estimator or model,
42
- but doesn't specify the ``framework_version`` parameter.
46
+ but doesn't specify the ``framework_version`` and ``py_version`` parameter,
47
+ as appropriate.
43
48
44
49
This looks for the following formats:
45
50
@@ -56,49 +61,12 @@ def node_should_be_modified(self, node):
56
61
bool: If the ``ast.Call`` is instantiating a framework class that
57
62
should specify ``framework_version``, but doesn't.
58
63
"""
59
- if self . _is_framework_constructor (node ):
60
- return not self . _fw_version_in_keywords (node )
64
+ if _is_named_constructor (node , FRAMEWORK_CLASSES ):
65
+ return _version_args_needed (node , "image_name" )
61
66
62
- return False
67
+ if _is_named_constructor (node , MODEL_CLASSES ):
68
+ return _version_args_needed (node , "image" )
63
69
64
- def _is_framework_constructor (self , node ):
65
- """Checks if the ``ast.Call`` node represents a call of the form
66
- <Framework> or sagemaker.<framework>.<Framework>.
67
- """
68
- # Check for <Framework> call
69
- if isinstance (node .func , ast .Name ):
70
- return node .func .id in FRAMEWORK_CLASSES
71
-
72
- # Check for something.that.ends.with.<framework>.<Framework> call
73
- if not (isinstance (node .func , ast .Attribute ) and node .func .attr in FRAMEWORK_CLASSES ):
74
- return False
75
-
76
- # Check for sagemaker.<frameworks>.<estimator/model>.<Framework> call
77
- if (
78
- isinstance (node .func .value , ast .Attribute )
79
- and node .func .value .attr in FRAMEWORK_SUBMODULES
80
- ):
81
- return self ._is_in_framework_module (node .func .value )
82
-
83
- # Check for sagemaker.<framework>.<Framework> call
84
- return self ._is_in_framework_module (node .func )
85
-
86
- def _is_in_framework_module (self , node ):
87
- """Checks if the node is an ``ast.Attribute`` that represents a
88
- ``sagemaker.<framework>`` module.
89
- """
90
- return (
91
- isinstance (node .value , ast .Attribute )
92
- and node .value .attr in FRAMEWORK_MODULES
93
- and isinstance (node .value .value , ast .Name )
94
- and node .value .value .id == "sagemaker"
95
- )
96
-
97
- def _fw_version_in_keywords (self , node ):
98
- """Checks if the ``ast.Call`` node's keywords contain ``framework_version``."""
99
- for kw in node .keywords :
100
- if kw .arg == "framework_version" and kw .value :
101
- return True
102
70
return False
103
71
104
72
def modify_node (self , node ):
@@ -112,30 +80,146 @@ def modify_node(self, node):
112
80
- SKLearn: "0.20.0"
113
81
- TensorFlow: "1.11.0"
114
82
83
+ The ``py_version`` value is determined by the framework, framework_version, and if it is a
84
+ model, whether the model accepts a py_version
85
+
115
86
Args:
116
87
node (ast.Call): a node that represents the constructor of a framework class.
117
88
"""
118
- framework = self ._framework_name_from_node (node )
119
- node .keywords .append (
120
- ast .keyword (arg = "framework_version" , value = ast .Str (s = FRAMEWORK_DEFAULTS [framework ]))
121
- )
89
+ framework , is_model = _framework_from_node (node )
122
90
123
- def _framework_name_from_node (self , node ):
124
- """Retrieves the framework name based on the function call.
91
+ # if framework_version is not supplied, get default and append keyword
92
+ framework_version = _arg_value (node , FRAMEWORK_ARG )
93
+ if framework_version is None :
94
+ framework_version = FRAMEWORK_DEFAULTS [framework ]
95
+ node .keywords .append (ast .keyword (arg = FRAMEWORK_ARG , value = ast .Str (s = framework_version )))
125
96
126
- Args:
127
- node (ast.Call): a node that represents the constructor of a framework class.
128
- This can represent either <Framework> or sagemaker.<framework>.<Framework>.
97
+ # if py_version is not supplied, get a conditional default, and if not None, append keyword
98
+ py_version = _arg_value (node , PY_ARG )
99
+ if py_version is None :
100
+ py_version = _py_version_defaults (framework , framework_version , is_model )
101
+ if py_version :
102
+ node .keywords .append (ast .keyword (arg = PY_ARG , value = ast .Str (s = py_version )))
129
103
130
- Returns:
131
- str: the (capitalized) framework name.
132
- """
133
- if isinstance (node .func , ast .Name ):
134
- framework = node .func .id
135
- elif isinstance (node .func , ast .Attribute ):
136
- framework = node .func .attr
137
104
138
- if framework .endswith ("Model" ):
139
- framework = framework [: framework .find ("Model" )]
105
+ def _py_version_defaults (framework , framework_version , is_model = False ):
106
+ """Gets the py_version required for the framework_version and if it's a model
107
+
108
+ Args:
109
+ framework (str): name of the framework
110
+ framework_version (str): version of the framework
111
+ is_model (bool): whether it is a constructor for a model or not
112
+
113
+ Returns:
114
+ str: the default py version, as appropriate. None if no default py_version
115
+ """
116
+ if framework in ("Chainer" , "PyTorch" ):
117
+ return "py3"
118
+ if framework == "SKLearn" and not is_model :
119
+ return "py3"
120
+ if framework == "MXNet" :
121
+ return "py2"
122
+ if framework == "TensorFlow" and not is_model :
123
+ return _tf_py_version_default (framework_version )
124
+ return None
125
+
126
+
127
+ def _tf_py_version_default (framework_version ):
128
+ """Gets the py_version default based on framework_version for TensorFlow."""
129
+ if not framework_version :
130
+ return "py2"
131
+ version = [int (s ) for s in framework_version .split ("." )]
132
+ if version < [1 , 12 ]:
133
+ return "py2"
134
+ if version < [2 , 2 ]:
135
+ return "py3"
136
+ return "py37"
137
+
138
+
139
+ def _framework_from_node (node ):
140
+ """Retrieves the framework class name based on the function call, and if it was a model
141
+
142
+ Args:
143
+ node (ast.Call): a node that represents the constructor of a framework class.
144
+ This can represent either <Framework> or sagemaker.<framework>.<Framework>.
145
+
146
+ Returns:
147
+ str, bool: the (capitalized) framework class name, and if it is a model class
148
+ """
149
+ if isinstance (node .func , ast .Name ):
150
+ framework = node .func .id
151
+ elif isinstance (node .func , ast .Attribute ):
152
+ framework = node .func .attr
153
+ else :
154
+ framework = ""
155
+
156
+ is_model = framework .endswith ("Model" )
157
+ if is_model :
158
+ framework = framework [: framework .find ("Model" )]
159
+
160
+ return framework , is_model
161
+
162
+
163
+ def _is_named_constructor (node , names ):
164
+ """Checks if the ``ast.Call`` node represents a call to particular named constructors.
165
+
166
+ Forms that qualify are either <Framework> or sagemaker.<framework>.<Framework>
167
+ where <Framework> belongs to the list of names passed in.
168
+ """
169
+ # Check for call from particular names of constructors
170
+ if isinstance (node .func , ast .Name ):
171
+ return node .func .id in names
172
+
173
+ # Check for something.that.ends.with.<framework>.<Framework> call for Framework in names
174
+ if not (isinstance (node .func , ast .Attribute ) and node .func .attr in names ):
175
+ return False
176
+
177
+ # Check for sagemaker.<frameworks>.<estimator/model>.<Framework> call
178
+ if isinstance (node .func .value , ast .Attribute ) and node .func .value .attr in FRAMEWORK_SUBMODULES :
179
+ return _is_in_framework_module (node .func .value )
180
+
181
+ # Check for sagemaker.<framework>.<Framework> call
182
+ return _is_in_framework_module (node .func )
183
+
184
+
185
+ def _is_in_framework_module (node ):
186
+ """Checks if node is an ``ast.Attribute`` representing a ``sagemaker.<framework>`` module."""
187
+ return (
188
+ isinstance (node .value , ast .Attribute )
189
+ and node .value .attr in FRAMEWORK_MODULES
190
+ and isinstance (node .value .value , ast .Name )
191
+ and node .value .value .id == "sagemaker"
192
+ )
193
+
194
+
195
+ def _version_args_needed (node , image_arg ):
196
+ """Determines if image_arg or version_arg was supplied
197
+
198
+ Applies similar logic as ``validate_version_or_image_args``
199
+ """
200
+ # if image_arg is present, no need to supply version arguments
201
+ image_name = _arg_value (node , image_arg )
202
+ if image_name :
203
+ return False
204
+
205
+ # if framework_version is None, need args
206
+ framework_version = _arg_value (node , FRAMEWORK_ARG )
207
+ if framework_version is None :
208
+ return True
209
+
210
+ # check if we expect py_version and we don't get it -- framework and model dependent
211
+ framework , is_model = _framework_from_node (node )
212
+ expecting_py_version = _py_version_defaults (framework , framework_version , is_model )
213
+ if expecting_py_version :
214
+ py_version = _arg_value (node , PY_ARG )
215
+ return py_version is None
216
+
217
+ return False
218
+
140
219
141
- return framework
220
+ def _arg_value (node , arg ):
221
+ """Gets the value associated with the arg keyword, if present"""
222
+ for kw in node .keywords :
223
+ if kw .arg == arg and kw .value :
224
+ return kw .value .s
225
+ return None
0 commit comments