@@ -270,20 +270,6 @@ def retrieve(
270
270
return ECR_URI_TEMPLATE .format (registry = registry , hostname = hostname , repository = repo )
271
271
272
272
273
- def _get_instance_type_family (instance_type ):
274
- """Return the family of the instance type.
275
-
276
- Regex matches either "ml.<family>.<size>" or "ml_<family>. If input is None
277
- or there is no match, return an empty string.
278
- """
279
- instance_type_family = ""
280
- if isinstance (instance_type , str ):
281
- match = re .match (r"^ml[\._]([a-z\d]+)\.?\w*$" , instance_type )
282
- if match is not None :
283
- instance_type_family = match [1 ]
284
- return instance_type_family
285
-
286
-
287
273
def _get_image_tag (
288
274
container_version ,
289
275
distribution ,
@@ -297,7 +283,7 @@ def _get_image_tag(
297
283
version ,
298
284
):
299
285
"""Return image tag based on framework, container, and compute configuration(s)."""
300
- instance_type_family = _get_instance_type_family (instance_type )
286
+ instance_type_family = utils . get_instance_type_family (instance_type )
301
287
if framework in (XGBOOST_FRAMEWORK , SKLEARN_FRAMEWORK ):
302
288
if instance_type_family and final_image_scope == INFERENCE_GRAVITON :
303
289
_validate_arg (
@@ -385,7 +371,7 @@ def _config_for_framework_and_scope(framework, image_scope, accelerator_type=Non
385
371
386
372
def _validate_instance_deprecation (framework , instance_type , version ):
387
373
"""Check if instance type is deprecated for a certain framework with a certain version"""
388
- if _get_instance_type_family (instance_type ) == "p2" :
374
+ if utils . get_instance_type_family (instance_type ) == "p2" :
389
375
if (framework == "pytorch" and Version (version ) >= Version ("1.13" )) or (
390
376
framework == "tensorflow" and Version (version ) >= Version ("2.12" )
391
377
):
@@ -409,7 +395,7 @@ def _validate_for_suppported_frameworks_and_instance_type(framework, instance_ty
409
395
# Validate for Graviton allowed frameowrks
410
396
if (
411
397
instance_type is not None
412
- and _get_instance_type_family (instance_type ) in GRAVITON_ALLOWED_TARGET_INSTANCE_FAMILY
398
+ and utils . get_instance_type_family (instance_type ) in GRAVITON_ALLOWED_TARGET_INSTANCE_FAMILY
413
399
and framework not in GRAVITON_ALLOWED_FRAMEWORKS
414
400
):
415
401
_validate_framework (framework , GRAVITON_ALLOWED_FRAMEWORKS , "framework" , "Graviton" )
@@ -426,7 +412,7 @@ def _get_final_image_scope(framework, instance_type, image_scope):
426
412
"""Return final image scope based on provided framework and instance type."""
427
413
if (
428
414
framework in GRAVITON_ALLOWED_FRAMEWORKS
429
- and _get_instance_type_family (instance_type ) in GRAVITON_ALLOWED_TARGET_INSTANCE_FAMILY
415
+ and utils . get_instance_type_family (instance_type ) in GRAVITON_ALLOWED_TARGET_INSTANCE_FAMILY
430
416
):
431
417
return INFERENCE_GRAVITON
432
418
if image_scope is None and framework in (XGBOOST_FRAMEWORK , SKLEARN_FRAMEWORK ):
@@ -441,7 +427,7 @@ def _get_final_image_scope(framework, instance_type, image_scope):
441
427
def _get_inference_tool (inference_tool , instance_type ):
442
428
"""Extract the inference tool name from instance type."""
443
429
if not inference_tool :
444
- instance_type_family = _get_instance_type_family (instance_type )
430
+ instance_type_family = utils . get_instance_type_family (instance_type )
445
431
if instance_type_family .startswith ("inf" ) or instance_type_family .startswith ("trn" ):
446
432
return "neuron"
447
433
return inference_tool
@@ -529,7 +515,7 @@ def _processor(instance_type, available_processors, serverless_inference_config=
529
515
processor = "neuron"
530
516
else :
531
517
# looks for either "ml.<family>.<size>" or "ml_<family>"
532
- family = _get_instance_type_family (instance_type )
518
+ family = utils . get_instance_type_family (instance_type )
533
519
if family :
534
520
# For some frameworks, we have optimized images for specific families, e.g c5 or p3.
535
521
# In those cases, we use the family name in the image tag. In other cases, we use
@@ -559,7 +545,7 @@ def _should_auto_select_container_version(instance_type, distribution):
559
545
p4d = False
560
546
if instance_type :
561
547
# looks for either "ml.<family>.<size>" or "ml_<family>"
562
- family = _get_instance_type_family (instance_type )
548
+ family = utils . get_instance_type_family (instance_type )
563
549
if family :
564
550
p4d = family == "p4d"
565
551
0 commit comments