From c5af85757c2c620340feb856b28fa04ede670a9c Mon Sep 17 00:00:00 2001 From: Evan Kravitz Date: Tue, 30 Jan 2024 15:39:57 +0000 Subject: [PATCH 1/4] feat: instance specific jumpstart host requirements --- .../artifacts/resource_requirements.py | 23 ++++++++- src/sagemaker/jumpstart/factory/model.py | 1 + src/sagemaker/jumpstart/types.py | 23 +++++++++ src/sagemaker/resource_requirements.py | 4 ++ tests/unit/sagemaker/jumpstart/constants.py | 20 ++++++++ tests/unit/sagemaker/jumpstart/test_types.py | 23 +++++++++ .../jumpstart/test_resource_requirements.py | 49 +++++++++++++++++++ 7 files changed, 141 insertions(+), 2 deletions(-) diff --git a/src/sagemaker/jumpstart/artifacts/resource_requirements.py b/src/sagemaker/jumpstart/artifacts/resource_requirements.py index 8356d1efac..08ff1b4606 100644 --- a/src/sagemaker/jumpstart/artifacts/resource_requirements.py +++ b/src/sagemaker/jumpstart/artifacts/resource_requirements.py @@ -13,7 +13,7 @@ """This module contains functions for obtaining JumpStart resoure requirements.""" from __future__ import absolute_import -from typing import Optional +from typing import Dict, Optional from sagemaker.jumpstart.constants import ( DEFAULT_JUMPSTART_SAGEMAKER_SESSION, @@ -37,6 +37,7 @@ def _retrieve_default_resources( tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + instance_type: Optional[str] = None, ) -> ResourceRequirements: """Retrieves the default resource requirements for the model. @@ -60,6 +61,8 @@ def _retrieve_default_resources( object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + instance_type (str): An instance type to optionally supply in order to get + host requirements specific for the instance type. Returns: str: The default resource requirements to use for the model or None. @@ -87,12 +90,28 @@ def _retrieve_default_resources( is_dynamic_container_deployment_supported = ( model_specs.dynamic_container_deployment_supported ) - default_resource_requirements = model_specs.hosting_resource_requirements + default_resource_requirements: Dict[str, int] = ( + model_specs.hosting_resource_requirements or {} + ) else: raise NotImplementedError( f"Unsupported script scope for retrieving default resource requirements: '{scope}'" ) + instance_specific_resource_requirements: Dict[str, int] = ( + model_specs.hosting_instance_type_variants.get_instance_specific_resource_requirements( + instance_type + ) + if instance_type + and getattr(model_specs, "hosting_instance_type_variants", None) is not None + else {} + ) + + default_resource_requirements = { + **default_resource_requirements, + **instance_specific_resource_requirements, + } + if is_dynamic_container_deployment_supported: requests = {} if "num_accelerators" in default_resource_requirements: diff --git a/src/sagemaker/jumpstart/factory/model.py b/src/sagemaker/jumpstart/factory/model.py index 64e4727116..1b41cad714 100644 --- a/src/sagemaker/jumpstart/factory/model.py +++ b/src/sagemaker/jumpstart/factory/model.py @@ -481,6 +481,7 @@ def _add_resources_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModel tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, sagemaker_session=kwargs.sagemaker_session, + instance_type=kwargs.instance_type, ) return kwargs diff --git a/src/sagemaker/jumpstart/types.py b/src/sagemaker/jumpstart/types.py index 49d3e295c5..810d1c4cd3 100644 --- a/src/sagemaker/jumpstart/types.py +++ b/src/sagemaker/jumpstart/types.py @@ -478,6 +478,29 @@ def get_instance_specific_artifact_key(self, instance_type: str) -> Optional[str instance_type=instance_type, property_name="artifact_key" ) + def get_instance_specific_resource_requirements(self, instance_type: str) -> Optional[str]: + """Returns instance specific resource requirements. + + If a value exists for both the instance family and instance type, the instance type value + is chosen. + """ + + instance_specific_resource_requirements: dict = ( + self.variants.get(instance_type, {}) + .get("properties", {}) + .get("resource_requirements", {}) + ) + + instance_type_family = get_instance_type_family(instance_type) + + instance_family_resource_requirements: dict = ( + self.variants.get(instance_type_family, {}) + .get("properties", {}) + .get("resource_requirements", {}) + ) + + return {**instance_family_resource_requirements, **instance_specific_resource_requirements} + def _get_instance_specific_property( self, instance_type: str, property_name: str ) -> Optional[str]: diff --git a/src/sagemaker/resource_requirements.py b/src/sagemaker/resource_requirements.py index 446d034bf3..ff16714b4e 100644 --- a/src/sagemaker/resource_requirements.py +++ b/src/sagemaker/resource_requirements.py @@ -33,6 +33,7 @@ def retrieve_default( tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + instance_type: Optional[str] = None, ) -> str: """Retrieves the default resource requirements for the model matching the given arguments. @@ -56,6 +57,8 @@ def retrieve_default( object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + instance_type (str): An instance type to optionally supply in order to get + host requirements specific for the instance type. Returns: str: The default resource requirements to use for the model. @@ -79,4 +82,5 @@ def retrieve_default( tolerate_vulnerable_model, tolerate_deprecated_model, sagemaker_session=sagemaker_session, + instance_type=instance_type, ) diff --git a/tests/unit/sagemaker/jumpstart/constants.py b/tests/unit/sagemaker/jumpstart/constants.py index a3c4c747f7..605253466a 100644 --- a/tests/unit/sagemaker/jumpstart/constants.py +++ b/tests/unit/sagemaker/jumpstart/constants.py @@ -840,8 +840,22 @@ "model_package_arn": "$gpu_model_package_arn", } }, + "g5": { + "properties": { + "resource_requirements": { + "num_accelerators": 888810, + "randon-field-2": 2222, + } + } + }, "m2": {"regional_properties": {"image_uri": "$cpu_image_uri"}}, "c2": {"regional_properties": {"image_uri": "$cpu_image_uri"}}, + "ml.g5.xlarge": { + "properties": { + "environment_variables": {"TENSOR_PARALLEL_DEGREE": "8"}, + "resource_requirements": {"num_accelerators": 10}, + } + }, "ml.g5.48xlarge": { "properties": {"environment_variables": {"TENSOR_PARALLEL_DEGREE": "8"}} }, @@ -857,6 +871,12 @@ "framework_version": "1.5.0", "py_version": "py3", }, + "dynamic_container_deployment_supported": True, + "hosting_resource_requirements": { + "min_memory_mb": 81999, + "num_accelerators": 1, + "random_field_1": 1, + }, "hosting_artifact_key": "pytorch-infer/infer-pytorch-ic-mobilenet-v2.tar.gz", "training_artifact_key": "pytorch-training/train-pytorch-ic-mobilenet-v2.tar.gz", "hosting_script_key": "source-directory-tarballs/pytorch/inference/ic/v1.0.0/sourcedir.tar.gz", diff --git a/tests/unit/sagemaker/jumpstart/test_types.py b/tests/unit/sagemaker/jumpstart/test_types.py index 82e69e1d89..a9daad934d 100644 --- a/tests/unit/sagemaker/jumpstart/test_types.py +++ b/tests/unit/sagemaker/jumpstart/test_types.py @@ -34,6 +34,7 @@ "variants": { "ml.p2.12xlarge": { "properties": { + "resource_requirements": {"req1": 1, "req2": {"1": 2, "2": 3}, "req3": 9}, "environment_variables": {"TENSOR_PARALLEL_DEGREE": "4"}, "supported_inference_instance_types": ["ml.p5.xlarge"], "default_inference_instance_type": "ml.p5.xlarge", @@ -60,6 +61,11 @@ "p2": { "regional_properties": {"image_uri": "$gpu_image_uri"}, "properties": { + "resource_requirements": { + "req2": {"2": 5, "9": 999}, + "req3": 999, + "req4": "blah", + }, "supported_inference_instance_types": ["ml.p2.xlarge", "ml.p3.xlarge"], "default_inference_instance_type": "ml.p2.xlarge", "metrics": [ @@ -879,3 +885,20 @@ def test_jumpstart_training_artifact_key_instance_variants(): ) is None ) + + +def test_jumpstart_resource_requirements_instance_variants(): + assert INSTANCE_TYPE_VARIANT.get_instance_specific_resource_requirements( + instance_type="ml.p2.xlarge" + ) == {"req2": {"2": 5, "9": 999}, "req3": 999, "req4": "blah"} + + assert INSTANCE_TYPE_VARIANT.get_instance_specific_resource_requirements( + instance_type="ml.p2.12xlarge" + ) == {"req1": 1, "req2": {"1": 2, "2": 3}, "req3": 9, "req4": "blah"} + + assert ( + INSTANCE_TYPE_VARIANT.get_instance_specific_resource_requirements( + instance_type="ml.p99.12xlarge" + ) + == {} + ) diff --git a/tests/unit/sagemaker/resource_requirements/jumpstart/test_resource_requirements.py b/tests/unit/sagemaker/resource_requirements/jumpstart/test_resource_requirements.py index 28b53270f8..aa9a3dc729 100644 --- a/tests/unit/sagemaker/resource_requirements/jumpstart/test_resource_requirements.py +++ b/tests/unit/sagemaker/resource_requirements/jumpstart/test_resource_requirements.py @@ -50,6 +50,55 @@ def test_jumpstart_resource_requirements(patched_get_model_specs): patched_get_model_specs.reset_mock() +@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") +def test_jumpstart_resource_requirements_instance_type_variants(patched_get_model_specs): + + patched_get_model_specs.side_effect = get_special_model_spec + region = "us-west-2" + mock_client = boto3.client("s3") + mock_session = Mock(s3_client=mock_client) + + model_id, model_version = "variant-model", "*" + default_inference_resource_requirements = resource_requirements.retrieve_default( + region=region, + model_id=model_id, + model_version=model_version, + scope="inference", + sagemaker_session=mock_session, + instance_type="ml.g5.xlarge", + ) + assert default_inference_resource_requirements.requests == { + "memory": 81999, + "num_accelerators": 10, + } + + default_inference_resource_requirements = resource_requirements.retrieve_default( + region=region, + model_id=model_id, + model_version=model_version, + scope="inference", + sagemaker_session=mock_session, + instance_type="ml.g5.555xlarge", + ) + assert default_inference_resource_requirements.requests == { + "memory": 81999, + "num_accelerators": 888810, + } + + default_inference_resource_requirements = resource_requirements.retrieve_default( + region=region, + model_id=model_id, + model_version=model_version, + scope="inference", + sagemaker_session=mock_session, + instance_type="ml.f9.555xlarge", + ) + assert default_inference_resource_requirements.requests == { + "memory": 81999, + "num_accelerators": 1, + } + + @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") def test_jumpstart_no_supported_resource_requirements(patched_get_model_specs): patched_get_model_specs.side_effect = get_special_model_spec From c96649eda585dcd41c4f543879d1c30c7da197b3 Mon Sep 17 00:00:00 2001 From: Evan Kravitz Date: Tue, 30 Jan 2024 17:41:53 +0000 Subject: [PATCH 2/4] chore: add js support for copies resource requirement, enforce coupling with ResourceRequirements class --- .../artifacts/resource_requirements.py | 42 +++++++++++++------ src/sagemaker/resource_requirements.py | 3 +- .../jumpstart/test_resource_requirements.py | 19 +++++++++ 3 files changed, 51 insertions(+), 13 deletions(-) diff --git a/src/sagemaker/jumpstart/artifacts/resource_requirements.py b/src/sagemaker/jumpstart/artifacts/resource_requirements.py index 08ff1b4606..740f5c0baf 100644 --- a/src/sagemaker/jumpstart/artifacts/resource_requirements.py +++ b/src/sagemaker/jumpstart/artifacts/resource_requirements.py @@ -28,6 +28,18 @@ from sagemaker.session import Session from sagemaker.compute_resource_requirements.resource_requirements import ResourceRequirements +REQUIREMENT_TYPE_TO_SPEC_FIELD_NAME_TO_RESOURCE_REQUIREMENT_NAME_MAP: Dict[str, Dict[str, str]] = { + "requests": { + "num_accelerators": ("num_accelerators", "num_accelerators"), + "num_cpus": ("num_cpus", "num_cpus"), + "copies": ("copies", "copy_count"), + "min_memory_mb": ("memory", "min_memory"), + }, + "limits": { + "max_memory_mb": ("memory", "max_memory"), + }, +} + def _retrieve_default_resources( model_id: str, @@ -113,16 +125,22 @@ def _retrieve_default_resources( } if is_dynamic_container_deployment_supported: - requests = {} - if "num_accelerators" in default_resource_requirements: - requests["num_accelerators"] = default_resource_requirements["num_accelerators"] - if "min_memory_mb" in default_resource_requirements: - requests["memory"] = default_resource_requirements["min_memory_mb"] - if "num_cpus" in default_resource_requirements: - requests["num_cpus"] = default_resource_requirements["num_cpus"] - - limits = {} - if "max_memory_mb" in default_resource_requirements: - limits["memory"] = default_resource_requirements["max_memory_mb"] - return ResourceRequirements(requests=requests, limits=limits) + + all_resource_requirement_kwargs = {} + + for ( + requirement_type, + spec_field_to_resource_requirement_map, + ) in REQUIREMENT_TYPE_TO_SPEC_FIELD_NAME_TO_RESOURCE_REQUIREMENT_NAME_MAP.items(): + requirement_type + requirement_kwargs = {} + for spec_field, resource_requirement in spec_field_to_resource_requirement_map.items(): + if spec_field in default_resource_requirements: + requirement_kwargs[resource_requirement[0]] = default_resource_requirements[ + spec_field + ] + + all_resource_requirement_kwargs[requirement_type] = requirement_kwargs + + return ResourceRequirements(**all_resource_requirement_kwargs) return None diff --git a/src/sagemaker/resource_requirements.py b/src/sagemaker/resource_requirements.py index ff16714b4e..93b2833a35 100644 --- a/src/sagemaker/resource_requirements.py +++ b/src/sagemaker/resource_requirements.py @@ -16,6 +16,7 @@ import logging from typing import Optional +from sagemaker.compute_resource_requirements.resource_requirements import ResourceRequirements from sagemaker.jumpstart import utils as jumpstart_utils from sagemaker.jumpstart import artifacts @@ -34,7 +35,7 @@ def retrieve_default( tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, instance_type: Optional[str] = None, -) -> str: +) -> ResourceRequirements: """Retrieves the default resource requirements for the model matching the given arguments. Args: diff --git a/tests/unit/sagemaker/resource_requirements/jumpstart/test_resource_requirements.py b/tests/unit/sagemaker/resource_requirements/jumpstart/test_resource_requirements.py index aa9a3dc729..b0cef0e3d4 100644 --- a/tests/unit/sagemaker/resource_requirements/jumpstart/test_resource_requirements.py +++ b/tests/unit/sagemaker/resource_requirements/jumpstart/test_resource_requirements.py @@ -18,6 +18,10 @@ import pytest from sagemaker import resource_requirements +from sagemaker.compute_resource_requirements.resource_requirements import ResourceRequirements +from sagemaker.jumpstart.artifacts.resource_requirements import ( + REQUIREMENT_TYPE_TO_SPEC_FIELD_NAME_TO_RESOURCE_REQUIREMENT_NAME_MAP, +) from tests.unit.sagemaker.jumpstart.utils import get_spec_from_base_spec, get_special_model_spec @@ -129,3 +133,18 @@ def test_jumpstart_no_supported_resource_requirements(patched_get_model_specs): resource_requirements.retrieve_default( region=region, model_id=model_id, model_version=model_version, scope="training" ) + + +def test_jumpstart_supports_all_resource_requirement_fields(): + + all_tracked_resource_requirement_fields = { + field + for requirements in REQUIREMENT_TYPE_TO_SPEC_FIELD_NAME_TO_RESOURCE_REQUIREMENT_NAME_MAP.values() + for _, field in requirements.values() + } + + excluded_resource_requirement_fields = {"requests", "limits"} + assert ( + set(ResourceRequirements().__dict__.keys()) - excluded_resource_requirement_fields + == all_tracked_resource_requirement_fields + ) From 8a949e23a574cfc62a040b4178b4c60acd3f21cb Mon Sep 17 00:00:00 2001 From: Evan Kravitz Date: Tue, 30 Jan 2024 17:42:50 +0000 Subject: [PATCH 3/4] fix: typing --- src/sagemaker/jumpstart/artifacts/resource_requirements.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/sagemaker/jumpstart/artifacts/resource_requirements.py b/src/sagemaker/jumpstart/artifacts/resource_requirements.py index 740f5c0baf..6283b6bcd1 100644 --- a/src/sagemaker/jumpstart/artifacts/resource_requirements.py +++ b/src/sagemaker/jumpstart/artifacts/resource_requirements.py @@ -13,7 +13,7 @@ """This module contains functions for obtaining JumpStart resoure requirements.""" from __future__ import absolute_import -from typing import Dict, Optional +from typing import Dict, Optional, Tuple from sagemaker.jumpstart.constants import ( DEFAULT_JUMPSTART_SAGEMAKER_SESSION, @@ -28,7 +28,9 @@ from sagemaker.session import Session from sagemaker.compute_resource_requirements.resource_requirements import ResourceRequirements -REQUIREMENT_TYPE_TO_SPEC_FIELD_NAME_TO_RESOURCE_REQUIREMENT_NAME_MAP: Dict[str, Dict[str, str]] = { +REQUIREMENT_TYPE_TO_SPEC_FIELD_NAME_TO_RESOURCE_REQUIREMENT_NAME_MAP: Dict[ + str, Dict[str, Tuple[str, str]] +] = { "requests": { "num_accelerators": ("num_accelerators", "num_accelerators"), "num_cpus": ("num_cpus", "num_cpus"), From 39d3fa628c2bc1a3655e430b872b895ce519d016 Mon Sep 17 00:00:00 2001 From: Evan Kravitz Date: Tue, 30 Jan 2024 18:23:26 +0000 Subject: [PATCH 4/4] fix: pylint --- src/sagemaker/jumpstart/artifacts/resource_requirements.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/sagemaker/jumpstart/artifacts/resource_requirements.py b/src/sagemaker/jumpstart/artifacts/resource_requirements.py index 6283b6bcd1..6ee4f31c56 100644 --- a/src/sagemaker/jumpstart/artifacts/resource_requirements.py +++ b/src/sagemaker/jumpstart/artifacts/resource_requirements.py @@ -134,7 +134,6 @@ def _retrieve_default_resources( requirement_type, spec_field_to_resource_requirement_map, ) in REQUIREMENT_TYPE_TO_SPEC_FIELD_NAME_TO_RESOURCE_REQUIREMENT_NAME_MAP.items(): - requirement_type requirement_kwargs = {} for spec_field, resource_requirement in spec_field_to_resource_requirement_map.items(): if spec_field in default_resource_requirements: