Skip to content

Commit 1344b23

Browse files
author
huilgolr
committed
Fix black
1 parent 8bffacc commit 1344b23

File tree

1 file changed

+34
-96
lines changed

1 file changed

+34
-96
lines changed

src/sagemaker/image_uris.py

+34-96
Original file line numberDiff line numberDiff line change
@@ -167,20 +167,13 @@ def retrieve(
167167
)
168168
else:
169169
_framework = framework
170-
if (
171-
framework == HUGGING_FACE_FRAMEWORK
172-
or framework in TRAINIUM_ALLOWED_FRAMEWORKS
173-
):
170+
if framework == HUGGING_FACE_FRAMEWORK or framework in TRAINIUM_ALLOWED_FRAMEWORKS:
174171
inference_tool = _get_inference_tool(inference_tool, instance_type)
175172
if inference_tool in ["neuron", "neuronx"]:
176173
_framework = f"{framework}-{inference_tool}"
177-
final_image_scope = _get_final_image_scope(
178-
framework, instance_type, image_scope
179-
)
174+
final_image_scope = _get_final_image_scope(framework, instance_type, image_scope)
180175
_validate_for_suppported_frameworks_and_instance_type(framework, instance_type)
181-
config = _config_for_framework_and_scope(
182-
_framework, final_image_scope, accelerator_type
183-
)
176+
config = _config_for_framework_and_scope(_framework, final_image_scope, accelerator_type)
184177

185178
original_version = version
186179
version = _validate_version_and_set_if_needed(version, config, framework)
@@ -191,14 +184,10 @@ def retrieve(
191184
full_base_framework_version = version_config["version_aliases"].get(
192185
base_framework_version, base_framework_version
193186
)
194-
_validate_arg(
195-
full_base_framework_version, list(version_config.keys()), "base framework"
196-
)
187+
_validate_arg(full_base_framework_version, list(version_config.keys()), "base framework")
197188
version_config = version_config.get(full_base_framework_version)
198189

199-
py_version = _validate_py_version_and_set_if_needed(
200-
py_version, version_config, framework
201-
)
190+
py_version = _validate_py_version_and_set_if_needed(py_version, version_config, framework)
202191
version_config = version_config.get(py_version) or version_config
203192
registry = _registry_from_region(region, version_config["registries"])
204193
endpoint_data = utils._botocore_resolver().construct_endpoint("ecr", region)
@@ -226,9 +215,7 @@ def retrieve(
226215

227216
if framework == HUGGING_FACE_FRAMEWORK:
228217
pt_or_tf_version = (
229-
re.compile("^(pytorch|tensorflow)(.*)$")
230-
.match(base_framework_version)
231-
.group(2)
218+
re.compile("^(pytorch|tensorflow)(.*)$").match(base_framework_version).group(2)
232219
)
233220
_version = original_version
234221

@@ -252,13 +239,11 @@ def retrieve(
252239
.get("version_aliases", {})
253240
.get(base_framework_version, {})
254241
):
255-
_base_framework_version = config.get("versions")[_version][
256-
"version_aliases"
257-
][base_framework_version]
242+
_base_framework_version = config.get("versions")[_version]["version_aliases"][
243+
base_framework_version
244+
]
258245
pt_or_tf_version = (
259-
re.compile("^(pytorch|tensorflow)(.*)$")
260-
.match(_base_framework_version)
261-
.group(2)
246+
re.compile("^(pytorch|tensorflow)(.*)$").match(_base_framework_version).group(2)
262247
)
263248

264249
tag_prefix = f"{pt_or_tf_version}-transformers{_version}"
@@ -285,9 +270,7 @@ def retrieve(
285270
if tag:
286271
repo += ":{}".format(tag)
287272

288-
return ECR_URI_TEMPLATE.format(
289-
registry=registry, hostname=hostname, repository=repo
290-
)
273+
return ECR_URI_TEMPLATE.format(registry=registry, hostname=hostname, repository=repo)
291274

292275

293276
def _get_image_tag(
@@ -326,13 +309,9 @@ def _get_image_tag(
326309
}
327310
tag = version_to_arm64_tag_mapping[framework][version]
328311
else:
329-
tag = _format_tag(
330-
tag_prefix, processor, py_version, container_version, inference_tool
331-
)
312+
tag = _format_tag(tag_prefix, processor, py_version, container_version, inference_tool)
332313
else:
333-
tag = _format_tag(
334-
tag_prefix, processor, py_version, container_version, inference_tool
335-
)
314+
tag = _format_tag(tag_prefix, processor, py_version, container_version, inference_tool)
336315

337316
if instance_type is not None and _should_auto_select_container_version(
338317
instance_type, distribution
@@ -383,11 +362,7 @@ def _config_for_framework_and_scope(framework, image_scope, accelerator_type=Non
383362
)
384363
image_scope = available_scopes[0]
385364

386-
if (
387-
not image_scope
388-
and "scope" in config
389-
and set(available_scopes) == {"training", "inference"}
390-
):
365+
if not image_scope and "scope" in config and set(available_scopes) == {"training", "inference"}:
391366
logger.info(
392367
"Same images used for training and inference. Defaulting to image scope: %s.",
393368
available_scopes[0],
@@ -419,27 +394,20 @@ def _validate_for_suppported_frameworks_and_instance_type(framework, instance_ty
419394
and "trn" in instance_type
420395
and framework not in TRAINIUM_ALLOWED_FRAMEWORKS
421396
):
422-
_validate_framework(
423-
framework, TRAINIUM_ALLOWED_FRAMEWORKS, "framework", "Trainium"
424-
)
397+
_validate_framework(framework, TRAINIUM_ALLOWED_FRAMEWORKS, "framework", "Trainium")
425398

426399
# Validate for Graviton allowed frameowrks
427400
if (
428401
instance_type is not None
429-
and utils.get_instance_type_family(instance_type)
430-
in GRAVITON_ALLOWED_TARGET_INSTANCE_FAMILY
402+
and utils.get_instance_type_family(instance_type) in GRAVITON_ALLOWED_TARGET_INSTANCE_FAMILY
431403
and framework not in GRAVITON_ALLOWED_FRAMEWORKS
432404
):
433-
_validate_framework(
434-
framework, GRAVITON_ALLOWED_FRAMEWORKS, "framework", "Graviton"
435-
)
405+
_validate_framework(framework, GRAVITON_ALLOWED_FRAMEWORKS, "framework", "Graviton")
436406

437407

438408
def config_for_framework(framework):
439409
"""Loads the JSON config for the given framework."""
440-
fname = os.path.join(
441-
os.path.dirname(__file__), "image_uri_config", "{}.json".format(framework)
442-
)
410+
fname = os.path.join(os.path.dirname(__file__), "image_uri_config", "{}.json".format(framework))
443411
with open(fname) as f:
444412
return json.load(f)
445413

@@ -448,8 +416,7 @@ def _get_final_image_scope(framework, instance_type, image_scope):
448416
"""Return final image scope based on provided framework and instance type."""
449417
if (
450418
framework in GRAVITON_ALLOWED_FRAMEWORKS
451-
and utils.get_instance_type_family(instance_type)
452-
in GRAVITON_ALLOWED_TARGET_INSTANCE_FAMILY
419+
and utils.get_instance_type_family(instance_type) in GRAVITON_ALLOWED_TARGET_INSTANCE_FAMILY
453420
):
454421
return INFERENCE_GRAVITON
455422
if image_scope is None and framework in (XGBOOST_FRAMEWORK, SKLEARN_FRAMEWORK):
@@ -465,9 +432,7 @@ def _get_inference_tool(inference_tool, instance_type):
465432
"""Extract the inference tool name from instance type."""
466433
if not inference_tool:
467434
instance_type_family = utils.get_instance_type_family(instance_type)
468-
if instance_type_family.startswith("inf") or instance_type_family.startswith(
469-
"trn"
470-
):
435+
if instance_type_family.startswith("inf") or instance_type_family.startswith("trn"):
471436
return "neuron"
472437
return inference_tool
473438

@@ -479,15 +444,10 @@ def _get_latest_versions(list_of_versions):
479444

480445
def _validate_accelerator_type(accelerator_type):
481446
"""Raises a ``ValueError`` if ``accelerator_type`` is invalid."""
482-
if (
483-
not accelerator_type.startswith("ml.eia")
484-
and accelerator_type != "local_sagemaker_notebook"
485-
):
447+
if not accelerator_type.startswith("ml.eia") and accelerator_type != "local_sagemaker_notebook":
486448
raise ValueError(
487449
"Invalid SageMaker Elastic Inference accelerator type: {}. "
488-
"See https://docs.aws.amazon.com/sagemaker/latest/dg/ei.html".format(
489-
accelerator_type
490-
)
450+
"See https://docs.aws.amazon.com/sagemaker/latest/dg/ei.html".format(accelerator_type)
491451
)
492452

493453

@@ -497,15 +457,11 @@ def _validate_version_and_set_if_needed(version, config, framework):
497457
aliased_versions = list(config.get("version_aliases", {}).keys())
498458

499459
if len(available_versions) == 1 and version not in aliased_versions:
500-
log_message = (
501-
"Defaulting to the only supported framework/algorithm version: {}.".format(
502-
available_versions[0]
503-
)
460+
log_message = "Defaulting to the only supported framework/algorithm version: {}.".format(
461+
available_versions[0]
504462
)
505463
if version and version != available_versions[0]:
506-
logger.warning(
507-
"%s Ignoring framework/algorithm version: %s.", log_message, version
508-
)
464+
logger.warning("%s Ignoring framework/algorithm version: %s.", log_message, version)
509465
elif not version:
510466
logger.info(log_message)
511467

@@ -518,9 +474,7 @@ def _validate_version_and_set_if_needed(version, config, framework):
518474
]:
519475
version = _get_latest_versions(available_versions)
520476

521-
_validate_arg(
522-
version, available_versions + aliased_versions, "{} version".format(framework)
523-
)
477+
_validate_arg(version, available_versions + aliased_versions, "{} version".format(framework))
524478
return version
525479

526480

@@ -546,9 +500,7 @@ def _processor(instance_type, available_processors, serverless_inference_config=
546500
return None
547501

548502
if len(available_processors) == 1 and not instance_type:
549-
logger.info(
550-
"Defaulting to only supported image scope: %s.", available_processors[0]
551-
)
503+
logger.info("Defaulting to only supported image scope: %s.", available_processors[0])
552504
return available_processors[0]
553505

554506
if serverless_inference_config is not None:
@@ -585,9 +537,7 @@ def _processor(instance_type, available_processors, serverless_inference_config=
585537
else:
586538
raise ValueError(
587539
"Invalid SageMaker instance type: {}. For options, see: "
588-
"https://aws.amazon.com/sagemaker/pricing/instance-types".format(
589-
instance_type
590-
)
540+
"https://aws.amazon.com/sagemaker/pricing/instance-types".format(instance_type)
591541
)
592542

593543
_validate_arg(processor, available_processors, "processor")
@@ -626,9 +576,7 @@ def _validate_py_version_and_set_if_needed(py_version, version_config, framework
626576
return None
627577

628578
if py_version is None and len(available_versions) == 1:
629-
logger.info(
630-
"Defaulting to only available Python version: %s", available_versions[0]
631-
)
579+
logger.info("Defaulting to only available Python version: %s", available_versions[0])
632580
return available_versions[0]
633581

634582
_validate_arg(py_version, available_versions, "Python version")
@@ -641,9 +589,7 @@ def _validate_arg(arg, available_options, arg_name):
641589
raise ValueError(
642590
"Unsupported {arg_name}: {arg}. You may need to upgrade your SDK version "
643591
"(pip install -U sagemaker) for newer {arg_name}s. Supported {arg_name}(s): "
644-
"{options}.".format(
645-
arg_name=arg_name, arg=arg, options=", ".join(available_options)
646-
)
592+
"{options}.".format(arg_name=arg_name, arg=arg, options=", ".join(available_options))
647593
)
648594

649595

@@ -656,17 +602,11 @@ def _validate_framework(framework, allowed_frameworks, arg_name, hardware_name):
656602
)
657603

658604

659-
def _format_tag(
660-
tag_prefix, processor, py_version, container_version, inference_tool=None
661-
):
605+
def _format_tag(tag_prefix, processor, py_version, container_version, inference_tool=None):
662606
"""Creates a tag for the image URI."""
663607
if inference_tool:
664-
return "-".join(
665-
x for x in (tag_prefix, inference_tool, py_version, container_version) if x
666-
)
667-
return "-".join(
668-
x for x in (tag_prefix, processor, py_version, container_version) if x
669-
)
608+
return "-".join(x for x in (tag_prefix, inference_tool, py_version, container_version) if x)
609+
return "-".join(x for x in (tag_prefix, processor, py_version, container_version) if x)
670610

671611

672612
@override_pipeline_parameter_var
@@ -775,6 +715,4 @@ def get_base_python_image_uri(region, py_version="310") -> str:
775715
repo = version_config["repository"] + "-" + py_version
776716
repo_and_tag = repo + ":" + version
777717

778-
return ECR_URI_TEMPLATE.format(
779-
registry=registry, hostname=hostname, repository=repo_and_tag
780-
)
718+
return ECR_URI_TEMPLATE.format(registry=registry, hostname=hostname, repository=repo_and_tag)

0 commit comments

Comments
 (0)