Skip to content

Commit de2a0e7

Browse files
authored
Merge branch 'master' into 2PR
2 parents 8abf8a9 + 790bd87 commit de2a0e7

File tree

15 files changed

+240
-35
lines changed

15 files changed

+240
-35
lines changed

doc/api/prep_data/feature_store.rst

+16-2
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ Feature Definition
6060
:members:
6161
:show-inheritance:
6262

63+
6364
Inputs
6465
******
6566

@@ -181,9 +182,13 @@ Feature Processor Data Source
181182
:members:
182183
:show-inheritance:
183184

185+
.. autoclass:: sagemaker.feature_store.feature_processor.PySparkDataSource
186+
:members:
187+
:show-inheritance:
184188

185-
Feature Processor Scheduler
186-
***************************
189+
190+
Feature Processor Scheduler and Triggers
191+
****************************************
187192

188193
.. automethod:: sagemaker.feature_store.feature_processor.to_pipeline
189194

@@ -196,3 +201,12 @@ Feature Processor Scheduler
196201
.. automethod:: sagemaker.feature_store.feature_processor.describe
197202

198203
.. automethod:: sagemaker.feature_store.feature_processor.list_pipelines
204+
205+
.. automethod:: sagemaker.feature_store.feature_processor.put_trigger
206+
207+
.. automethod:: sagemaker.feature_store.feature_processor.enable_trigger
208+
209+
.. automethod:: sagemaker.feature_store.feature_processor.disable_trigger
210+
211+
.. automethod:: sagemaker.feature_store.feature_processor.delete_trigger
212+

requirements/extras/test_requirements.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ awslogs==0.14.0
1212
black==22.3.0
1313
stopit==1.1.2
1414
# Update tox.ini to have correct version of airflow constraints file
15-
apache-airflow==2.8.1
15+
apache-airflow==2.8.2
1616
apache-airflow-providers-amazon==7.2.1
1717
attrs>=23.1.0,<24
1818
fabric==2.6.0

src/sagemaker/config/config.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,9 @@ def _load_config_from_file(file_path: str) -> dict:
181181
f"Provide a valid file path"
182182
)
183183
logger.debug("Fetching defaults config from location: %s", file_path)
184-
return yaml.safe_load(open(inferred_file_path, "r"))
184+
with open(inferred_file_path, "r") as f:
185+
content = yaml.safe_load(f)
186+
return content
185187

186188

187189
def _load_config_from_s3(s3_uri, s3_resource_for_config) -> dict:

src/sagemaker/jumpstart/artifacts/resource_requirements.py

+52-14
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
"""This module contains functions for obtaining JumpStart resoure requirements."""
1414
from __future__ import absolute_import
1515

16-
from typing import Optional
16+
from typing import Dict, Optional, Tuple
1717

1818
from sagemaker.jumpstart.constants import (
1919
DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
@@ -28,6 +28,20 @@
2828
from sagemaker.session import Session
2929
from sagemaker.compute_resource_requirements.resource_requirements import ResourceRequirements
3030

31+
REQUIREMENT_TYPE_TO_SPEC_FIELD_NAME_TO_RESOURCE_REQUIREMENT_NAME_MAP: Dict[
32+
str, Dict[str, Tuple[str, str]]
33+
] = {
34+
"requests": {
35+
"num_accelerators": ("num_accelerators", "num_accelerators"),
36+
"num_cpus": ("num_cpus", "num_cpus"),
37+
"copies": ("copies", "copy_count"),
38+
"min_memory_mb": ("memory", "min_memory"),
39+
},
40+
"limits": {
41+
"max_memory_mb": ("memory", "max_memory"),
42+
},
43+
}
44+
3145

3246
def _retrieve_default_resources(
3347
model_id: str,
@@ -37,6 +51,7 @@ def _retrieve_default_resources(
3751
tolerate_vulnerable_model: bool = False,
3852
tolerate_deprecated_model: bool = False,
3953
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
54+
instance_type: Optional[str] = None,
4055
) -> ResourceRequirements:
4156
"""Retrieves the default resource requirements for the model.
4257
@@ -60,6 +75,8 @@ def _retrieve_default_resources(
6075
object, used for SageMaker interactions. If not
6176
specified, one is created using the default AWS configuration
6277
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
78+
instance_type (str): An instance type to optionally supply in order to get
79+
host requirements specific for the instance type.
6380
Returns:
6481
str: The default resource requirements to use for the model or None.
6582
@@ -87,23 +104,44 @@ def _retrieve_default_resources(
87104
is_dynamic_container_deployment_supported = (
88105
model_specs.dynamic_container_deployment_supported
89106
)
90-
default_resource_requirements = model_specs.hosting_resource_requirements
107+
default_resource_requirements: Dict[str, int] = (
108+
model_specs.hosting_resource_requirements or {}
109+
)
91110
else:
92111
raise NotImplementedError(
93112
f"Unsupported script scope for retrieving default resource requirements: '{scope}'"
94113
)
95114

115+
instance_specific_resource_requirements: Dict[str, int] = (
116+
model_specs.hosting_instance_type_variants.get_instance_specific_resource_requirements(
117+
instance_type
118+
)
119+
if instance_type
120+
and getattr(model_specs, "hosting_instance_type_variants", None) is not None
121+
else {}
122+
)
123+
124+
default_resource_requirements = {
125+
**default_resource_requirements,
126+
**instance_specific_resource_requirements,
127+
}
128+
96129
if is_dynamic_container_deployment_supported:
97-
requests = {}
98-
if "num_accelerators" in default_resource_requirements:
99-
requests["num_accelerators"] = default_resource_requirements["num_accelerators"]
100-
if "min_memory_mb" in default_resource_requirements:
101-
requests["memory"] = default_resource_requirements["min_memory_mb"]
102-
if "num_cpus" in default_resource_requirements:
103-
requests["num_cpus"] = default_resource_requirements["num_cpus"]
104-
105-
limits = {}
106-
if "max_memory_mb" in default_resource_requirements:
107-
limits["memory"] = default_resource_requirements["max_memory_mb"]
108-
return ResourceRequirements(requests=requests, limits=limits)
130+
131+
all_resource_requirement_kwargs = {}
132+
133+
for (
134+
requirement_type,
135+
spec_field_to_resource_requirement_map,
136+
) in REQUIREMENT_TYPE_TO_SPEC_FIELD_NAME_TO_RESOURCE_REQUIREMENT_NAME_MAP.items():
137+
requirement_kwargs = {}
138+
for spec_field, resource_requirement in spec_field_to_resource_requirement_map.items():
139+
if spec_field in default_resource_requirements:
140+
requirement_kwargs[resource_requirement[0]] = default_resource_requirements[
141+
spec_field
142+
]
143+
144+
all_resource_requirement_kwargs[requirement_type] = requirement_kwargs
145+
146+
return ResourceRequirements(**all_resource_requirement_kwargs)
109147
return None

src/sagemaker/jumpstart/factory/model.py

+1
Original file line numberDiff line numberDiff line change
@@ -481,6 +481,7 @@ def _add_resources_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModel
481481
tolerate_deprecated_model=kwargs.tolerate_deprecated_model,
482482
tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model,
483483
sagemaker_session=kwargs.sagemaker_session,
484+
instance_type=kwargs.instance_type,
484485
)
485486

486487
return kwargs

src/sagemaker/jumpstart/types.py

+23
Original file line numberDiff line numberDiff line change
@@ -478,6 +478,29 @@ def get_instance_specific_artifact_key(self, instance_type: str) -> Optional[str
478478
instance_type=instance_type, property_name="artifact_key"
479479
)
480480

481+
def get_instance_specific_resource_requirements(self, instance_type: str) -> Optional[str]:
482+
"""Returns instance specific resource requirements.
483+
484+
If a value exists for both the instance family and instance type, the instance type value
485+
is chosen.
486+
"""
487+
488+
instance_specific_resource_requirements: dict = (
489+
self.variants.get(instance_type, {})
490+
.get("properties", {})
491+
.get("resource_requirements", {})
492+
)
493+
494+
instance_type_family = get_instance_type_family(instance_type)
495+
496+
instance_family_resource_requirements: dict = (
497+
self.variants.get(instance_type_family, {})
498+
.get("properties", {})
499+
.get("resource_requirements", {})
500+
)
501+
502+
return {**instance_family_resource_requirements, **instance_specific_resource_requirements}
503+
481504
def _get_instance_specific_property(
482505
self, instance_type: str, property_name: str
483506
) -> Optional[str]:

src/sagemaker/local/image.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -860,7 +860,9 @@ def _create_docker_host(
860860
# to setting --runtime=nvidia in the docker commandline.
861861
if self.instance_type == "local_gpu":
862862
host_config["deploy"] = {
863-
"resources": {"reservations": {"devices": [{"capabilities": ["gpu"]}]}}
863+
"resources": {
864+
"reservations": {"devices": [{"count": "all", "capabilities": ["gpu"]}]}
865+
}
864866
}
865867

866868
if not self.is_studio and command == "serve":

src/sagemaker/remote_function/client.py

-5
Original file line numberDiff line numberDiff line change
@@ -694,11 +694,6 @@ def __init__(
694694
encrypt_inter_container_traffic (bool): A flag that specifies whether traffic between
695695
training containers is encrypted for the training job. Defaults to ``False``.
696696
697-
enable_network_isolation (bool): A flag that specifies whether container will run in
698-
network isolation mode. Defaults to ``False``. Network isolation mode restricts the
699-
container access to outside networks (such as the Internet). The container does not
700-
make any inbound or outbound network calls. Also known as Internet-free mode.
701-
702697
spark_config (SparkConfig): Configurations to the Spark application that runs on
703698
Spark image. If ``spark_config`` is specified, a SageMaker Spark image uri
704699
will be used for training. Note that ``image_uri`` can not be specified at the

src/sagemaker/resource_requirements.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
import logging
1818
from typing import Optional
19+
from sagemaker.compute_resource_requirements.resource_requirements import ResourceRequirements
1920

2021
from sagemaker.jumpstart import utils as jumpstart_utils
2122
from sagemaker.jumpstart import artifacts
@@ -33,7 +34,8 @@ def retrieve_default(
3334
tolerate_vulnerable_model: bool = False,
3435
tolerate_deprecated_model: bool = False,
3536
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
36-
) -> str:
37+
instance_type: Optional[str] = None,
38+
) -> ResourceRequirements:
3739
"""Retrieves the default resource requirements for the model matching the given arguments.
3840
3941
Args:
@@ -56,6 +58,8 @@ def retrieve_default(
5658
object, used for SageMaker interactions. If not
5759
specified, one is created using the default AWS configuration
5860
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
61+
instance_type (str): An instance type to optionally supply in order to get
62+
host requirements specific for the instance type.
5963
Returns:
6064
str: The default resource requirements to use for the model.
6165
@@ -79,4 +83,5 @@ def retrieve_default(
7983
tolerate_vulnerable_model,
8084
tolerate_deprecated_model,
8185
sagemaker_session=sagemaker_session,
86+
instance_type=instance_type,
8287
)

tests/unit/sagemaker/config/test_config.py

+18-5
Original file line numberDiff line numberDiff line change
@@ -34,15 +34,23 @@
3434
@pytest.fixture()
3535
def config_file_as_yaml(get_data_dir):
3636
config_file_path = os.path.join(get_data_dir, "config.yaml")
37-
return open(config_file_path, "r").read()
37+
with open(config_file_path, "r") as f:
38+
content = f.read()
39+
return content
3840

3941

4042
@pytest.fixture()
4143
def expected_merged_config(get_data_dir):
4244
expected_merged_config_file_path = os.path.join(
4345
get_data_dir, "expected_output_config_after_merge.yaml"
4446
)
45-
return yaml.safe_load(open(expected_merged_config_file_path, "r").read())
47+
with open(expected_merged_config_file_path, "r") as f:
48+
content = yaml.safe_load(f.read())
49+
return content
50+
51+
52+
def _raise_valueerror(*args):
53+
raise ValueError(args)
4654

4755

4856
def test_config_when_default_config_file_and_user_config_file_is_not_found():
@@ -60,7 +68,8 @@ def test_config_when_overriden_default_config_file_is_not_found(get_data_dir):
6068
def test_invalid_config_file_which_has_python_code(get_data_dir):
6169
invalid_config_file_path = os.path.join(get_data_dir, "config_file_with_code.yaml")
6270
# no exceptions will be thrown with yaml.unsafe_load
63-
yaml.unsafe_load(open(invalid_config_file_path, "r"))
71+
with open(invalid_config_file_path, "r") as f:
72+
yaml.unsafe_load(f)
6473
# PyYAML will throw exceptions for yaml.safe_load. SageMaker Config is using
6574
# yaml.safe_load internally
6675
with pytest.raises(ConstructorError) as exception_info:
@@ -228,7 +237,8 @@ def test_merge_of_s3_default_config_file_and_regular_config_file(
228237
get_data_dir, expected_merged_config, s3_resource_mock
229238
):
230239
config_file_content_path = os.path.join(get_data_dir, "sample_config_for_merge.yaml")
231-
config_file_as_yaml = open(config_file_content_path, "r").read()
240+
with open(config_file_content_path, "r") as f:
241+
config_file_as_yaml = f.read()
232242
config_file_bucket = "config-file-bucket"
233243
config_file_s3_prefix = "config/config.yaml"
234244
config_file_s3_uri = "s3://{}/{}".format(config_file_bucket, config_file_s3_prefix)
@@ -440,8 +450,11 @@ def test_load_local_mode_config(mock_load_config):
440450
mock_load_config.assert_called_with(_DEFAULT_LOCAL_MODE_CONFIG_FILE_PATH)
441451

442452

443-
def test_load_local_mode_config_when_config_file_is_not_found():
453+
@patch("sagemaker.config.config._load_config_from_file", side_effect=_raise_valueerror)
454+
def test_load_local_mode_config_when_config_file_is_not_found(mock_load_config):
455+
# Patch is needed because one might actually have a local config file
444456
assert load_local_mode_config() is None
457+
mock_load_config.assert_called_with(_DEFAULT_LOCAL_MODE_CONFIG_FILE_PATH)
445458

446459

447460
@pytest.mark.parametrize(

tests/unit/sagemaker/jumpstart/constants.py

+20
Original file line numberDiff line numberDiff line change
@@ -840,8 +840,22 @@
840840
"model_package_arn": "$gpu_model_package_arn",
841841
}
842842
},
843+
"g5": {
844+
"properties": {
845+
"resource_requirements": {
846+
"num_accelerators": 888810,
847+
"randon-field-2": 2222,
848+
}
849+
}
850+
},
843851
"m2": {"regional_properties": {"image_uri": "$cpu_image_uri"}},
844852
"c2": {"regional_properties": {"image_uri": "$cpu_image_uri"}},
853+
"ml.g5.xlarge": {
854+
"properties": {
855+
"environment_variables": {"TENSOR_PARALLEL_DEGREE": "8"},
856+
"resource_requirements": {"num_accelerators": 10},
857+
}
858+
},
845859
"ml.g5.48xlarge": {
846860
"properties": {"environment_variables": {"TENSOR_PARALLEL_DEGREE": "8"}}
847861
},
@@ -857,6 +871,12 @@
857871
"framework_version": "1.5.0",
858872
"py_version": "py3",
859873
},
874+
"dynamic_container_deployment_supported": True,
875+
"hosting_resource_requirements": {
876+
"min_memory_mb": 81999,
877+
"num_accelerators": 1,
878+
"random_field_1": 1,
879+
},
860880
"hosting_artifact_key": "pytorch-infer/infer-pytorch-ic-mobilenet-v2.tar.gz",
861881
"training_artifact_key": "pytorch-training/train-pytorch-ic-mobilenet-v2.tar.gz",
862882
"hosting_script_key": "source-directory-tarballs/pytorch/inference/ic/v1.0.0/sourcedir.tar.gz",

tests/unit/sagemaker/jumpstart/test_types.py

+23
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
"variants": {
3535
"ml.p2.12xlarge": {
3636
"properties": {
37+
"resource_requirements": {"req1": 1, "req2": {"1": 2, "2": 3}, "req3": 9},
3738
"environment_variables": {"TENSOR_PARALLEL_DEGREE": "4"},
3839
"supported_inference_instance_types": ["ml.p5.xlarge"],
3940
"default_inference_instance_type": "ml.p5.xlarge",
@@ -60,6 +61,11 @@
6061
"p2": {
6162
"regional_properties": {"image_uri": "$gpu_image_uri"},
6263
"properties": {
64+
"resource_requirements": {
65+
"req2": {"2": 5, "9": 999},
66+
"req3": 999,
67+
"req4": "blah",
68+
},
6369
"supported_inference_instance_types": ["ml.p2.xlarge", "ml.p3.xlarge"],
6470
"default_inference_instance_type": "ml.p2.xlarge",
6571
"metrics": [
@@ -879,3 +885,20 @@ def test_jumpstart_training_artifact_key_instance_variants():
879885
)
880886
is None
881887
)
888+
889+
890+
def test_jumpstart_resource_requirements_instance_variants():
891+
assert INSTANCE_TYPE_VARIANT.get_instance_specific_resource_requirements(
892+
instance_type="ml.p2.xlarge"
893+
) == {"req2": {"2": 5, "9": 999}, "req3": 999, "req4": "blah"}
894+
895+
assert INSTANCE_TYPE_VARIANT.get_instance_specific_resource_requirements(
896+
instance_type="ml.p2.12xlarge"
897+
) == {"req1": 1, "req2": {"1": 2, "2": 3}, "req3": 9, "req4": "blah"}
898+
899+
assert (
900+
INSTANCE_TYPE_VARIANT.get_instance_specific_resource_requirements(
901+
instance_type="ml.p99.12xlarge"
902+
)
903+
== {}
904+
)

0 commit comments

Comments
 (0)