Skip to content

Commit 5b0ce21

Browse files
authored
Merge branch 'master' into fix/sagemaker-session-region-not-being-used
2 parents e8ad3b3 + 12bcf05 commit 5b0ce21

File tree

11 files changed

+216
-23
lines changed

11 files changed

+216
-23
lines changed

requirements/extras/test_requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ awslogs==0.14.0
1212
black==22.3.0
1313
stopit==1.1.2
1414
# Update tox.ini to have correct version of airflow constraints file
15-
apache-airflow==2.8.1
15+
apache-airflow==2.8.2
1616
apache-airflow-providers-amazon==7.2.1
1717
attrs>=23.1.0,<24
1818
fabric==2.6.0

src/sagemaker/config/config.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,9 @@ def _load_config_from_file(file_path: str) -> dict:
181181
f"Provide a valid file path"
182182
)
183183
logger.debug("Fetching defaults config from location: %s", file_path)
184-
return yaml.safe_load(open(inferred_file_path, "r"))
184+
with open(inferred_file_path, "r") as f:
185+
content = yaml.safe_load(f)
186+
return content
185187

186188

187189
def _load_config_from_s3(s3_uri, s3_resource_for_config) -> dict:

src/sagemaker/jumpstart/artifacts/resource_requirements.py

Lines changed: 52 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
"""This module contains functions for obtaining JumpStart resoure requirements."""
1414
from __future__ import absolute_import
1515

16-
from typing import Optional
16+
from typing import Dict, Optional, Tuple
1717

1818
from sagemaker.jumpstart.constants import (
1919
DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
@@ -27,6 +27,20 @@
2727
from sagemaker.session import Session
2828
from sagemaker.compute_resource_requirements.resource_requirements import ResourceRequirements
2929

30+
REQUIREMENT_TYPE_TO_SPEC_FIELD_NAME_TO_RESOURCE_REQUIREMENT_NAME_MAP: Dict[
31+
str, Dict[str, Tuple[str, str]]
32+
] = {
33+
"requests": {
34+
"num_accelerators": ("num_accelerators", "num_accelerators"),
35+
"num_cpus": ("num_cpus", "num_cpus"),
36+
"copies": ("copies", "copy_count"),
37+
"min_memory_mb": ("memory", "min_memory"),
38+
},
39+
"limits": {
40+
"max_memory_mb": ("memory", "max_memory"),
41+
},
42+
}
43+
3044

3145
def _retrieve_default_resources(
3246
model_id: str,
@@ -36,6 +50,7 @@ def _retrieve_default_resources(
3650
tolerate_vulnerable_model: bool = False,
3751
tolerate_deprecated_model: bool = False,
3852
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
53+
instance_type: Optional[str] = None,
3954
) -> ResourceRequirements:
4055
"""Retrieves the default resource requirements for the model.
4156
@@ -59,6 +74,8 @@ def _retrieve_default_resources(
5974
object, used for SageMaker interactions. If not
6075
specified, one is created using the default AWS configuration
6176
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
77+
instance_type (str): An instance type to optionally supply in order to get
78+
host requirements specific for the instance type.
6279
Returns:
6380
str: The default resource requirements to use for the model or None.
6481
@@ -86,23 +103,44 @@ def _retrieve_default_resources(
86103
is_dynamic_container_deployment_supported = (
87104
model_specs.dynamic_container_deployment_supported
88105
)
89-
default_resource_requirements = model_specs.hosting_resource_requirements
106+
default_resource_requirements: Dict[str, int] = (
107+
model_specs.hosting_resource_requirements or {}
108+
)
90109
else:
91110
raise NotImplementedError(
92111
f"Unsupported script scope for retrieving default resource requirements: '{scope}'"
93112
)
94113

114+
instance_specific_resource_requirements: Dict[str, int] = (
115+
model_specs.hosting_instance_type_variants.get_instance_specific_resource_requirements(
116+
instance_type
117+
)
118+
if instance_type
119+
and getattr(model_specs, "hosting_instance_type_variants", None) is not None
120+
else {}
121+
)
122+
123+
default_resource_requirements = {
124+
**default_resource_requirements,
125+
**instance_specific_resource_requirements,
126+
}
127+
95128
if is_dynamic_container_deployment_supported:
96-
requests = {}
97-
if "num_accelerators" in default_resource_requirements:
98-
requests["num_accelerators"] = default_resource_requirements["num_accelerators"]
99-
if "min_memory_mb" in default_resource_requirements:
100-
requests["memory"] = default_resource_requirements["min_memory_mb"]
101-
if "num_cpus" in default_resource_requirements:
102-
requests["num_cpus"] = default_resource_requirements["num_cpus"]
103-
104-
limits = {}
105-
if "max_memory_mb" in default_resource_requirements:
106-
limits["memory"] = default_resource_requirements["max_memory_mb"]
107-
return ResourceRequirements(requests=requests, limits=limits)
129+
130+
all_resource_requirement_kwargs = {}
131+
132+
for (
133+
requirement_type,
134+
spec_field_to_resource_requirement_map,
135+
) in REQUIREMENT_TYPE_TO_SPEC_FIELD_NAME_TO_RESOURCE_REQUIREMENT_NAME_MAP.items():
136+
requirement_kwargs = {}
137+
for spec_field, resource_requirement in spec_field_to_resource_requirement_map.items():
138+
if spec_field in default_resource_requirements:
139+
requirement_kwargs[resource_requirement[0]] = default_resource_requirements[
140+
spec_field
141+
]
142+
143+
all_resource_requirement_kwargs[requirement_type] = requirement_kwargs
144+
145+
return ResourceRequirements(**all_resource_requirement_kwargs)
108146
return None

src/sagemaker/jumpstart/factory/model.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -483,6 +483,7 @@ def _add_resources_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModel
483483
tolerate_deprecated_model=kwargs.tolerate_deprecated_model,
484484
tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model,
485485
sagemaker_session=kwargs.sagemaker_session,
486+
instance_type=kwargs.instance_type,
486487
)
487488

488489
return kwargs

src/sagemaker/jumpstart/types.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -478,6 +478,29 @@ def get_instance_specific_artifact_key(self, instance_type: str) -> Optional[str
478478
instance_type=instance_type, property_name="artifact_key"
479479
)
480480

481+
def get_instance_specific_resource_requirements(self, instance_type: str) -> Optional[str]:
482+
"""Returns instance specific resource requirements.
483+
484+
If a value exists for both the instance family and instance type, the instance type value
485+
is chosen.
486+
"""
487+
488+
instance_specific_resource_requirements: dict = (
489+
self.variants.get(instance_type, {})
490+
.get("properties", {})
491+
.get("resource_requirements", {})
492+
)
493+
494+
instance_type_family = get_instance_type_family(instance_type)
495+
496+
instance_family_resource_requirements: dict = (
497+
self.variants.get(instance_type_family, {})
498+
.get("properties", {})
499+
.get("resource_requirements", {})
500+
)
501+
502+
return {**instance_family_resource_requirements, **instance_specific_resource_requirements}
503+
481504
def _get_instance_specific_property(
482505
self, instance_type: str, property_name: str
483506
) -> Optional[str]:

src/sagemaker/resource_requirements.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
import logging
1818
from typing import Optional
19+
from sagemaker.compute_resource_requirements.resource_requirements import ResourceRequirements
1920

2021
from sagemaker.jumpstart import utils as jumpstart_utils
2122
from sagemaker.jumpstart import artifacts
@@ -33,7 +34,8 @@ def retrieve_default(
3334
tolerate_vulnerable_model: bool = False,
3435
tolerate_deprecated_model: bool = False,
3536
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
36-
) -> str:
37+
instance_type: Optional[str] = None,
38+
) -> ResourceRequirements:
3739
"""Retrieves the default resource requirements for the model matching the given arguments.
3840
3941
Args:
@@ -56,6 +58,8 @@ def retrieve_default(
5658
object, used for SageMaker interactions. If not
5759
specified, one is created using the default AWS configuration
5860
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
61+
instance_type (str): An instance type to optionally supply in order to get
62+
host requirements specific for the instance type.
5963
Returns:
6064
str: The default resource requirements to use for the model.
6165
@@ -79,4 +83,5 @@ def retrieve_default(
7983
tolerate_vulnerable_model,
8084
tolerate_deprecated_model,
8185
sagemaker_session=sagemaker_session,
86+
instance_type=instance_type,
8287
)

tests/unit/sagemaker/config/test_config.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,15 +34,23 @@
3434
@pytest.fixture()
3535
def config_file_as_yaml(get_data_dir):
3636
config_file_path = os.path.join(get_data_dir, "config.yaml")
37-
return open(config_file_path, "r").read()
37+
with open(config_file_path, "r") as f:
38+
content = f.read()
39+
return content
3840

3941

4042
@pytest.fixture()
4143
def expected_merged_config(get_data_dir):
4244
expected_merged_config_file_path = os.path.join(
4345
get_data_dir, "expected_output_config_after_merge.yaml"
4446
)
45-
return yaml.safe_load(open(expected_merged_config_file_path, "r").read())
47+
with open(expected_merged_config_file_path, "r") as f:
48+
content = yaml.safe_load(f.read())
49+
return content
50+
51+
52+
def _raise_valueerror(*args):
53+
raise ValueError(args)
4654

4755

4856
def test_config_when_default_config_file_and_user_config_file_is_not_found():
@@ -60,7 +68,8 @@ def test_config_when_overriden_default_config_file_is_not_found(get_data_dir):
6068
def test_invalid_config_file_which_has_python_code(get_data_dir):
6169
invalid_config_file_path = os.path.join(get_data_dir, "config_file_with_code.yaml")
6270
# no exceptions will be thrown with yaml.unsafe_load
63-
yaml.unsafe_load(open(invalid_config_file_path, "r"))
71+
with open(invalid_config_file_path, "r") as f:
72+
yaml.unsafe_load(f)
6473
# PyYAML will throw exceptions for yaml.safe_load. SageMaker Config is using
6574
# yaml.safe_load internally
6675
with pytest.raises(ConstructorError) as exception_info:
@@ -228,7 +237,8 @@ def test_merge_of_s3_default_config_file_and_regular_config_file(
228237
get_data_dir, expected_merged_config, s3_resource_mock
229238
):
230239
config_file_content_path = os.path.join(get_data_dir, "sample_config_for_merge.yaml")
231-
config_file_as_yaml = open(config_file_content_path, "r").read()
240+
with open(config_file_content_path, "r") as f:
241+
config_file_as_yaml = f.read()
232242
config_file_bucket = "config-file-bucket"
233243
config_file_s3_prefix = "config/config.yaml"
234244
config_file_s3_uri = "s3://{}/{}".format(config_file_bucket, config_file_s3_prefix)
@@ -440,8 +450,11 @@ def test_load_local_mode_config(mock_load_config):
440450
mock_load_config.assert_called_with(_DEFAULT_LOCAL_MODE_CONFIG_FILE_PATH)
441451

442452

443-
def test_load_local_mode_config_when_config_file_is_not_found():
453+
@patch("sagemaker.config.config._load_config_from_file", side_effect=_raise_valueerror)
454+
def test_load_local_mode_config_when_config_file_is_not_found(mock_load_config):
455+
# Patch is needed because one might actually have a local config file
444456
assert load_local_mode_config() is None
457+
mock_load_config.assert_called_with(_DEFAULT_LOCAL_MODE_CONFIG_FILE_PATH)
445458

446459

447460
@pytest.mark.parametrize(

tests/unit/sagemaker/jumpstart/constants.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -840,8 +840,22 @@
840840
"model_package_arn": "$gpu_model_package_arn",
841841
}
842842
},
843+
"g5": {
844+
"properties": {
845+
"resource_requirements": {
846+
"num_accelerators": 888810,
847+
"randon-field-2": 2222,
848+
}
849+
}
850+
},
843851
"m2": {"regional_properties": {"image_uri": "$cpu_image_uri"}},
844852
"c2": {"regional_properties": {"image_uri": "$cpu_image_uri"}},
853+
"ml.g5.xlarge": {
854+
"properties": {
855+
"environment_variables": {"TENSOR_PARALLEL_DEGREE": "8"},
856+
"resource_requirements": {"num_accelerators": 10},
857+
}
858+
},
845859
"ml.g5.48xlarge": {
846860
"properties": {"environment_variables": {"TENSOR_PARALLEL_DEGREE": "8"}}
847861
},
@@ -857,6 +871,12 @@
857871
"framework_version": "1.5.0",
858872
"py_version": "py3",
859873
},
874+
"dynamic_container_deployment_supported": True,
875+
"hosting_resource_requirements": {
876+
"min_memory_mb": 81999,
877+
"num_accelerators": 1,
878+
"random_field_1": 1,
879+
},
860880
"hosting_artifact_key": "pytorch-infer/infer-pytorch-ic-mobilenet-v2.tar.gz",
861881
"training_artifact_key": "pytorch-training/train-pytorch-ic-mobilenet-v2.tar.gz",
862882
"hosting_script_key": "source-directory-tarballs/pytorch/inference/ic/v1.0.0/sourcedir.tar.gz",

tests/unit/sagemaker/jumpstart/test_types.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
"variants": {
3535
"ml.p2.12xlarge": {
3636
"properties": {
37+
"resource_requirements": {"req1": 1, "req2": {"1": 2, "2": 3}, "req3": 9},
3738
"environment_variables": {"TENSOR_PARALLEL_DEGREE": "4"},
3839
"supported_inference_instance_types": ["ml.p5.xlarge"],
3940
"default_inference_instance_type": "ml.p5.xlarge",
@@ -60,6 +61,11 @@
6061
"p2": {
6162
"regional_properties": {"image_uri": "$gpu_image_uri"},
6263
"properties": {
64+
"resource_requirements": {
65+
"req2": {"2": 5, "9": 999},
66+
"req3": 999,
67+
"req4": "blah",
68+
},
6369
"supported_inference_instance_types": ["ml.p2.xlarge", "ml.p3.xlarge"],
6470
"default_inference_instance_type": "ml.p2.xlarge",
6571
"metrics": [
@@ -879,3 +885,20 @@ def test_jumpstart_training_artifact_key_instance_variants():
879885
)
880886
is None
881887
)
888+
889+
890+
def test_jumpstart_resource_requirements_instance_variants():
891+
assert INSTANCE_TYPE_VARIANT.get_instance_specific_resource_requirements(
892+
instance_type="ml.p2.xlarge"
893+
) == {"req2": {"2": 5, "9": 999}, "req3": 999, "req4": "blah"}
894+
895+
assert INSTANCE_TYPE_VARIANT.get_instance_specific_resource_requirements(
896+
instance_type="ml.p2.12xlarge"
897+
) == {"req1": 1, "req2": {"1": 2, "2": 3}, "req3": 9, "req4": "blah"}
898+
899+
assert (
900+
INSTANCE_TYPE_VARIANT.get_instance_specific_resource_requirements(
901+
instance_type="ml.p99.12xlarge"
902+
)
903+
== {}
904+
)

0 commit comments

Comments
 (0)