63
63
"tensorflow-scriptmode" : "tensorflow-training" ,
64
64
"mxnet" : "mxnet-training" ,
65
65
"tensorflow-serving" : "tensorflow-inference" ,
66
- "mxnet-serving" : "mxnet-inference" ,
66
+ "tensorflow-serving-eia" : "tensorflow-inference-eia" ,
67
+ "mxnet-serving-eia" : "mxnet-inference-eia" ,
67
68
}
68
69
69
70
MERGED_FRAMEWORKS_LOWEST_VERSIONS = {
70
71
"tensorflow-scriptmode" : [1 , 13 , 1 ],
71
72
"mxnet" : [1 , 4 , 1 ],
72
73
"tensorflow-serving" : [1 , 13 , 0 ],
73
- "mxnet-serving" : [1 , 4 , 1 ],
74
+ "tensorflow-serving-eia" : [1 , 14 , 0 ],
75
+ "mxnet-serving-eia" : [1 , 4 , 1 ],
74
76
}
75
77
76
78
@@ -101,7 +103,7 @@ def _is_merged_versions(framework, framework_version):
101
103
return False
102
104
103
105
104
- def _using_merged_images (region , framework , py_version , accelerator_type , framework_version ):
106
+ def _using_merged_images (region , framework , py_version , framework_version ):
105
107
"""
106
108
Args:
107
109
region:
@@ -116,8 +118,11 @@ def _using_merged_images(region, framework, py_version, accelerator_type, framew
116
118
return (
117
119
(not is_gov_region )
118
120
and is_merged_versions
119
- and (is_py3 or _is_tf_14_or_later (framework , framework_version ))
120
- and accelerator_type is None
121
+ and (
122
+ is_py3
123
+ or _is_tf_14_or_later (framework , framework_version )
124
+ or _is_mxnet_serving_141_or_later (framework , framework_version )
125
+ )
121
126
)
122
127
123
128
@@ -135,7 +140,25 @@ def _is_tf_14_or_later(framework, framework_version):
135
140
)
136
141
137
142
138
- def _registry_id (region , framework , py_version , account , accelerator_type , framework_version ):
143
+ def _is_mxnet_serving_141_or_later (framework , framework_version ):
144
+ """
145
+ Args:
146
+ framework:
147
+ framework_version:
148
+ """
149
+ asimov_lowest_mxnet = [1 , 4 , 1 ]
150
+
151
+ version = [int (s ) for s in framework_version .split ("." )]
152
+
153
+ if len (version ) == 2 :
154
+ version .append (0 )
155
+
156
+ return (
157
+ framework .startswith ("mxnet-serving" ) and version >= asimov_lowest_mxnet [0 : len (version )]
158
+ )
159
+
160
+
161
+ def _registry_id (region , framework , py_version , account , framework_version ):
139
162
"""
140
163
Args:
141
164
region:
@@ -145,7 +168,7 @@ def _registry_id(region, framework, py_version, account, accelerator_type, frame
145
168
accelerator_type:
146
169
framework_version:
147
170
"""
148
- if _using_merged_images (region , framework , py_version , accelerator_type , framework_version ):
171
+ if _using_merged_images (region , framework , py_version , framework_version ):
149
172
if region in ASIMOV_OPT_IN_ACCOUNTS_BY_REGION :
150
173
return ASIMOV_OPT_IN_ACCOUNTS_BY_REGION .get (region )
151
174
return "763104351884"
@@ -187,13 +210,19 @@ def create_image_uri(
187
210
if py_version and py_version not in VALID_PY_VERSIONS :
188
211
raise ValueError ("invalid py_version argument: {}" .format (py_version ))
189
212
213
+ if _accelerator_type_valid_for_framework (
214
+ framework = framework ,
215
+ accelerator_type = accelerator_type ,
216
+ optimized_families = optimized_families ,
217
+ ):
218
+ framework += "-eia"
219
+
190
220
# Handle Account Number for Gov Cloud and frameworks with DLC merged images
191
221
account = _registry_id (
192
222
region = region ,
193
223
framework = framework ,
194
224
py_version = py_version ,
195
225
account = account ,
196
- accelerator_type = accelerator_type ,
197
226
framework_version = framework_version ,
198
227
)
199
228
@@ -218,19 +247,14 @@ def create_image_uri(
218
247
else :
219
248
device_type = "cpu"
220
249
221
- if py_version :
222
- tag = "{}-{}-{}" .format (framework_version , device_type , py_version )
223
- else :
224
- tag = "{}-{}" .format (framework_version , device_type )
250
+ using_merged_images = _using_merged_images (region , framework , py_version , framework_version )
225
251
226
- if _accelerator_type_valid_for_framework (
227
- framework = framework ,
228
- accelerator_type = accelerator_type ,
229
- optimized_families = optimized_families ,
230
- ):
231
- framework += "-eia"
252
+ if not py_version or (using_merged_images and framework == "tensorflow-serving-eia" ):
253
+ tag = "{}-{}" .format (framework_version , device_type )
254
+ else :
255
+ tag = "{}-{}-{}" .format (framework_version , device_type , py_version )
232
256
233
- if _using_merged_images ( region , framework , py_version , accelerator_type , framework_version ) :
257
+ if using_merged_images :
234
258
return "{}/{}:{}" .format (
235
259
get_ecr_image_uri_prefix (account , region ), MERGED_FRAMEWORKS_REPO_MAP [framework ], tag
236
260
)
0 commit comments