Skip to content

Commit 5c0ec63

Browse files
authored
Merge branch 'master' into feat/s3-prefix-model-data-for-jumpstart-model
2 parents c951271 + e100e0a commit 5c0ec63

File tree

12 files changed

+1195
-31
lines changed

12 files changed

+1195
-31
lines changed

src/sagemaker/environment_variables.py

+9
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from sagemaker.jumpstart import utils as jumpstart_utils
2121
from sagemaker.jumpstart import artifacts
2222
from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION
23+
from sagemaker.jumpstart.enums import JumpStartScriptScope
2324
from sagemaker.session import Session
2425

2526
logger = logging.getLogger(__name__)
@@ -33,6 +34,8 @@ def retrieve_default(
3334
tolerate_deprecated_model: bool = False,
3435
include_aws_sdk_env_vars: bool = True,
3536
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
37+
instance_type: Optional[str] = None,
38+
script: JumpStartScriptScope = JumpStartScriptScope.INFERENCE,
3639
) -> Dict[str, str]:
3740
"""Retrieves the default container environment variables for the model matching the arguments.
3841
@@ -58,6 +61,10 @@ def retrieve_default(
5861
object, used for SageMaker interactions. If not
5962
specified, one is created using the default AWS configuration
6063
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
64+
instance_type (str): An instance type to optionally supply in order to get environment
65+
variables specific for the instance type.
66+
script (JumpStartScriptScope): The JumpStart script for which to retrieve environment
67+
variables.
6168
Returns:
6269
dict: The variables to use for the model.
6370
@@ -78,4 +85,6 @@ def retrieve_default(
7885
tolerate_deprecated_model,
7986
include_aws_sdk_env_vars,
8087
sagemaker_session=sagemaker_session,
88+
instance_type=instance_type,
89+
script=script,
8190
)

src/sagemaker/image_uris.py

+7-21
Original file line numberDiff line numberDiff line change
@@ -270,20 +270,6 @@ def retrieve(
270270
return ECR_URI_TEMPLATE.format(registry=registry, hostname=hostname, repository=repo)
271271

272272

273-
def _get_instance_type_family(instance_type):
274-
"""Return the family of the instance type.
275-
276-
Regex matches either "ml.<family>.<size>" or "ml_<family>. If input is None
277-
or there is no match, return an empty string.
278-
"""
279-
instance_type_family = ""
280-
if isinstance(instance_type, str):
281-
match = re.match(r"^ml[\._]([a-z\d]+)\.?\w*$", instance_type)
282-
if match is not None:
283-
instance_type_family = match[1]
284-
return instance_type_family
285-
286-
287273
def _get_image_tag(
288274
container_version,
289275
distribution,
@@ -297,7 +283,7 @@ def _get_image_tag(
297283
version,
298284
):
299285
"""Return image tag based on framework, container, and compute configuration(s)."""
300-
instance_type_family = _get_instance_type_family(instance_type)
286+
instance_type_family = utils.get_instance_type_family(instance_type)
301287
if framework in (XGBOOST_FRAMEWORK, SKLEARN_FRAMEWORK):
302288
if instance_type_family and final_image_scope == INFERENCE_GRAVITON:
303289
_validate_arg(
@@ -385,7 +371,7 @@ def _config_for_framework_and_scope(framework, image_scope, accelerator_type=Non
385371

386372
def _validate_instance_deprecation(framework, instance_type, version):
387373
"""Check if instance type is deprecated for a certain framework with a certain version"""
388-
if _get_instance_type_family(instance_type) == "p2":
374+
if utils.get_instance_type_family(instance_type) == "p2":
389375
if (framework == "pytorch" and Version(version) >= Version("1.13")) or (
390376
framework == "tensorflow" and Version(version) >= Version("2.12")
391377
):
@@ -409,7 +395,7 @@ def _validate_for_suppported_frameworks_and_instance_type(framework, instance_ty
409395
# Validate for Graviton allowed frameowrks
410396
if (
411397
instance_type is not None
412-
and _get_instance_type_family(instance_type) in GRAVITON_ALLOWED_TARGET_INSTANCE_FAMILY
398+
and utils.get_instance_type_family(instance_type) in GRAVITON_ALLOWED_TARGET_INSTANCE_FAMILY
413399
and framework not in GRAVITON_ALLOWED_FRAMEWORKS
414400
):
415401
_validate_framework(framework, GRAVITON_ALLOWED_FRAMEWORKS, "framework", "Graviton")
@@ -426,7 +412,7 @@ def _get_final_image_scope(framework, instance_type, image_scope):
426412
"""Return final image scope based on provided framework and instance type."""
427413
if (
428414
framework in GRAVITON_ALLOWED_FRAMEWORKS
429-
and _get_instance_type_family(instance_type) in GRAVITON_ALLOWED_TARGET_INSTANCE_FAMILY
415+
and utils.get_instance_type_family(instance_type) in GRAVITON_ALLOWED_TARGET_INSTANCE_FAMILY
430416
):
431417
return INFERENCE_GRAVITON
432418
if image_scope is None and framework in (XGBOOST_FRAMEWORK, SKLEARN_FRAMEWORK):
@@ -441,7 +427,7 @@ def _get_final_image_scope(framework, instance_type, image_scope):
441427
def _get_inference_tool(inference_tool, instance_type):
442428
"""Extract the inference tool name from instance type."""
443429
if not inference_tool:
444-
instance_type_family = _get_instance_type_family(instance_type)
430+
instance_type_family = utils.get_instance_type_family(instance_type)
445431
if instance_type_family.startswith("inf") or instance_type_family.startswith("trn"):
446432
return "neuron"
447433
return inference_tool
@@ -529,7 +515,7 @@ def _processor(instance_type, available_processors, serverless_inference_config=
529515
processor = "neuron"
530516
else:
531517
# looks for either "ml.<family>.<size>" or "ml_<family>"
532-
family = _get_instance_type_family(instance_type)
518+
family = utils.get_instance_type_family(instance_type)
533519
if family:
534520
# For some frameworks, we have optimized images for specific families, e.g c5 or p3.
535521
# In those cases, we use the family name in the image tag. In other cases, we use
@@ -559,7 +545,7 @@ def _should_auto_select_container_version(instance_type, distribution):
559545
p4d = False
560546
if instance_type:
561547
# looks for either "ml.<family>.<size>" or "ml_<family>"
562-
family = _get_instance_type_family(instance_type)
548+
family = utils.get_instance_type_family(instance_type)
563549
if family:
564550
p4d = family == "p4d"
565551

src/sagemaker/jumpstart/artifacts/environment_variables.py

+31-5
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@ def _retrieve_default_environment_variables(
3434
tolerate_deprecated_model: bool = False,
3535
include_aws_sdk_env_vars: bool = True,
3636
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
37+
instance_type: Optional[str] = None,
38+
script: JumpStartScriptScope = JumpStartScriptScope.INFERENCE,
3739
) -> Dict[str, str]:
3840
"""Retrieves the inference environment variables for the model matching the given arguments.
3941
@@ -59,6 +61,10 @@ def _retrieve_default_environment_variables(
5961
object, used for SageMaker interactions. If not
6062
specified, one is created using the default AWS configuration
6163
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
64+
instance_type (str): An instance type to optionally supply in order to get
65+
environment variables specific for the instance type.
66+
script (JumpStartScriptScope): The JumpStart script for which to retrieve
67+
environment variables.
6268
Returns:
6369
dict: the inference environment variables to use for the model.
6470
"""
@@ -69,17 +75,37 @@ def _retrieve_default_environment_variables(
6975
model_specs = verify_model_region_and_return_specs(
7076
model_id=model_id,
7177
version=model_version,
72-
scope=JumpStartScriptScope.INFERENCE,
78+
scope=script,
7379
region=region,
7480
tolerate_vulnerable_model=tolerate_vulnerable_model,
7581
tolerate_deprecated_model=tolerate_deprecated_model,
7682
sagemaker_session=sagemaker_session,
7783
)
7884

7985
default_environment_variables: Dict[str, str] = {}
80-
for environment_variable in model_specs.inference_environment_variables:
81-
if include_aws_sdk_env_vars or environment_variable.required_for_model_class:
82-
default_environment_variables[environment_variable.name] = str(
83-
environment_variable.default
86+
if script == JumpStartScriptScope.INFERENCE:
87+
for environment_variable in model_specs.inference_environment_variables:
88+
if include_aws_sdk_env_vars or environment_variable.required_for_model_class:
89+
default_environment_variables[environment_variable.name] = str(
90+
environment_variable.default
91+
)
92+
93+
if instance_type:
94+
if script == JumpStartScriptScope.INFERENCE and getattr(
95+
model_specs, "hosting_instance_type_variants", None
96+
):
97+
default_environment_variables.update(
98+
model_specs.hosting_instance_type_variants.get_instance_specific_environment_variables( # noqa E501 # pylint: disable=c0301
99+
instance_type
100+
)
101+
)
102+
elif script == JumpStartScriptScope.TRAINING and getattr(
103+
model_specs, "training_instance_type_variants", None
104+
):
105+
default_environment_variables.update(
106+
model_specs.training_instance_type_variants.get_instance_specific_environment_variables( # noqa E501 # pylint: disable=c0301
107+
instance_type
108+
)
84109
)
110+
85111
return default_environment_variables

src/sagemaker/jumpstart/artifacts/image_uris.py

+26-1
Original file line numberDiff line numberDiff line change
@@ -118,10 +118,35 @@ def _retrieve_image_uri(
118118
)
119119

120120
if image_scope == JumpStartScriptScope.INFERENCE:
121+
hosting_instance_type_variants = model_specs.hosting_instance_type_variants
122+
if hosting_instance_type_variants:
123+
image_uri = hosting_instance_type_variants.get_image_uri(
124+
instance_type=instance_type, region=region
125+
)
126+
if image_uri is not None:
127+
return image_uri
121128
ecr_specs = model_specs.hosting_ecr_specs
129+
if ecr_specs is None:
130+
raise ValueError(
131+
f"No inference ECR configuration found for JumpStart model ID '{model_id}' "
132+
f"with {instance_type} instance type in {region}. "
133+
"Please try another instance type or region."
134+
)
122135
elif image_scope == JumpStartScriptScope.TRAINING:
136+
training_instance_type_variants = model_specs.training_instance_type_variants
137+
if training_instance_type_variants:
138+
image_uri = training_instance_type_variants.get_image_uri(
139+
instance_type=instance_type, region=region
140+
)
141+
if image_uri is not None:
142+
return image_uri
123143
ecr_specs = model_specs.training_ecr_specs
124-
144+
if ecr_specs is None:
145+
raise ValueError(
146+
f"No training ECR configuration found for JumpStart model ID '{model_id}' "
147+
f"with {instance_type} instance type in {region}. "
148+
"Please try another instance type or region."
149+
)
125150
if framework is not None and framework != ecr_specs.framework:
126151
raise ValueError(
127152
f"Incorrect container framework '{framework}' for JumpStart model ID '{model_id}' "

src/sagemaker/jumpstart/factory/model.py

+2
Original file line numberDiff line numberDiff line change
@@ -304,6 +304,8 @@ def _add_env_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModelInitKw
304304
tolerate_deprecated_model=kwargs.tolerate_deprecated_model,
305305
tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model,
306306
sagemaker_session=kwargs.sagemaker_session,
307+
script=JumpStartScriptScope.INFERENCE,
308+
instance_type=kwargs.instance_type,
307309
)
308310

309311
for key, value in extra_env_vars.items():

0 commit comments

Comments
 (0)