104
104
"pytorch-serving-eia" : [1 , 3 , 1 ],
105
105
}
106
106
107
+ INFERENTIA_VERSION_RANGES = {
108
+ "neo-mxnet" : [[1 , 5 , 1 ], [1 , 5 , 1 ]],
109
+ "neo-tensorflow" : [[1 , 15 , 0 ], [1 , 15 , 0 ]],
110
+ }
111
+
112
+ INFERENTIA_SUPPORTED_REGIONS = ["us-east-1" , "us-west-2" ]
113
+
107
114
DEBUGGER_UNSUPPORTED_REGIONS = ["us-gov-west-1" , "us-iso-east-1" ]
108
115
109
116
@@ -124,6 +131,23 @@ def is_version_equal_or_higher(lowest_version, framework_version):
124
131
return version_list >= lowest_version [0 : len (version_list )]
125
132
126
133
134
+ def is_version_equal_or_lower (highest_version , framework_version ):
135
+ """Determine whether the ``framework_version`` is equal to or lower than
136
+ ``highest_version``
137
+
138
+ Args:
139
+ highest_version (List[int]): highest version represented in an integer
140
+ list
141
+ framework_version (str): framework version string
142
+
143
+ Returns:
144
+ bool: Whether or not ``framework_version`` is equal to or lower than
145
+ ``highest_version``
146
+ """
147
+ version_list = [int (s ) for s in framework_version .split ("." )]
148
+ return version_list <= highest_version [0 : len (version_list )]
149
+
150
+
127
151
def _is_dlc_version (framework , framework_version , py_version ):
128
152
"""Return if the framework's version uses the corresponding DLC image.
129
153
@@ -144,6 +168,23 @@ def _is_dlc_version(framework, framework_version, py_version):
144
168
return False
145
169
146
170
171
+ def _is_inferentia_supported (framework , framework_version ):
172
+ """Return if Inferentia supports the framework and its version.
173
+
174
+ Args:
175
+ framework (str): The framework name, e.g. "tensorflow"
176
+ framework_version (str): The framework version
177
+
178
+ Returns:
179
+ bool: Whether or not Inferentia supports the framework and its version.
180
+ """
181
+ lowest_version_list = INFERENTIA_VERSION_RANGES .get (framework )[0 ]
182
+ highest_version_list = INFERENTIA_VERSION_RANGES .get (framework )[1 ]
183
+ return is_version_equal_or_higher (
184
+ lowest_version_list , framework_version
185
+ ) and is_version_equal_or_lower (highest_version_list , framework_version )
186
+
187
+
147
188
def _registry_id (region , framework , py_version , account , framework_version ):
148
189
"""Return the Amazon ECR registry number (or AWS account ID) for
149
190
the given framework, framework version, Python version, and region.
@@ -240,11 +281,34 @@ def create_image_uri(
240
281
# 'cpu' or 'gpu'.
241
282
if family in optimized_families :
242
283
device_type = family
284
+ elif family .startswith ("inf" ):
285
+ device_type = "inf"
243
286
elif family [0 ] in ["g" , "p" ]:
244
287
device_type = "gpu"
245
288
else :
246
289
device_type = "cpu"
247
290
291
+ if device_type == "inf" :
292
+ if region not in INFERENTIA_SUPPORTED_REGIONS :
293
+ raise ValueError (
294
+ "Inferentia is not supported in region {}. Supported regions are {}" .format (
295
+ region , ", " .join (INFERENTIA_SUPPORTED_REGIONS )
296
+ )
297
+ )
298
+ if framework not in INFERENTIA_VERSION_RANGES :
299
+ raise ValueError (
300
+ "Inferentia does not support {}. Currently it supports "
301
+ "MXNet and TensorFlow with more frameworks coming soon." .format (
302
+ framework .split ("-" )[- 1 ]
303
+ )
304
+ )
305
+ if not _is_inferentia_supported (framework , framework_version ):
306
+ raise ValueError (
307
+ "Inferentia is not supported with {} version {}." .format (
308
+ framework .split ("-" )[- 1 ], framework_version
309
+ )
310
+ )
311
+
248
312
use_dlc_image = _is_dlc_version (framework , framework_version , py_version )
249
313
250
314
if not py_version or (use_dlc_image and framework == "tensorflow-serving-eia" ):
0 commit comments