Skip to content

Commit 4d48530

Browse files
authored
feat: instance specific jumpstart host requirements (#4397)
* feat: instance specific jumpstart host requirements * chore: add js support for copies resource requirement, enforce coupling with ResourceRequirements class * fix: typing * fix: pylint
1 parent d131264 commit 4d48530

File tree

7 files changed

+193
-15
lines changed

7 files changed

+193
-15
lines changed

src/sagemaker/jumpstart/artifacts/resource_requirements.py

+52-14
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
"""This module contains functions for obtaining JumpStart resoure requirements."""
1414
from __future__ import absolute_import
1515

16-
from typing import Optional
16+
from typing import Dict, Optional, Tuple
1717

1818
from sagemaker.jumpstart.constants import (
1919
DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
@@ -28,6 +28,20 @@
2828
from sagemaker.session import Session
2929
from sagemaker.compute_resource_requirements.resource_requirements import ResourceRequirements
3030

31+
REQUIREMENT_TYPE_TO_SPEC_FIELD_NAME_TO_RESOURCE_REQUIREMENT_NAME_MAP: Dict[
32+
str, Dict[str, Tuple[str, str]]
33+
] = {
34+
"requests": {
35+
"num_accelerators": ("num_accelerators", "num_accelerators"),
36+
"num_cpus": ("num_cpus", "num_cpus"),
37+
"copies": ("copies", "copy_count"),
38+
"min_memory_mb": ("memory", "min_memory"),
39+
},
40+
"limits": {
41+
"max_memory_mb": ("memory", "max_memory"),
42+
},
43+
}
44+
3145

3246
def _retrieve_default_resources(
3347
model_id: str,
@@ -37,6 +51,7 @@ def _retrieve_default_resources(
3751
tolerate_vulnerable_model: bool = False,
3852
tolerate_deprecated_model: bool = False,
3953
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
54+
instance_type: Optional[str] = None,
4055
) -> ResourceRequirements:
4156
"""Retrieves the default resource requirements for the model.
4257
@@ -60,6 +75,8 @@ def _retrieve_default_resources(
6075
object, used for SageMaker interactions. If not
6176
specified, one is created using the default AWS configuration
6277
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
78+
instance_type (str): An instance type to optionally supply in order to get
79+
host requirements specific for the instance type.
6380
Returns:
6481
str: The default resource requirements to use for the model or None.
6582
@@ -87,23 +104,44 @@ def _retrieve_default_resources(
87104
is_dynamic_container_deployment_supported = (
88105
model_specs.dynamic_container_deployment_supported
89106
)
90-
default_resource_requirements = model_specs.hosting_resource_requirements
107+
default_resource_requirements: Dict[str, int] = (
108+
model_specs.hosting_resource_requirements or {}
109+
)
91110
else:
92111
raise NotImplementedError(
93112
f"Unsupported script scope for retrieving default resource requirements: '{scope}'"
94113
)
95114

115+
instance_specific_resource_requirements: Dict[str, int] = (
116+
model_specs.hosting_instance_type_variants.get_instance_specific_resource_requirements(
117+
instance_type
118+
)
119+
if instance_type
120+
and getattr(model_specs, "hosting_instance_type_variants", None) is not None
121+
else {}
122+
)
123+
124+
default_resource_requirements = {
125+
**default_resource_requirements,
126+
**instance_specific_resource_requirements,
127+
}
128+
96129
if is_dynamic_container_deployment_supported:
97-
requests = {}
98-
if "num_accelerators" in default_resource_requirements:
99-
requests["num_accelerators"] = default_resource_requirements["num_accelerators"]
100-
if "min_memory_mb" in default_resource_requirements:
101-
requests["memory"] = default_resource_requirements["min_memory_mb"]
102-
if "num_cpus" in default_resource_requirements:
103-
requests["num_cpus"] = default_resource_requirements["num_cpus"]
104-
105-
limits = {}
106-
if "max_memory_mb" in default_resource_requirements:
107-
limits["memory"] = default_resource_requirements["max_memory_mb"]
108-
return ResourceRequirements(requests=requests, limits=limits)
130+
131+
all_resource_requirement_kwargs = {}
132+
133+
for (
134+
requirement_type,
135+
spec_field_to_resource_requirement_map,
136+
) in REQUIREMENT_TYPE_TO_SPEC_FIELD_NAME_TO_RESOURCE_REQUIREMENT_NAME_MAP.items():
137+
requirement_kwargs = {}
138+
for spec_field, resource_requirement in spec_field_to_resource_requirement_map.items():
139+
if spec_field in default_resource_requirements:
140+
requirement_kwargs[resource_requirement[0]] = default_resource_requirements[
141+
spec_field
142+
]
143+
144+
all_resource_requirement_kwargs[requirement_type] = requirement_kwargs
145+
146+
return ResourceRequirements(**all_resource_requirement_kwargs)
109147
return None

src/sagemaker/jumpstart/factory/model.py

+1
Original file line numberDiff line numberDiff line change
@@ -481,6 +481,7 @@ def _add_resources_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModel
481481
tolerate_deprecated_model=kwargs.tolerate_deprecated_model,
482482
tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model,
483483
sagemaker_session=kwargs.sagemaker_session,
484+
instance_type=kwargs.instance_type,
484485
)
485486

486487
return kwargs

src/sagemaker/jumpstart/types.py

+23
Original file line numberDiff line numberDiff line change
@@ -478,6 +478,29 @@ def get_instance_specific_artifact_key(self, instance_type: str) -> Optional[str
478478
instance_type=instance_type, property_name="artifact_key"
479479
)
480480

481+
def get_instance_specific_resource_requirements(self, instance_type: str) -> Optional[str]:
482+
"""Returns instance specific resource requirements.
483+
484+
If a value exists for both the instance family and instance type, the instance type value
485+
is chosen.
486+
"""
487+
488+
instance_specific_resource_requirements: dict = (
489+
self.variants.get(instance_type, {})
490+
.get("properties", {})
491+
.get("resource_requirements", {})
492+
)
493+
494+
instance_type_family = get_instance_type_family(instance_type)
495+
496+
instance_family_resource_requirements: dict = (
497+
self.variants.get(instance_type_family, {})
498+
.get("properties", {})
499+
.get("resource_requirements", {})
500+
)
501+
502+
return {**instance_family_resource_requirements, **instance_specific_resource_requirements}
503+
481504
def _get_instance_specific_property(
482505
self, instance_type: str, property_name: str
483506
) -> Optional[str]:

src/sagemaker/resource_requirements.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
import logging
1818
from typing import Optional
19+
from sagemaker.compute_resource_requirements.resource_requirements import ResourceRequirements
1920

2021
from sagemaker.jumpstart import utils as jumpstart_utils
2122
from sagemaker.jumpstart import artifacts
@@ -33,7 +34,8 @@ def retrieve_default(
3334
tolerate_vulnerable_model: bool = False,
3435
tolerate_deprecated_model: bool = False,
3536
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
36-
) -> str:
37+
instance_type: Optional[str] = None,
38+
) -> ResourceRequirements:
3739
"""Retrieves the default resource requirements for the model matching the given arguments.
3840
3941
Args:
@@ -56,6 +58,8 @@ def retrieve_default(
5658
object, used for SageMaker interactions. If not
5759
specified, one is created using the default AWS configuration
5860
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
61+
instance_type (str): An instance type to optionally supply in order to get
62+
host requirements specific for the instance type.
5963
Returns:
6064
str: The default resource requirements to use for the model.
6165
@@ -79,4 +83,5 @@ def retrieve_default(
7983
tolerate_vulnerable_model,
8084
tolerate_deprecated_model,
8185
sagemaker_session=sagemaker_session,
86+
instance_type=instance_type,
8287
)

tests/unit/sagemaker/jumpstart/constants.py

+20
Original file line numberDiff line numberDiff line change
@@ -840,8 +840,22 @@
840840
"model_package_arn": "$gpu_model_package_arn",
841841
}
842842
},
843+
"g5": {
844+
"properties": {
845+
"resource_requirements": {
846+
"num_accelerators": 888810,
847+
"randon-field-2": 2222,
848+
}
849+
}
850+
},
843851
"m2": {"regional_properties": {"image_uri": "$cpu_image_uri"}},
844852
"c2": {"regional_properties": {"image_uri": "$cpu_image_uri"}},
853+
"ml.g5.xlarge": {
854+
"properties": {
855+
"environment_variables": {"TENSOR_PARALLEL_DEGREE": "8"},
856+
"resource_requirements": {"num_accelerators": 10},
857+
}
858+
},
845859
"ml.g5.48xlarge": {
846860
"properties": {"environment_variables": {"TENSOR_PARALLEL_DEGREE": "8"}}
847861
},
@@ -857,6 +871,12 @@
857871
"framework_version": "1.5.0",
858872
"py_version": "py3",
859873
},
874+
"dynamic_container_deployment_supported": True,
875+
"hosting_resource_requirements": {
876+
"min_memory_mb": 81999,
877+
"num_accelerators": 1,
878+
"random_field_1": 1,
879+
},
860880
"hosting_artifact_key": "pytorch-infer/infer-pytorch-ic-mobilenet-v2.tar.gz",
861881
"training_artifact_key": "pytorch-training/train-pytorch-ic-mobilenet-v2.tar.gz",
862882
"hosting_script_key": "source-directory-tarballs/pytorch/inference/ic/v1.0.0/sourcedir.tar.gz",

tests/unit/sagemaker/jumpstart/test_types.py

+23
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
"variants": {
3535
"ml.p2.12xlarge": {
3636
"properties": {
37+
"resource_requirements": {"req1": 1, "req2": {"1": 2, "2": 3}, "req3": 9},
3738
"environment_variables": {"TENSOR_PARALLEL_DEGREE": "4"},
3839
"supported_inference_instance_types": ["ml.p5.xlarge"],
3940
"default_inference_instance_type": "ml.p5.xlarge",
@@ -60,6 +61,11 @@
6061
"p2": {
6162
"regional_properties": {"image_uri": "$gpu_image_uri"},
6263
"properties": {
64+
"resource_requirements": {
65+
"req2": {"2": 5, "9": 999},
66+
"req3": 999,
67+
"req4": "blah",
68+
},
6369
"supported_inference_instance_types": ["ml.p2.xlarge", "ml.p3.xlarge"],
6470
"default_inference_instance_type": "ml.p2.xlarge",
6571
"metrics": [
@@ -879,3 +885,20 @@ def test_jumpstart_training_artifact_key_instance_variants():
879885
)
880886
is None
881887
)
888+
889+
890+
def test_jumpstart_resource_requirements_instance_variants():
891+
assert INSTANCE_TYPE_VARIANT.get_instance_specific_resource_requirements(
892+
instance_type="ml.p2.xlarge"
893+
) == {"req2": {"2": 5, "9": 999}, "req3": 999, "req4": "blah"}
894+
895+
assert INSTANCE_TYPE_VARIANT.get_instance_specific_resource_requirements(
896+
instance_type="ml.p2.12xlarge"
897+
) == {"req1": 1, "req2": {"1": 2, "2": 3}, "req3": 9, "req4": "blah"}
898+
899+
assert (
900+
INSTANCE_TYPE_VARIANT.get_instance_specific_resource_requirements(
901+
instance_type="ml.p99.12xlarge"
902+
)
903+
== {}
904+
)

tests/unit/sagemaker/resource_requirements/jumpstart/test_resource_requirements.py

+68
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,10 @@
1818
import pytest
1919

2020
from sagemaker import resource_requirements
21+
from sagemaker.compute_resource_requirements.resource_requirements import ResourceRequirements
22+
from sagemaker.jumpstart.artifacts.resource_requirements import (
23+
REQUIREMENT_TYPE_TO_SPEC_FIELD_NAME_TO_RESOURCE_REQUIREMENT_NAME_MAP,
24+
)
2125

2226
from tests.unit.sagemaker.jumpstart.utils import get_spec_from_base_spec, get_special_model_spec
2327

@@ -50,6 +54,55 @@ def test_jumpstart_resource_requirements(patched_get_model_specs):
5054
patched_get_model_specs.reset_mock()
5155

5256

57+
@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
58+
def test_jumpstart_resource_requirements_instance_type_variants(patched_get_model_specs):
59+
60+
patched_get_model_specs.side_effect = get_special_model_spec
61+
region = "us-west-2"
62+
mock_client = boto3.client("s3")
63+
mock_session = Mock(s3_client=mock_client)
64+
65+
model_id, model_version = "variant-model", "*"
66+
default_inference_resource_requirements = resource_requirements.retrieve_default(
67+
region=region,
68+
model_id=model_id,
69+
model_version=model_version,
70+
scope="inference",
71+
sagemaker_session=mock_session,
72+
instance_type="ml.g5.xlarge",
73+
)
74+
assert default_inference_resource_requirements.requests == {
75+
"memory": 81999,
76+
"num_accelerators": 10,
77+
}
78+
79+
default_inference_resource_requirements = resource_requirements.retrieve_default(
80+
region=region,
81+
model_id=model_id,
82+
model_version=model_version,
83+
scope="inference",
84+
sagemaker_session=mock_session,
85+
instance_type="ml.g5.555xlarge",
86+
)
87+
assert default_inference_resource_requirements.requests == {
88+
"memory": 81999,
89+
"num_accelerators": 888810,
90+
}
91+
92+
default_inference_resource_requirements = resource_requirements.retrieve_default(
93+
region=region,
94+
model_id=model_id,
95+
model_version=model_version,
96+
scope="inference",
97+
sagemaker_session=mock_session,
98+
instance_type="ml.f9.555xlarge",
99+
)
100+
assert default_inference_resource_requirements.requests == {
101+
"memory": 81999,
102+
"num_accelerators": 1,
103+
}
104+
105+
53106
@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
54107
def test_jumpstart_no_supported_resource_requirements(patched_get_model_specs):
55108
patched_get_model_specs.side_effect = get_special_model_spec
@@ -80,3 +133,18 @@ def test_jumpstart_no_supported_resource_requirements(patched_get_model_specs):
80133
resource_requirements.retrieve_default(
81134
region=region, model_id=model_id, model_version=model_version, scope="training"
82135
)
136+
137+
138+
def test_jumpstart_supports_all_resource_requirement_fields():
139+
140+
all_tracked_resource_requirement_fields = {
141+
field
142+
for requirements in REQUIREMENT_TYPE_TO_SPEC_FIELD_NAME_TO_RESOURCE_REQUIREMENT_NAME_MAP.values()
143+
for _, field in requirements.values()
144+
}
145+
146+
excluded_resource_requirement_fields = {"requests", "limits"}
147+
assert (
148+
set(ResourceRequirements().__dict__.keys()) - excluded_resource_requirement_fields
149+
== all_tracked_resource_requirement_fields
150+
)

0 commit comments

Comments
 (0)