|
22 | 22 |
|
23 | 23 | import sagemaker.utils
|
24 | 24 | from sagemaker import s3
|
25 |
| -from sagemaker.utils import get_ecr_image_uri_prefix, ECR_URI_PATTERN |
26 | 25 |
|
27 | 26 | logger = logging.getLogger("sagemaker")
|
28 | 27 |
|
|
66 | 65 | "Please add framework_version={} to your constructor to avoid this error."
|
67 | 66 | )
|
68 | 67 |
|
69 |
| -VALID_PY_VERSIONS = ["py2", "py3", "py37"] |
70 |
| -VALID_EIA_FRAMEWORKS = [ |
71 |
| - "tensorflow", |
72 |
| - "tensorflow-serving", |
73 |
| - "mxnet", |
74 |
| - "mxnet-serving", |
75 |
| - "pytorch-serving", |
76 |
| -] |
77 |
| -PY2_RESTRICTED_EIA_FRAMEWORKS = ["pytorch-serving"] |
78 |
| -PY37_SUPPORTED_FRAMEWORKS = ["tensorflow-scriptmode"] |
79 |
| -VALID_ACCOUNTS_BY_REGION = { |
80 |
| - "us-gov-west-1": "246785580436", |
81 |
| - "us-iso-east-1": "744548109606", |
82 |
| - "cn-north-1": "422961961927", |
83 |
| - "cn-northwest-1": "423003514399", |
84 |
| -} |
85 |
| -ASIMOV_VALID_ACCOUNTS_BY_REGION = { |
86 |
| - "us-gov-west-1": "442386744353", |
87 |
| - "us-iso-east-1": "886529160074", |
88 |
| - "cn-north-1": "727897471807", |
89 |
| - "cn-northwest-1": "727897471807", |
90 |
| -} |
91 |
| -OPT_IN_ACCOUNTS_BY_REGION = {"ap-east-1": "057415533634", "me-south-1": "724002660598"} |
92 |
| -ASIMOV_OPT_IN_ACCOUNTS_BY_REGION = {"ap-east-1": "871362719292", "me-south-1": "217643126080"} |
93 |
| -DEFAULT_ACCOUNT = "520713654638" |
94 |
| -ASIMOV_PROD_ACCOUNT = "763104351884" |
95 |
| -ASIMOV_DEFAULT_ACCOUNT = ASIMOV_PROD_ACCOUNT |
96 | 68 | SINGLE_GPU_INSTANCE_TYPES = ("ml.p2.xlarge", "ml.p3.2xlarge")
|
97 | 69 |
|
98 |
| -MERGED_FRAMEWORKS_REPO_MAP = { |
99 |
| - "tensorflow-scriptmode": "tensorflow-training", |
100 |
| - "tensorflow-serving": "tensorflow-inference", |
101 |
| - "tensorflow-serving-eia": "tensorflow-inference-eia", |
102 |
| - "mxnet": "mxnet-training", |
103 |
| - "mxnet-serving": "mxnet-inference", |
104 |
| - "mxnet-serving-eia": "mxnet-inference-eia", |
105 |
| - "pytorch": "pytorch-training", |
106 |
| - "pytorch-serving": "pytorch-inference", |
107 |
| - "pytorch-serving-eia": "pytorch-inference-eia", |
108 |
| -} |
109 |
| - |
110 |
| -MERGED_FRAMEWORKS_LOWEST_VERSIONS = { |
111 |
| - "tensorflow-scriptmode": {"py3": [1, 13, 1], "py2": [1, 14, 0], "py37": [1, 15, 2]}, |
112 |
| - "tensorflow-serving": [1, 13, 0], |
113 |
| - "tensorflow-serving-eia": [1, 14, 0], |
114 |
| - "mxnet": {"py3": [1, 4, 1], "py2": [1, 6, 0]}, |
115 |
| - "mxnet-serving": {"py3": [1, 4, 1], "py2": [1, 6, 0]}, |
116 |
| - "mxnet-serving-eia": [1, 4, 1], |
117 |
| - "pytorch": [1, 2, 0], |
118 |
| - "pytorch-serving": [1, 2, 0], |
119 |
| - "pytorch-serving-eia": [1, 3, 1], |
120 |
| -} |
121 |
| - |
122 |
| -INFERENTIA_VERSION_RANGES = { |
123 |
| - "neo-mxnet": [[1, 5, 1], [1, 5, 1]], |
124 |
| - "neo-tensorflow": [[1, 15, 0], [1, 15, 0]], |
125 |
| -} |
126 |
| - |
127 |
| -INFERENTIA_SUPPORTED_REGIONS = ["us-east-1", "us-west-2"] |
128 |
| - |
129 | 70 | DEBUGGER_UNSUPPORTED_REGIONS = ["us-gov-west-1", "us-iso-east-1"]
|
130 | 71 |
|
131 | 72 |
|
@@ -163,229 +104,6 @@ def is_version_equal_or_lower(highest_version, framework_version):
|
163 | 104 | return version_list <= highest_version[0 : len(version_list)]
|
164 | 105 |
|
165 | 106 |
|
166 |
| -def _is_dlc_version(framework, framework_version, py_version): |
167 |
| - """Return if the framework's version uses the corresponding DLC image. |
168 |
| -
|
169 |
| - Args: |
170 |
| - framework (str): The framework name, e.g. "tensorflow-scriptmode" |
171 |
| - framework_version (str): The framework version |
172 |
| - py_version (str): The Python version, e.g. "py3" |
173 |
| -
|
174 |
| - Returns: |
175 |
| - bool: Whether or not the framework's version uses the DLC image. |
176 |
| - """ |
177 |
| - lowest_version_list = MERGED_FRAMEWORKS_LOWEST_VERSIONS.get(framework) |
178 |
| - if isinstance(lowest_version_list, dict): |
179 |
| - lowest_version_list = lowest_version_list[py_version] |
180 |
| - |
181 |
| - if lowest_version_list: |
182 |
| - return is_version_equal_or_higher(lowest_version_list, framework_version) |
183 |
| - return False |
184 |
| - |
185 |
| - |
186 |
| -def _is_inferentia_supported(framework, framework_version): |
187 |
| - """Return if Inferentia supports the framework and its version. |
188 |
| -
|
189 |
| - Args: |
190 |
| - framework (str): The framework name, e.g. "tensorflow" |
191 |
| - framework_version (str): The framework version |
192 |
| -
|
193 |
| - Returns: |
194 |
| - bool: Whether or not Inferentia supports the framework and its version. |
195 |
| - """ |
196 |
| - lowest_version_list = INFERENTIA_VERSION_RANGES.get(framework)[0] |
197 |
| - highest_version_list = INFERENTIA_VERSION_RANGES.get(framework)[1] |
198 |
| - return is_version_equal_or_higher( |
199 |
| - lowest_version_list, framework_version |
200 |
| - ) and is_version_equal_or_lower(highest_version_list, framework_version) |
201 |
| - |
202 |
| - |
203 |
| -def _registry_id(region, framework, py_version, account, framework_version): |
204 |
| - """Return the Amazon ECR registry number (or AWS account ID) for |
205 |
| - the given framework, framework version, Python version, and region. |
206 |
| -
|
207 |
| - Args: |
208 |
| - region (str): The AWS region. |
209 |
| - framework (str): The framework name, e.g. "tensorflow-scriptmode". |
210 |
| - py_version (str): The Python version, e.g. "py3". |
211 |
| - account (str): The AWS account ID to use as a default. |
212 |
| - framework_version (str): The framework version. |
213 |
| -
|
214 |
| - Returns: |
215 |
| - str: The appropriate Amazon ECR registry number. If there is no |
216 |
| - specific one for the framework, framework version, Python version, |
217 |
| - and region, then ``account`` is returned. |
218 |
| - """ |
219 |
| - if _is_dlc_version(framework, framework_version, py_version): |
220 |
| - if region in ASIMOV_OPT_IN_ACCOUNTS_BY_REGION: |
221 |
| - return ASIMOV_OPT_IN_ACCOUNTS_BY_REGION.get(region) |
222 |
| - if region in ASIMOV_VALID_ACCOUNTS_BY_REGION: |
223 |
| - return ASIMOV_VALID_ACCOUNTS_BY_REGION.get(region) |
224 |
| - return ASIMOV_DEFAULT_ACCOUNT |
225 |
| - if region in OPT_IN_ACCOUNTS_BY_REGION: |
226 |
| - return OPT_IN_ACCOUNTS_BY_REGION.get(region) |
227 |
| - return VALID_ACCOUNTS_BY_REGION.get(region, account) |
228 |
| - |
229 |
| - |
230 |
| -def create_image_uri( |
231 |
| - region, |
232 |
| - framework, |
233 |
| - instance_type, |
234 |
| - framework_version, |
235 |
| - py_version=None, |
236 |
| - account=None, |
237 |
| - accelerator_type=None, |
238 |
| - optimized_families=None, |
239 |
| -): |
240 |
| - """Return the ECR URI of an image. |
241 |
| -
|
242 |
| - Args: |
243 |
| - region (str): AWS region where the image is uploaded. |
244 |
| - framework (str): framework used by the image. |
245 |
| - instance_type (str): SageMaker instance type. Used to determine device |
246 |
| - type (cpu/gpu/family-specific optimized). |
247 |
| - framework_version (str): The version of the framework. |
248 |
| - py_version (str): Optional. Python version. If specified, should be one |
249 |
| - of 'py2' or 'py3'. If not specified, image uri will not include a |
250 |
| - python component. |
251 |
| - account (str): AWS account that contains the image. (default: |
252 |
| - '520713654638') |
253 |
| - accelerator_type (str): SageMaker Elastic Inference accelerator type. |
254 |
| - optimized_families (str): Instance families for which there exist |
255 |
| - specific optimized images. |
256 |
| -
|
257 |
| - Returns: |
258 |
| - str: The appropriate image URI based on the given parameters. |
259 |
| - """ |
260 |
| - logger.warning( |
261 |
| - "'create_image_uri' will be deprecated in favor of 'ImageURIProvider' class " |
262 |
| - "in SageMaker Python SDK v2." |
263 |
| - ) |
264 |
| - |
265 |
| - optimized_families = optimized_families or [] |
266 |
| - |
267 |
| - if py_version and py_version not in VALID_PY_VERSIONS: |
268 |
| - raise ValueError("invalid py_version argument: {}".format(py_version)) |
269 |
| - |
270 |
| - if py_version == "py37" and framework not in PY37_SUPPORTED_FRAMEWORKS: |
271 |
| - raise ValueError("{} does not support Python 3.7 at this time.".format(framework)) |
272 |
| - |
273 |
| - if _accelerator_type_valid_for_framework( |
274 |
| - framework=framework, |
275 |
| - py_version=py_version, |
276 |
| - accelerator_type=accelerator_type, |
277 |
| - optimized_families=optimized_families, |
278 |
| - ): |
279 |
| - framework += "-eia" |
280 |
| - |
281 |
| - # Handle account number for specific cases (e.g. GovCloud, opt-in regions, DLC images etc.) |
282 |
| - if account is None: |
283 |
| - account = _registry_id( |
284 |
| - region=region, |
285 |
| - framework=framework, |
286 |
| - py_version=py_version, |
287 |
| - account=DEFAULT_ACCOUNT, |
288 |
| - framework_version=framework_version, |
289 |
| - ) |
290 |
| - |
291 |
| - # Handle Local Mode |
292 |
| - if instance_type.startswith("local"): |
293 |
| - device_type = "cpu" if instance_type == "local" else "gpu" |
294 |
| - elif not instance_type.startswith("ml."): |
295 |
| - raise ValueError( |
296 |
| - "{} is not a valid SageMaker instance type. See: " |
297 |
| - "https://aws.amazon.com/sagemaker/pricing/instance-types/".format(instance_type) |
298 |
| - ) |
299 |
| - else: |
300 |
| - family = instance_type.split(".")[1] |
301 |
| - |
302 |
| - # For some frameworks, we have optimized images for specific families, e.g c5 or p3. |
303 |
| - # In those cases, we use the family name in the image tag. In other cases, we use |
304 |
| - # 'cpu' or 'gpu'. |
305 |
| - if family in optimized_families: |
306 |
| - device_type = family |
307 |
| - elif family.startswith("inf"): |
308 |
| - device_type = "inf" |
309 |
| - elif family[0] in ["g", "p"]: |
310 |
| - device_type = "gpu" |
311 |
| - else: |
312 |
| - device_type = "cpu" |
313 |
| - |
314 |
| - if device_type == "inf": |
315 |
| - if region not in INFERENTIA_SUPPORTED_REGIONS: |
316 |
| - raise ValueError( |
317 |
| - "Inferentia is not supported in region {}. Supported regions are {}".format( |
318 |
| - region, ", ".join(INFERENTIA_SUPPORTED_REGIONS) |
319 |
| - ) |
320 |
| - ) |
321 |
| - if framework not in INFERENTIA_VERSION_RANGES: |
322 |
| - raise ValueError( |
323 |
| - "Inferentia does not support {}. Currently it supports " |
324 |
| - "MXNet and TensorFlow with more frameworks coming soon.".format( |
325 |
| - framework.split("-")[-1] |
326 |
| - ) |
327 |
| - ) |
328 |
| - if not _is_inferentia_supported(framework, framework_version): |
329 |
| - raise ValueError( |
330 |
| - "Inferentia is not supported with {} version {}.".format( |
331 |
| - framework.split("-")[-1], framework_version |
332 |
| - ) |
333 |
| - ) |
334 |
| - |
335 |
| - use_dlc_image = _is_dlc_version(framework, framework_version, py_version) |
336 |
| - |
337 |
| - if not py_version or (use_dlc_image and framework == "tensorflow-serving-eia"): |
338 |
| - tag = "{}-{}".format(framework_version, device_type) |
339 |
| - else: |
340 |
| - tag = "{}-{}-{}".format(framework_version, device_type, py_version) |
341 |
| - |
342 |
| - if use_dlc_image: |
343 |
| - ecr_repo = MERGED_FRAMEWORKS_REPO_MAP[framework] |
344 |
| - else: |
345 |
| - ecr_repo = "sagemaker-{}".format(framework) |
346 |
| - |
347 |
| - return "{}/{}:{}".format(get_ecr_image_uri_prefix(account, region), ecr_repo, tag) |
348 |
| - |
349 |
| - |
350 |
| -def _accelerator_type_valid_for_framework( |
351 |
| - framework, py_version, accelerator_type=None, optimized_families=None |
352 |
| -): |
353 |
| - """ |
354 |
| - Args: |
355 |
| - framework: |
356 |
| - py_version: |
357 |
| - accelerator_type: |
358 |
| - optimized_families: |
359 |
| - """ |
360 |
| - if accelerator_type is None: |
361 |
| - return False |
362 |
| - |
363 |
| - if py_version == "py2" and framework in PY2_RESTRICTED_EIA_FRAMEWORKS: |
364 |
| - raise ValueError( |
365 |
| - "{} is not supported with Amazon Elastic Inference in Python 2.".format(framework) |
366 |
| - ) |
367 |
| - |
368 |
| - if framework not in VALID_EIA_FRAMEWORKS: |
369 |
| - raise ValueError( |
370 |
| - "{} is not supported with Amazon Elastic Inference. Currently only " |
371 |
| - "Python-based TensorFlow, MXNet, PyTorch are supported.".format(framework) |
372 |
| - ) |
373 |
| - |
374 |
| - if optimized_families: |
375 |
| - raise ValueError("Neo does not support Amazon Elastic Inference.") |
376 |
| - |
377 |
| - if ( |
378 |
| - not accelerator_type.startswith("ml.eia") |
379 |
| - and not accelerator_type == "local_sagemaker_notebook" |
380 |
| - ): |
381 |
| - raise ValueError( |
382 |
| - "{} is not a valid SageMaker Elastic Inference accelerator type. " |
383 |
| - "See: https://docs.aws.amazon.com/sagemaker/latest/dg/ei.html".format(accelerator_type) |
384 |
| - ) |
385 |
| - |
386 |
| - return True |
387 |
| - |
388 |
| - |
389 | 107 | def validate_source_dir(script, directory):
|
390 | 108 | """Validate that the source directory exists and it contains the user script
|
391 | 109 | Args:
|
@@ -505,7 +223,7 @@ def framework_name_from_image(image_uri):
|
505 | 223 | str: The framework name str: The Python version str: The image tag
|
506 | 224 | str: If the image is script mode
|
507 | 225 | """
|
508 |
| - sagemaker_pattern = re.compile(ECR_URI_PATTERN) |
| 226 | + sagemaker_pattern = re.compile(sagemaker.utils.ECR_URI_PATTERN) |
509 | 227 | sagemaker_match = sagemaker_pattern.match(image_uri)
|
510 | 228 | if sagemaker_match is None:
|
511 | 229 | return None, None, None, None
|
|
0 commit comments