From c4ac4803b26fd200597a2ec972d9682f697c8c25 Mon Sep 17 00:00:00 2001 From: Haotian An <33510317+Captainia@users.noreply.github.com> Date: Tue, 23 Apr 2024 12:03:54 -0400 Subject: [PATCH 01/32] fix: mainline alt config parsing (#4602) * fix: parsing * fix: commit tests * fix: types * updated * fix --- src/sagemaker/jumpstart/types.py | 72 +++++++++++----- tests/unit/sagemaker/jumpstart/constants.py | 32 ++++--- tests/unit/sagemaker/jumpstart/test_types.py | 90 ++++++++++++++++---- tests/unit/sagemaker/jumpstart/test_utils.py | 60 ++++++------- 4 files changed, 174 insertions(+), 80 deletions(-) diff --git a/src/sagemaker/jumpstart/types.py b/src/sagemaker/jumpstart/types.py index 8e53bf6f83..05c6a00961 100644 --- a/src/sagemaker/jumpstart/types.py +++ b/src/sagemaker/jumpstart/types.py @@ -744,12 +744,12 @@ def _get_regional_property( class JumpStartBenchmarkStat(JumpStartDataHolderType): - """Data class JumpStart benchmark stats.""" + """Data class JumpStart benchmark stat.""" __slots__ = ["name", "value", "unit"] def __init__(self, spec: Dict[str, Any]): - """Initializes a JumpStartBenchmarkStat object + """Initializes a JumpStartBenchmarkStat object. Args: spec (Dict[str, Any]): Dictionary representation of benchmark stat. @@ -858,7 +858,7 @@ class JumpStartMetadataBaseFields(JumpStartDataHolderType): "model_subscription_link", ] - def __init__(self, fields: Optional[Dict[str, Any]]): + def __init__(self, fields: Dict[str, Any]): """Initializes a JumpStartMetadataFields object. Args: @@ -877,7 +877,7 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: self.version: str = json_obj.get("version") self.min_sdk_version: str = json_obj.get("min_sdk_version") self.incremental_training_supported: bool = bool( - json_obj.get("incremental_training_supported") + json_obj.get("incremental_training_supported", False) ) self.hosting_ecr_specs: Optional[JumpStartECRSpecs] = ( JumpStartECRSpecs(json_obj["hosting_ecr_specs"]) @@ -1038,7 +1038,7 @@ class JumpStartConfigComponent(JumpStartMetadataBaseFields): __slots__ = slots + JumpStartMetadataBaseFields.__slots__ - def __init__( # pylint: disable=super-init-not-called + def __init__( self, component_name: str, component: Optional[Dict[str, Any]], @@ -1049,7 +1049,10 @@ def __init__( # pylint: disable=super-init-not-called component_name (str): Name of the component. component (Dict[str, Any]): Dictionary representation of the config component. + Raises: + ValueError: If the component field is invalid. """ + super().__init__(component) self.component_name = component_name self.from_json(component) @@ -1080,7 +1083,7 @@ def __init__( self, base_fields: Dict[str, Any], config_components: Dict[str, JumpStartConfigComponent], - benchmark_metrics: Dict[str, JumpStartBenchmarkStat], + benchmark_metrics: Dict[str, List[JumpStartBenchmarkStat]], ): """Initializes a JumpStartMetadataConfig object from its json representation. @@ -1089,12 +1092,12 @@ def __init__( The default base fields that are used to construct the final resolved config. config_components (Dict[str, JumpStartConfigComponent]): The list of components that are used to construct the resolved config. - benchmark_metrics (Dict[str, JumpStartBenchmarkStat]): + benchmark_metrics (Dict[str, List[JumpStartBenchmarkStat]]): The dictionary of benchmark metrics with name being the key. """ self.base_fields = base_fields self.config_components: Dict[str, JumpStartConfigComponent] = config_components - self.benchmark_metrics: Dict[str, JumpStartBenchmarkStat] = benchmark_metrics + self.benchmark_metrics: Dict[str, List[JumpStartBenchmarkStat]] = benchmark_metrics self.resolved_metadata_config: Optional[Dict[str, Any]] = None def to_json(self) -> Dict[str, Any]: @@ -1104,7 +1107,7 @@ def to_json(self) -> Dict[str, Any]: @property def resolved_config(self) -> Dict[str, Any]: - """Returns the final config that is resolved from the list of components. + """Returns the final config that is resolved from the components map. Construct the final config by applying the list of configs from list index, and apply to the base default fields in the current model specs. @@ -1139,7 +1142,7 @@ def __init__( Args: configs (Dict[str, JumpStartMetadataConfig]): - List of configs that the current model has. + The map of JumpStartMetadataConfig object, with config name being the key. config_rankings (JumpStartConfigRanking): Config ranking class represents the ranking of the configs in the model. scope (JumpStartScriptScope): @@ -1158,19 +1161,30 @@ def get_top_config_from_ranking( self, ranking_name: str = JumpStartConfigRankingName.DEFAULT, instance_type: Optional[str] = None, - ) -> JumpStartMetadataConfig: - """Gets the best the config based on config ranking.""" + ) -> Optional[JumpStartMetadataConfig]: + """Gets the best the config based on config ranking. + + Args: + ranking_name (str): + The ranking name that config priority is based on. + instance_type (Optional[str]): + The instance type which the config selection is based on. + + Raises: + ValueError: If the config exists but missing config ranking. + NotImplementedError: If the scope is unrecognized. + """ if self.configs and ( not self.config_rankings or not self.config_rankings.get(ranking_name) ): - raise ValueError("Config exists but missing config ranking.") + raise ValueError(f"Config exists but missing config ranking {ranking_name}.") if self.scope == JumpStartScriptScope.INFERENCE: instance_type_attribute = "supported_inference_instance_types" elif self.scope == JumpStartScriptScope.TRAINING: instance_type_attribute = "supported_training_instance_types" else: - raise ValueError(f"Unknown script scope {self.scope}") + raise NotImplementedError(f"Unknown script scope {self.scope}") rankings = self.config_rankings.get(ranking_name) for config_name in rankings.rankings: @@ -1198,12 +1212,13 @@ class JumpStartModelSpecs(JumpStartMetadataBaseFields): __slots__ = JumpStartMetadataBaseFields.__slots__ + slots - def __init__(self, spec: Dict[str, Any]): # pylint: disable=super-init-not-called + def __init__(self, spec: Dict[str, Any]): """Initializes a JumpStartModelSpecs object from its json representation. Args: spec (Dict[str, Any]): Dictionary representation of spec. """ + super().__init__(spec) self.from_json(spec) if self.inference_configs and self.inference_configs.get_top_config_from_ranking(): super().from_json(self.inference_configs.get_top_config_from_ranking().resolved_config) @@ -1245,8 +1260,8 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: ), ( { - stat_name: JumpStartBenchmarkStat(stat) - for stat_name, stat in config.get("benchmark_metrics").items() + stat_name: [JumpStartBenchmarkStat(stat) for stat in stats] + for stat_name, stats in config.get("benchmark_metrics").items() } if config and config.get("benchmark_metrics") else None @@ -1297,8 +1312,8 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: ), ( { - stat_name: JumpStartBenchmarkStat(stat) - for stat_name, stat in config.get("benchmark_metrics").items() + stat_name: [JumpStartBenchmarkStat(stat) for stat in stats] + for stat_name, stats in config.get("benchmark_metrics").items() } if config and config.get("benchmark_metrics") else None @@ -1330,13 +1345,26 @@ def set_config( config_name (str): Name of the config. scope (JumpStartScriptScope, optional): Scope of the config. Defaults to JumpStartScriptScope.INFERENCE. + + Raises: + ValueError: If the scope is not supported, or cannot find config name. """ if scope == JumpStartScriptScope.INFERENCE: - super().from_json(self.inference_configs.configs[config_name].resolved_config) + metadata_configs = self.inference_configs elif scope == JumpStartScriptScope.TRAINING and self.training_supported: - super().from_json(self.training_configs.configs[config_name].resolved_config) + metadata_configs = self.training_configs else: - raise ValueError(f"Unknown Jumpstart Script scope {scope}.") + raise ValueError(f"Unknown Jumpstart script scope {scope}.") + + config_object = metadata_configs.configs.get(config_name) + if not config_object: + error_msg = f"Cannot find Jumpstart config name {config_name}. " + config_names = list(metadata_configs.configs.keys()) + if config_names: + error_msg += f"List of config names that is supported by the model: {config_names}" + raise ValueError(error_msg) + + super().from_json(config_object.resolved_config) def supports_prepacked_inference(self) -> bool: """Returns True if the model has a prepacked inference artifact.""" diff --git a/tests/unit/sagemaker/jumpstart/constants.py b/tests/unit/sagemaker/jumpstart/constants.py index 2b6856b1f3..f165a513a9 100644 --- a/tests/unit/sagemaker/jumpstart/constants.py +++ b/tests/unit/sagemaker/jumpstart/constants.py @@ -6270,6 +6270,10 @@ "framework_version": "1.5.0", "py_version": "py3", }, + "default_inference_instance_type": "ml.p2.xlarge", + "supported_inference_instance_type": ["ml.p2.xlarge", "ml.p3.xlarge"], + "default_training_instance_type": "ml.p2.xlarge", + "supported_training_instance_type": ["ml.p2.xlarge", "ml.p3.xlarge"], "hosting_artifact_key": "pytorch-infer/infer-pytorch-eqa-bert-base-cased.tar.gz", "hosting_script_key": "source-directory-tarballs/pytorch/inference/eqa/v1.0.0/sourcedir.tar.gz", "inference_vulnerable": False, @@ -7658,25 +7662,25 @@ "inference_configs": { "neuron-inference": { "benchmark_metrics": { - "ml.inf2.2xlarge": {"name": "Latency", "value": "100", "unit": "Tokens/S"} + "ml.inf2.2xlarge": [{"name": "Latency", "value": "100", "unit": "Tokens/S"}] }, - "component_names": ["neuron-base"], + "component_names": ["neuron-inference"], }, "neuron-inference-budget": { "benchmark_metrics": { - "ml.inf2.2xlarge": {"name": "Latency", "value": "100", "unit": "Tokens/S"} + "ml.inf2.2xlarge": [{"name": "Latency", "value": "100", "unit": "Tokens/S"}] }, "component_names": ["neuron-base"], }, "gpu-inference-budget": { "benchmark_metrics": { - "ml.p3.2xlarge": {"name": "Latency", "value": "100", "unit": "Tokens/S"} + "ml.p3.2xlarge": [{"name": "Latency", "value": "100", "unit": "Tokens/S"}] }, "component_names": ["gpu-inference-budget"], }, "gpu-inference": { "benchmark_metrics": { - "ml.p3.2xlarge": {"name": "Latency", "value": "100", "unit": "Tokens/S"} + "ml.p3.2xlarge": [{"name": "Latency", "value": "100", "unit": "Tokens/S"}] }, "component_names": ["gpu-inference"], }, @@ -7686,7 +7690,13 @@ "supported_inference_instance_types": ["ml.inf2.xlarge", "ml.inf2.2xlarge"] }, "neuron-inference": { + "default_inference_instance_type": "ml.inf2.xlarge", "supported_inference_instance_types": ["ml.inf2.xlarge", "ml.inf2.2xlarge"], + "hosting_ecr_specs": { + "framework": "huggingface-llm-neuronx", + "framework_version": "0.0.17", + "py_version": "py310", + }, "hosting_artifact_key": "artifacts/meta-textgeneration-llama-2-7b/neuron-inference/model/", "hosting_instance_type_variants": { "regional_aliases": { @@ -7738,27 +7748,27 @@ "training_configs": { "neuron-training": { "benchmark_metrics": { - "ml.tr1n1.2xlarge": {"name": "Latency", "value": "100", "unit": "Tokens/S"}, - "ml.tr1n1.4xlarge": {"name": "Latency", "value": "50", "unit": "Tokens/S"}, + "ml.tr1n1.2xlarge": [{"name": "Latency", "value": "100", "unit": "Tokens/S"}], + "ml.tr1n1.4xlarge": [{"name": "Latency", "value": "50", "unit": "Tokens/S"}], }, "component_names": ["neuron-training"], }, "neuron-training-budget": { "benchmark_metrics": { - "ml.tr1n1.2xlarge": {"name": "Latency", "value": "100", "unit": "Tokens/S"}, - "ml.tr1n1.4xlarge": {"name": "Latency", "value": "50", "unit": "Tokens/S"}, + "ml.tr1n1.2xlarge": [{"name": "Latency", "value": "100", "unit": "Tokens/S"}], + "ml.tr1n1.4xlarge": [{"name": "Latency", "value": "50", "unit": "Tokens/S"}], }, "component_names": ["neuron-training-budget"], }, "gpu-training": { "benchmark_metrics": { - "ml.p3.2xlarge": {"name": "Latency", "value": "200", "unit": "Tokens/S"}, + "ml.p3.2xlarge": [{"name": "Latency", "value": "200", "unit": "Tokens/S"}], }, "component_names": ["gpu-training"], }, "gpu-training-budget": { "benchmark_metrics": { - "ml.p3.2xlarge": {"name": "Latency", "value": "100", "unit": "Tokens/S"} + "ml.p3.2xlarge": [{"name": "Latency", "value": "100", "unit": "Tokens/S"}] }, "component_names": ["gpu-training-budget"], }, diff --git a/tests/unit/sagemaker/jumpstart/test_types.py b/tests/unit/sagemaker/jumpstart/test_types.py index 3048bbc320..5ca01c3c52 100644 --- a/tests/unit/sagemaker/jumpstart/test_types.py +++ b/tests/unit/sagemaker/jumpstart/test_types.py @@ -12,6 +12,7 @@ # language governing permissions and limitations under the License. from __future__ import absolute_import import copy +import pytest from sagemaker.jumpstart.enums import JumpStartScriptScope from sagemaker.jumpstart.types import ( JumpStartBenchmarkStat, @@ -934,9 +935,9 @@ def test_inference_configs_parsing(): assert specs1.incremental_training_supported assert specs1.hosting_ecr_specs == JumpStartECRSpecs( { - "framework": "pytorch", - "framework_version": "1.5.0", - "py_version": "py3", + "framework": "huggingface-llm-neuronx", + "framework_version": "0.0.17", + "py_version": "py310", } ) assert specs1.training_ecr_specs == JumpStartECRSpecs( @@ -946,7 +947,10 @@ def test_inference_configs_parsing(): "py_version": "py3", } ) - assert specs1.hosting_artifact_key == "pytorch-infer/infer-pytorch-ic-mobilenet-v2.tar.gz" + assert ( + specs1.hosting_artifact_key + == "artifacts/meta-textgeneration-llama-2-7b/neuron-inference/model/" + ) assert specs1.training_artifact_key == "pytorch-training/train-pytorch-ic-mobilenet-v2.tar.gz" assert ( specs1.hosting_script_key @@ -1019,16 +1023,58 @@ def test_inference_configs_parsing(): config = specs1.inference_configs.get_top_config_from_ranking() assert config.benchmark_metrics == { - "ml.inf2.2xlarge": JumpStartBenchmarkStat( - {"name": "Latency", "value": "100", "unit": "Tokens/S"} - ) + "ml.inf2.2xlarge": [ + JumpStartBenchmarkStat({"name": "Latency", "value": "100", "unit": "Tokens/S"}) + ] } assert len(config.config_components) == 1 - assert config.config_components["neuron-base"] == JumpStartConfigComponent( - "neuron-base", - {"supported_inference_instance_types": ["ml.inf2.xlarge", "ml.inf2.2xlarge"]}, + assert config.config_components["neuron-inference"] == JumpStartConfigComponent( + "neuron-inference", + { + "default_inference_instance_type": "ml.inf2.xlarge", + "supported_inference_instance_types": ["ml.inf2.xlarge", "ml.inf2.2xlarge"], + "hosting_ecr_specs": { + "framework": "huggingface-llm-neuronx", + "framework_version": "0.0.17", + "py_version": "py310", + }, + "hosting_artifact_key": "artifacts/meta-textgeneration-llama-2-7b/neuron-inference/model/", + "hosting_instance_type_variants": { + "regional_aliases": { + "us-west-2": { + "neuron-ecr-uri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/" + "huggingface-pytorch-hosting:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + } + }, + "variants": {"inf2": {"regional_properties": {"image_uri": "$neuron-ecr-uri"}}}, + }, + }, ) - assert list(config.config_components.keys()) == ["neuron-base"] + assert list(config.config_components.keys()) == ["neuron-inference"] + + +def test_set_inference_configs(): + spec = {**BASE_SPEC, **INFERENCE_CONFIGS, **INFERENCE_CONFIG_RANKINGS} + specs1 = JumpStartModelSpecs(spec) + + assert list(specs1.inference_config_components.keys()) == [ + "neuron-base", + "neuron-inference", + "neuron-budget", + "gpu-inference", + "gpu-inference-budget", + ] + + with pytest.raises(ValueError) as error: + specs1.set_config("invalid_name") + assert "Cannot find Jumpstart config name invalid_name." + "List of config names that is supported by the model: " + "['neuron-inference', 'neuron-inference-budget', " + "'gpu-inference-budget', 'gpu-inference']" in str(error.value) + + assert specs1.supported_inference_instance_types == ["ml.inf2.xlarge", "ml.inf2.2xlarge"] + specs1.set_config("gpu-inference") + assert specs1.supported_inference_instance_types == ["ml.p2.xlarge", "ml.p3.2xlarge"] def test_training_configs_parsing(): @@ -1133,12 +1179,12 @@ def test_training_configs_parsing(): config = specs1.training_configs.get_top_config_from_ranking() assert config.benchmark_metrics == { - "ml.tr1n1.2xlarge": JumpStartBenchmarkStat( - {"name": "Latency", "value": "100", "unit": "Tokens/S"} - ), - "ml.tr1n1.4xlarge": JumpStartBenchmarkStat( - {"name": "Latency", "value": "50", "unit": "Tokens/S"} - ), + "ml.tr1n1.2xlarge": [ + JumpStartBenchmarkStat({"name": "Latency", "value": "100", "unit": "Tokens/S"}) + ], + "ml.tr1n1.4xlarge": [ + JumpStartBenchmarkStat({"name": "Latency", "value": "50", "unit": "Tokens/S"}) + ], } assert len(config.config_components) == 1 assert config.config_components["neuron-training"] == JumpStartConfigComponent( @@ -1192,3 +1238,13 @@ def test_set_training_config(): specs1.training_artifact_key == "artifacts/meta-textgeneration-llama-2-7b/gpu-training-budget/model/" ) + + with pytest.raises(ValueError) as error: + specs1.set_config("invalid_name", scope=JumpStartScriptScope.TRAINING) + assert "Cannot find Jumpstart config name invalid_name." + "List of config names that is supported by the model: " + "['neuron-training', 'neuron-training-budget', " + "'gpu-training-budget', 'gpu-training']" in str(error.value) + + with pytest.raises(ValueError) as error: + specs1.set_config("invalid_name", scope="unknown scope") diff --git a/tests/unit/sagemaker/jumpstart/test_utils.py b/tests/unit/sagemaker/jumpstart/test_utils.py index e7a7d522c3..c1ea8abcb8 100644 --- a/tests/unit/sagemaker/jumpstart/test_utils.py +++ b/tests/unit/sagemaker/jumpstart/test_utils.py @@ -1598,24 +1598,24 @@ def test_get_jumpstart_benchmark_stats_full_list( "mock-region", "mock-model", "mock-model-version", config_names=None ) == { "neuron-inference": { - "ml.inf2.2xlarge": JumpStartBenchmarkStat( - {"name": "Latency", "value": "100", "unit": "Tokens/S"} - ) + "ml.inf2.2xlarge": [ + JumpStartBenchmarkStat({"name": "Latency", "value": "100", "unit": "Tokens/S"}) + ] }, "neuron-inference-budget": { - "ml.inf2.2xlarge": JumpStartBenchmarkStat( - {"name": "Latency", "value": "100", "unit": "Tokens/S"} - ) + "ml.inf2.2xlarge": [ + JumpStartBenchmarkStat({"name": "Latency", "value": "100", "unit": "Tokens/S"}) + ] }, "gpu-inference-budget": { - "ml.p3.2xlarge": JumpStartBenchmarkStat( - {"name": "Latency", "value": "100", "unit": "Tokens/S"} - ) + "ml.p3.2xlarge": [ + JumpStartBenchmarkStat({"name": "Latency", "value": "100", "unit": "Tokens/S"}) + ] }, "gpu-inference": { - "ml.p3.2xlarge": JumpStartBenchmarkStat( - {"name": "Latency", "value": "100", "unit": "Tokens/S"} - ) + "ml.p3.2xlarge": [ + JumpStartBenchmarkStat({"name": "Latency", "value": "100", "unit": "Tokens/S"}) + ] }, } @@ -1633,14 +1633,14 @@ def test_get_jumpstart_benchmark_stats_partial_list( config_names=["neuron-inference-budget", "gpu-inference-budget"], ) == { "neuron-inference-budget": { - "ml.inf2.2xlarge": JumpStartBenchmarkStat( - {"name": "Latency", "value": "100", "unit": "Tokens/S"} - ) + "ml.inf2.2xlarge": [ + JumpStartBenchmarkStat({"name": "Latency", "value": "100", "unit": "Tokens/S"}) + ] }, "gpu-inference-budget": { - "ml.p3.2xlarge": JumpStartBenchmarkStat( - {"name": "Latency", "value": "100", "unit": "Tokens/S"} - ) + "ml.p3.2xlarge": [ + JumpStartBenchmarkStat({"name": "Latency", "value": "100", "unit": "Tokens/S"}) + ] }, } @@ -1658,9 +1658,9 @@ def test_get_jumpstart_benchmark_stats_single_stat( config_names=["neuron-inference-budget"], ) == { "neuron-inference-budget": { - "ml.inf2.2xlarge": JumpStartBenchmarkStat( - {"name": "Latency", "value": "100", "unit": "Tokens/S"} - ) + "ml.inf2.2xlarge": [ + JumpStartBenchmarkStat({"name": "Latency", "value": "100", "unit": "Tokens/S"}) + ] } } @@ -1695,16 +1695,16 @@ def test_get_jumpstart_benchmark_stats_training( config_names=["neuron-training", "gpu-training-budget"], ) == { "neuron-training": { - "ml.tr1n1.2xlarge": JumpStartBenchmarkStat( - {"name": "Latency", "value": "100", "unit": "Tokens/S"} - ), - "ml.tr1n1.4xlarge": JumpStartBenchmarkStat( - {"name": "Latency", "value": "50", "unit": "Tokens/S"} - ), + "ml.tr1n1.2xlarge": [ + JumpStartBenchmarkStat({"name": "Latency", "value": "100", "unit": "Tokens/S"}) + ], + "ml.tr1n1.4xlarge": [ + JumpStartBenchmarkStat({"name": "Latency", "value": "50", "unit": "Tokens/S"}) + ], }, "gpu-training-budget": { - "ml.p3.2xlarge": JumpStartBenchmarkStat( - {"name": "Latency", "value": "100", "unit": "Tokens/S"} - ) + "ml.p3.2xlarge": [ + JumpStartBenchmarkStat({"name": "Latency", "value": "100", "unit": "Tokens/S"}) + ] }, } From 30c9bf670cc5acb735376046ee6aeb71d628b0f9 Mon Sep 17 00:00:00 2001 From: Nikhil Kulkarni Date: Tue, 23 Apr 2024 09:30:46 -0700 Subject: [PATCH 02/32] Add Triton v24.03 URI (#4605) Co-authored-by: Nikhil Kulkarni --- .../sagemaker-tritonserver.json | 36 +++++++++++++++++-- 1 file changed, 34 insertions(+), 2 deletions(-) diff --git a/src/sagemaker/image_uri_config/sagemaker-tritonserver.json b/src/sagemaker/image_uri_config/sagemaker-tritonserver.json index 82397d913e..b2257ce803 100644 --- a/src/sagemaker/image_uri_config/sagemaker-tritonserver.json +++ b/src/sagemaker/image_uri_config/sagemaker-tritonserver.json @@ -7,7 +7,7 @@ "inference" ], "versions": { - "23.12": { + "24.03": { "registries": { "af-south-1": "626614931356", "il-central-1": "780543022126", @@ -37,7 +37,7 @@ "ca-west-1": "204538143572" }, "repository": "sagemaker-tritonserver", - "tag_prefix": "23.12-py3" + "tag_prefix": "24.03-py3" }, "24.01": { "registries": { @@ -70,6 +70,38 @@ }, "repository": "sagemaker-tritonserver", "tag_prefix": "24.01-py3" + }, + "23.12": { + "registries": { + "af-south-1": "626614931356", + "il-central-1": "780543022126", + "ap-east-1": "871362719292", + "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", + "ap-south-1": "763104351884", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", + "ca-central-1": "763104351884", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", + "eu-central-1": "763104351884", + "eu-north-1": "763104351884", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", + "eu-south-1": "692866216735", + "me-south-1": "217643126080", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-west-1": "763104351884", + "us-west-2": "763104351884", + "ca-west-1": "204538143572" + }, + "repository": "sagemaker-tritonserver", + "tag_prefix": "23.12-py3" } } } \ No newline at end of file From fe32d799fa991ab0cca8caedcb07a379e5b6acef Mon Sep 17 00:00:00 2001 From: jessicazhu3 <106775307+jessicazhu3@users.noreply.github.com> Date: Wed, 24 Apr 2024 11:37:29 -0700 Subject: [PATCH 03/32] feature: support session tag chaining for training job (#4596) * feature: support session tag chaining for training job * fix: resolve typo * fix: resolve typo and build failure * fix: resolve typo and unit test failure --------- Co-authored-by: Jessica Zhu --- src/sagemaker/estimator.py | 22 +++++++++++- src/sagemaker/jumpstart/estimator.py | 4 +++ src/sagemaker/jumpstart/factory/estimator.py | 2 ++ src/sagemaker/jumpstart/types.py | 3 ++ src/sagemaker/session.py | 24 ++++++++++++++ tests/unit/test_estimator.py | 35 ++++++++++++++++++++ tests/unit/test_session.py | 3 ++ 7 files changed, 92 insertions(+), 1 deletion(-) diff --git a/src/sagemaker/estimator.py b/src/sagemaker/estimator.py index 066846564e..58a5fabc2f 100644 --- a/src/sagemaker/estimator.py +++ b/src/sagemaker/estimator.py @@ -181,6 +181,7 @@ def __init__( container_arguments: Optional[List[str]] = None, disable_output_compression: bool = False, enable_remote_debug: Optional[Union[bool, PipelineVariable]] = None, + enable_session_tag_chaining: Optional[Union[bool, PipelineVariable]] = None, **kwargs, ): """Initialize an ``EstimatorBase`` instance. @@ -544,7 +545,9 @@ def __init__( enable_infra_check (bool or PipelineVariable): Optional. Specifies whether it is running Sagemaker built-in infra check jobs. enable_remote_debug (bool or PipelineVariable): Optional. - Specifies whether RemoteDebug is enabled for the training job + Specifies whether RemoteDebug is enabled for the training job. + enable_session_tag_chaining (bool or PipelineVariable): Optional. + Specifies whether SessionTagChaining is enabled for the training job. """ instance_count = renamed_kwargs( "train_instance_count", "instance_count", instance_count, kwargs @@ -785,6 +788,8 @@ def __init__( self._enable_remote_debug = enable_remote_debug + self._enable_session_tag_chaining = enable_session_tag_chaining + @abstractmethod def training_image_uri(self): """Return the Docker image to use for training. @@ -2318,6 +2323,14 @@ def get_remote_debug_config(self): else {"EnableRemoteDebug": self._enable_remote_debug} ) + def get_session_chaining_config(self): + """dict: Return the configuration of SessionChaining""" + return ( + None + if self._enable_session_tag_chaining is None + else {"EnableSessionTagChaining": self._enable_session_tag_chaining} + ) + def enable_remote_debug(self): """Enable remote debug for a training job.""" self._update_remote_debug(True) @@ -2574,6 +2587,9 @@ def _get_train_args(cls, estimator, inputs, experiment_config): if estimator.get_remote_debug_config() is not None: train_args["remote_debug_config"] = estimator.get_remote_debug_config() + if estimator.get_session_chaining_config() is not None: + train_args["session_chaining_config"] = estimator.get_session_chaining_config() + return train_args @classmethod @@ -2766,6 +2782,7 @@ def __init__( disable_output_compression: bool = False, enable_infra_check: Optional[Union[bool, PipelineVariable]] = None, enable_remote_debug: Optional[Union[bool, PipelineVariable]] = None, + enable_session_tag_chaining: Optional[Union[bool, PipelineVariable]] = None, **kwargs, ): """Initialize an ``Estimator`` instance. @@ -3129,6 +3146,8 @@ def __init__( Specifies whether it is running Sagemaker built-in infra check jobs. enable_remote_debug (bool or PipelineVariable): Optional. Specifies whether RemoteDebug is enabled for the training job + enable_session_tag_chaining (bool or PipelineVariable): Optional. + Specifies whether SessionTagChaining is enabled for the training job """ self.image_uri = image_uri self._hyperparameters = hyperparameters.copy() if hyperparameters else {} @@ -3181,6 +3200,7 @@ def __init__( container_arguments=container_arguments, disable_output_compression=disable_output_compression, enable_remote_debug=enable_remote_debug, + enable_session_tag_chaining=enable_session_tag_chaining, **kwargs, ) diff --git a/src/sagemaker/jumpstart/estimator.py b/src/sagemaker/jumpstart/estimator.py index 88927ae931..bade834cc6 100644 --- a/src/sagemaker/jumpstart/estimator.py +++ b/src/sagemaker/jumpstart/estimator.py @@ -109,6 +109,7 @@ def __init__( container_arguments: Optional[List[str]] = None, disable_output_compression: Optional[bool] = None, enable_remote_debug: Optional[Union[bool, PipelineVariable]] = None, + enable_session_tag_chaining: Optional[Union[bool, PipelineVariable]] = None, ): """Initializes a ``JumpStartEstimator``. @@ -500,6 +501,8 @@ def __init__( to Amazon S3 without compression after training finishes. enable_remote_debug (bool or PipelineVariable): Optional. Specifies whether RemoteDebug is enabled for the training job + enable_session_tag_chaining (bool or PipelineVariable): Optional. + Specifies whether SessionTagChaining is enabled for the training job Raises: ValueError: If the model ID is not recognized by JumpStart. @@ -578,6 +581,7 @@ def _validate_model_id_and_get_type_hook(): disable_output_compression=disable_output_compression, enable_infra_check=enable_infra_check, enable_remote_debug=enable_remote_debug, + enable_session_tag_chaining=enable_session_tag_chaining, ) self.model_id = estimator_init_kwargs.model_id diff --git a/src/sagemaker/jumpstart/factory/estimator.py b/src/sagemaker/jumpstart/factory/estimator.py index 875ec9d003..387a4a843c 100644 --- a/src/sagemaker/jumpstart/factory/estimator.py +++ b/src/sagemaker/jumpstart/factory/estimator.py @@ -130,6 +130,7 @@ def get_init_kwargs( disable_output_compression: Optional[bool] = None, enable_infra_check: Optional[Union[bool, PipelineVariable]] = None, enable_remote_debug: Optional[Union[bool, PipelineVariable]] = None, + enable_session_tag_chaining: Optional[Union[bool, PipelineVariable]] = None, ) -> JumpStartEstimatorInitKwargs: """Returns kwargs required to instantiate `sagemaker.estimator.Estimator` object.""" @@ -188,6 +189,7 @@ def get_init_kwargs( disable_output_compression=disable_output_compression, enable_infra_check=enable_infra_check, enable_remote_debug=enable_remote_debug, + enable_session_tag_chaining=enable_session_tag_chaining, ) estimator_init_kwargs = _add_model_version_to_kwargs(estimator_init_kwargs) diff --git a/src/sagemaker/jumpstart/types.py b/src/sagemaker/jumpstart/types.py index 05c6a00961..dae879494e 100644 --- a/src/sagemaker/jumpstart/types.py +++ b/src/sagemaker/jumpstart/types.py @@ -1751,6 +1751,7 @@ class JumpStartEstimatorInitKwargs(JumpStartKwargs): "disable_output_compression", "enable_infra_check", "enable_remote_debug", + "enable_session_tag_chaining", ] SERIALIZATION_EXCLUSION_SET = { @@ -1818,6 +1819,7 @@ def __init__( disable_output_compression: Optional[bool] = None, enable_infra_check: Optional[Union[bool, PipelineVariable]] = None, enable_remote_debug: Optional[Union[bool, PipelineVariable]] = None, + enable_session_tag_chaining: Optional[Union[bool, PipelineVariable]] = None, ) -> None: """Instantiates JumpStartEstimatorInitKwargs object.""" @@ -1877,6 +1879,7 @@ def __init__( self.disable_output_compression = disable_output_compression self.enable_infra_check = enable_infra_check self.enable_remote_debug = enable_remote_debug + self.enable_session_tag_chaining = enable_session_tag_chaining class JumpStartEstimatorFitKwargs(JumpStartKwargs): diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index 9e593706c1..5ea3d5f8a1 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -758,6 +758,7 @@ def train( # noqa: C901 environment: Optional[Dict[str, str]] = None, retry_strategy=None, remote_debug_config=None, + session_chaining_config=None, ): """Create an Amazon SageMaker training job. @@ -877,6 +878,15 @@ def train( # noqa: C901 remote_debug_config = { "EnableRemoteDebug": True, } + session_chaining_config(dict): Configuration for SessionChaining. (default: ``None``) + The dict can contain 'EnableSessionTagChaining'(bool). + For example, + + .. code:: python + + session_chaining_config = { + "EnableSessionTagChaining": True, + } environment (dict[str, str]) : Environment variables to be set for use during training job (default: ``None``) retry_strategy(dict): Defines RetryStrategy for InternalServerFailures. @@ -970,6 +980,7 @@ def train( # noqa: C901 profiler_rule_configs=profiler_rule_configs, profiler_config=inferred_profiler_config, remote_debug_config=remote_debug_config, + session_chaining_config=session_chaining_config, environment=environment, retry_strategy=retry_strategy, ) @@ -1013,6 +1024,7 @@ def _get_train_request( # noqa: C901 profiler_rule_configs=None, profiler_config=None, remote_debug_config=None, + session_chaining_config=None, environment=None, retry_strategy=None, ): @@ -1133,6 +1145,15 @@ def _get_train_request( # noqa: C901 remote_debug_config = { "EnableRemoteDebug": True, } + session_chaining_config(dict): Configuration for SessionChaining. (default: ``None``) + The dict can contain 'EnableSessionTagChaining'(bool). + For example, + + .. code:: python + + session_chaining_config = { + "EnableSessionTagChaining": True, + } environment (dict[str, str]) : Environment variables to be set for use during training job (default: ``None``) retry_strategy(dict): Defines RetryStrategy for InternalServerFailures. @@ -1239,6 +1260,9 @@ def _get_train_request( # noqa: C901 if remote_debug_config is not None: train_request["RemoteDebugConfig"] = remote_debug_config + if session_chaining_config is not None: + train_request["SessionChainingConfig"] = session_chaining_config + if retry_strategy is not None: train_request["RetryStrategy"] = retry_strategy diff --git a/tests/unit/test_estimator.py b/tests/unit/test_estimator.py index 382c48fde6..fd45601801 100644 --- a/tests/unit/test_estimator.py +++ b/tests/unit/test_estimator.py @@ -2089,6 +2089,41 @@ def test_framework_disable_remote_debug(sagemaker_session): assert len(args) == 2 +def test_framework_with_session_chaining_config(sagemaker_session): + f = DummyFramework( + entry_point=SCRIPT_PATH, + role=ROLE, + sagemaker_session=sagemaker_session, + instance_groups=[ + InstanceGroup("group1", "ml.c4.xlarge", 1), + InstanceGroup("group2", "ml.m4.xlarge", 2), + ], + enable_session_tag_chaining=True, + ) + f.fit("s3://mydata") + sagemaker_session.train.assert_called_once() + _, args = sagemaker_session.train.call_args + assert args["session_chaining_config"]["EnableSessionTagChaining"] + assert f.get_session_chaining_config()["EnableSessionTagChaining"] + + +def test_framework_without_session_chaining_config(sagemaker_session): + f = DummyFramework( + entry_point=SCRIPT_PATH, + role=ROLE, + sagemaker_session=sagemaker_session, + instance_groups=[ + InstanceGroup("group1", "ml.c4.xlarge", 1), + InstanceGroup("group2", "ml.m4.xlarge", 2), + ], + ) + f.fit("s3://mydata") + sagemaker_session.train.assert_called_once() + _, args = sagemaker_session.train.call_args + assert args.get("SessionTagChaining") is None + assert f.get_remote_debug_config() is None + + @patch("time.strftime", return_value=TIMESTAMP) def test_custom_code_bucket(time, sagemaker_session): code_bucket = "codebucket" diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index 19f9d0ae3d..944f22acff 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -2197,6 +2197,7 @@ def test_train_pack_to_request_with_optional_params(sagemaker_session): CONTAINER_ENTRY_POINT = ["bin/bash", "test.sh"] CONTAINER_ARGUMENTS = ["--arg1", "value1", "--arg2", "value2"] remote_debug_config = {"EnableRemoteDebug": True} + session_chaining_config = {"EnableSessionTagChaining": True} sagemaker_session.train( image_uri=IMAGE, @@ -2222,6 +2223,7 @@ def test_train_pack_to_request_with_optional_params(sagemaker_session): container_entry_point=CONTAINER_ENTRY_POINT, container_arguments=CONTAINER_ARGUMENTS, remote_debug_config=remote_debug_config, + session_chaining_config=session_chaining_config, ) _, _, actual_train_args = sagemaker_session.sagemaker_client.method_calls[0] @@ -2245,6 +2247,7 @@ def test_train_pack_to_request_with_optional_params(sagemaker_session): ) assert actual_train_args["AlgorithmSpecification"]["ContainerArguments"] == CONTAINER_ARGUMENTS assert actual_train_args["RemoteDebugConfig"]["EnableRemoteDebug"] + assert actual_train_args["SessionChainingConfig"]["EnableSessionTagChaining"] def test_create_transform_job_with_sagemaker_config_injection(sagemaker_session): From 8984d9243fd37cc56639889d1c21ded68d868c0b Mon Sep 17 00:00:00 2001 From: ci Date: Wed, 24 Apr 2024 21:35:04 +0000 Subject: [PATCH 04/32] prepare release v2.217.0 --- CHANGELOG.md | 13 +++++++++++++ VERSION | 2 +- 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 546c3438a0..880e5df8c5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,18 @@ # Changelog +## v2.217.0 (2024-04-24) + +### Features + + * support session tag chaining for training job + +### Bug Fixes and Other Changes + + * Add Triton v24.03 URI + * mainline alt config parsing + * Fix tox installs + * Add PT 2.2 Graviton Inference DLC + ## v2.216.1 (2024-04-22) ### Bug Fixes and Other Changes diff --git a/VERSION b/VERSION index 9558cc93a5..b236067a9e 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -2.216.2.dev0 +2.217.0 From ed390dddffa95049e39ab45f75f276015cf12ff6 Mon Sep 17 00:00:00 2001 From: ci Date: Wed, 24 Apr 2024 21:35:06 +0000 Subject: [PATCH 05/32] update development version to v2.217.1.dev0 --- VERSION | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/VERSION b/VERSION index b236067a9e..70303736d8 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -2.217.0 +2.217.1.dev0 From 2a52478e7fd185f197e944568173c58bd2495d78 Mon Sep 17 00:00:00 2001 From: Justin Date: Thu, 25 Apr 2024 11:30:56 -0500 Subject: [PATCH 06/32] fix: properly close files in lineage queries and tests (#4587) Closes #4458 --- src/sagemaker/lineage/query.py | 4 ++-- tests/data/sip/training.py | 3 ++- .../sagemaker/lineage/test_lineage_visualize.py | 4 ++-- tests/integ/sagemaker/workflow/test_workflow.py | 3 ++- tests/integ/test_sagemaker_config.py | 6 ++++-- tests/unit/sagemaker/local/test_local_image.py | 12 ++++++++---- tests/unit/sagemaker/serializers/test_serializers.py | 6 ++++-- 7 files changed, 24 insertions(+), 14 deletions(-) diff --git a/src/sagemaker/lineage/query.py b/src/sagemaker/lineage/query.py index 182f117913..3e2003674b 100644 --- a/src/sagemaker/lineage/query.py +++ b/src/sagemaker/lineage/query.py @@ -335,8 +335,8 @@ def _get_legend_line(self, component_name): def _add_legend(self, path): """Embed legend to html file generated by pyvis.""" - f = open(path, "r") - content = self.BeautifulSoup(f, "html.parser") + with open(path, "r") as f: + content = self.BeautifulSoup(f, "html.parser") legend = """
Date: Mon, 29 Apr 2024 11:34:39 -0700 Subject: [PATCH 07/32] feature: set default allow_pickle param to False (#4557) * breaking: set default allow_pickle param to False * breaking: fix unit tests and linting NumpyDeserializer will not allow deserialization unless allow_pickle flag is set to True explicitly * fix: black-check --------- Co-authored-by: Ashwin Krishna --- src/sagemaker/base_deserializers.py | 17 ++++++++++++++--- .../deserializers/test_deserializers.py | 3 ++- 2 files changed, 16 insertions(+), 4 deletions(-) diff --git a/src/sagemaker/base_deserializers.py b/src/sagemaker/base_deserializers.py index 7162e5274d..a152f0144d 100644 --- a/src/sagemaker/base_deserializers.py +++ b/src/sagemaker/base_deserializers.py @@ -196,14 +196,14 @@ class NumpyDeserializer(SimpleBaseDeserializer): single array. """ - def __init__(self, dtype=None, accept="application/x-npy", allow_pickle=True): + def __init__(self, dtype=None, accept="application/x-npy", allow_pickle=False): """Initialize a ``NumpyDeserializer`` instance. Args: dtype (str): The dtype of the data (default: None). accept (union[str, tuple[str]]): The MIME type (or tuple of allowable MIME types) that is expected from the inference endpoint (default: "application/x-npy"). - allow_pickle (bool): Allow loading pickled object arrays (default: True). + allow_pickle (bool): Allow loading pickled object arrays (default: False). """ super(NumpyDeserializer, self).__init__(accept=accept) self.dtype = dtype @@ -227,10 +227,21 @@ def deserialize(self, stream, content_type): if content_type == "application/json": return np.array(json.load(codecs.getreader("utf-8")(stream)), dtype=self.dtype) if content_type == "application/x-npy": - return np.load(io.BytesIO(stream.read()), allow_pickle=self.allow_pickle) + try: + return np.load(io.BytesIO(stream.read()), allow_pickle=self.allow_pickle) + except ValueError as ve: + raise ValueError( + "Please set the param allow_pickle=True \ + to deserialize pickle objects in NumpyDeserializer" + ).with_traceback(ve.__traceback__) if content_type == "application/x-npz": try: return np.load(io.BytesIO(stream.read()), allow_pickle=self.allow_pickle) + except ValueError as ve: + raise ValueError( + "Please set the param allow_pickle=True \ + to deserialize pickle objectsin NumpyDeserializer" + ).with_traceback(ve.__traceback__) finally: stream.close() finally: diff --git a/tests/unit/sagemaker/deserializers/test_deserializers.py b/tests/unit/sagemaker/deserializers/test_deserializers.py index b8ede11ba5..cb1923a094 100644 --- a/tests/unit/sagemaker/deserializers/test_deserializers.py +++ b/tests/unit/sagemaker/deserializers/test_deserializers.py @@ -142,7 +142,8 @@ def test_numpy_deserializer_from_npy(numpy_deserializer): assert np.array_equal(array, result) -def test_numpy_deserializer_from_npy_object_array(numpy_deserializer): +def test_numpy_deserializer_from_npy_object_array(): + numpy_deserializer = NumpyDeserializer(allow_pickle=True) array = np.array([{"a": "", "b": ""}, {"c": "", "d": ""}]) stream = io.BytesIO() np.save(stream, array) From b17d332a5e4542d57d2039d08b124edc6042f9fb Mon Sep 17 00:00:00 2001 From: Haotian An <33510317+Captainia@users.noreply.github.com> Date: Tue, 30 Apr 2024 17:58:41 -0400 Subject: [PATCH 08/32] Fix:invalid component error with new metadata (#4634) * fix: invalid component name * tests * format * fix vulnerable model integ tests llama 2 * updated * fix: training dataset location --- src/sagemaker/jumpstart/estimator.py | 7 ++++++- src/sagemaker/jumpstart/types.py | 5 ++--- tests/integ/sagemaker/jumpstart/constants.py | 1 + .../jumpstart/estimator/test_jumpstart_estimator.py | 3 ++- .../unit/sagemaker/jumpstart/estimator/test_estimator.py | 4 ++++ tests/unit/sagemaker/jumpstart/test_types.py | 8 ++++++++ 6 files changed, 23 insertions(+), 5 deletions(-) diff --git a/src/sagemaker/jumpstart/estimator.py b/src/sagemaker/jumpstart/estimator.py index bade834cc6..f53d109dc8 100644 --- a/src/sagemaker/jumpstart/estimator.py +++ b/src/sagemaker/jumpstart/estimator.py @@ -734,7 +734,12 @@ def attach( model_version = model_version or "*" - additional_kwargs = {"model_id": model_id, "model_version": model_version} + additional_kwargs = { + "model_id": model_id, + "model_version": model_version, + "tolerate_vulnerable_model": True, # model is already trained + "tolerate_deprecated_model": True, # model is already trained + } model_specs = verify_model_region_and_return_specs( model_id=model_id, diff --git a/src/sagemaker/jumpstart/types.py b/src/sagemaker/jumpstart/types.py index dae879494e..05c38da266 100644 --- a/src/sagemaker/jumpstart/types.py +++ b/src/sagemaker/jumpstart/types.py @@ -1064,9 +1064,8 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: Dictionary representation of the config component. """ for field in json_obj.keys(): - if field not in self.__slots__: - raise ValueError(f"Invalid component field: {field}") - setattr(self, field, json_obj[field]) + if field in self.__slots__: + setattr(self, field, json_obj[field]) class JumpStartMetadataConfig(JumpStartDataHolderType): diff --git a/tests/integ/sagemaker/jumpstart/constants.py b/tests/integ/sagemaker/jumpstart/constants.py index f5ffbf7a3a..b839866b1f 100644 --- a/tests/integ/sagemaker/jumpstart/constants.py +++ b/tests/integ/sagemaker/jumpstart/constants.py @@ -48,6 +48,7 @@ def _to_s3_path(filename: str, s3_prefix: Optional[str]) -> str: ("meta-textgeneration-llama-2-7b", "*"): ("training-datasets/sec_amazon/"), ("meta-textgeneration-llama-2-7b", "2.*"): ("training-datasets/sec_amazon/"), ("meta-textgeneration-llama-2-7b", "3.*"): ("training-datasets/sec_amazon/"), + ("meta-textgeneration-llama-2-7b", "4.*"): ("training-datasets/sec_amazon/"), ("meta-textgenerationneuron-llama-2-7b", "*"): ("training-datasets/sec_amazon/"), } diff --git a/tests/integ/sagemaker/jumpstart/estimator/test_jumpstart_estimator.py b/tests/integ/sagemaker/jumpstart/estimator/test_jumpstart_estimator.py index a839a293c5..0da64ecf05 100644 --- a/tests/integ/sagemaker/jumpstart/estimator/test_jumpstart_estimator.py +++ b/tests/integ/sagemaker/jumpstart/estimator/test_jumpstart_estimator.py @@ -140,7 +140,7 @@ def test_gated_model_training_v1(setup): def test_gated_model_training_v2(setup): model_id = "meta-textgeneration-llama-2-7b" - model_version = "3.*" # model artifacts retrieved from jumpstart-private-cache-* buckets + model_version = "4.*" # model artifacts retrieved from jumpstart-private-cache-* buckets estimator = JumpStartEstimator( model_id=model_id, @@ -150,6 +150,7 @@ def test_gated_model_training_v2(setup): tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}], environment={"accept_eula": "true"}, max_run=259200, # avoid exceeding resource limits + tolerate_vulnerable_model=True, # tolerate old version of model ) # uses ml.g5.12xlarge instance diff --git a/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py b/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py index ce5f15b287..36d8b11fab 100644 --- a/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py +++ b/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py @@ -1010,6 +1010,8 @@ def test_jumpstart_estimator_attach_eula_model( "model_id": "gemma-model", "model_version": "*", "environment": {"accept_eula": "true"}, + "tolerate_vulnerable_model": True, + "tolerate_deprecated_model": True, }, ) @@ -1053,6 +1055,8 @@ def test_jumpstart_estimator_attach_no_model_id_happy_case( additional_kwargs={ "model_id": "js-trainable-model-prepacked", "model_version": "1.0.0", + "tolerate_vulnerable_model": True, + "tolerate_deprecated_model": True, }, ) diff --git a/tests/unit/sagemaker/jumpstart/test_types.py b/tests/unit/sagemaker/jumpstart/test_types.py index 5ca01c3c52..b2758c73ef 100644 --- a/tests/unit/sagemaker/jumpstart/test_types.py +++ b/tests/unit/sagemaker/jumpstart/test_types.py @@ -1052,6 +1052,14 @@ def test_inference_configs_parsing(): ) assert list(config.config_components.keys()) == ["neuron-inference"] + spec = { + **BASE_SPEC, + **INFERENCE_CONFIGS, + **INFERENCE_CONFIG_RANKINGS, + "unrecognized-field": "blah", # New fields in base metadata fields should be ignored + } + specs1 = JumpStartModelSpecs(spec) + def test_set_inference_configs(): spec = {**BASE_SPEC, **INFERENCE_CONFIGS, **INFERENCE_CONFIG_RANKINGS} From 15094ee208ec2b84f9ca7a53bd1afb291406b8e3 Mon Sep 17 00:00:00 2001 From: ci Date: Wed, 1 May 2024 21:14:30 +0000 Subject: [PATCH 09/32] prepare release v2.218.0 --- CHANGELOG.md | 10 ++++++++++ VERSION | 2 +- 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 880e5df8c5..99416fe44a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,15 @@ # Changelog +## v2.218.0 (2024-05-01) + +### Features + + * set default allow_pickle param to False + +### Bug Fixes and Other Changes + + * properly close files in lineage queries and tests + ## v2.217.0 (2024-04-24) ### Features diff --git a/VERSION b/VERSION index 70303736d8..45aef98018 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -2.217.1.dev0 +2.218.0 From 7c49f5d43d31e1543cb01cb59e72735c0cb901de Mon Sep 17 00:00:00 2001 From: ci Date: Wed, 1 May 2024 21:14:32 +0000 Subject: [PATCH 10/32] update development version to v2.218.1.dev0 --- VERSION | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/VERSION b/VERSION index 45aef98018..c611a0a1ab 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -2.218.0 +2.218.1.dev0 From 45e31921994cb11612e7c44e262b48d8ff5f4d9c Mon Sep 17 00:00:00 2001 From: Haotian An <33510317+Captainia@users.noreply.github.com> Date: Thu, 2 May 2024 13:03:32 -0400 Subject: [PATCH 11/32] chore: update skipped flaky tests (#4644) * Update skipped flaky tests * flake8 * format * format --- src/sagemaker/jumpstart/notebook_utils.py | 9 ++-- src/sagemaker/jumpstart/payload_utils.py | 6 ++- .../jumpstart/test_notebook_utils.py | 52 +++++++++---------- 3 files changed, 35 insertions(+), 32 deletions(-) diff --git a/src/sagemaker/jumpstart/notebook_utils.py b/src/sagemaker/jumpstart/notebook_utils.py index 732493ce3b..83613cd71b 100644 --- a/src/sagemaker/jumpstart/notebook_utils.py +++ b/src/sagemaker/jumpstart/notebook_utils.py @@ -329,9 +329,12 @@ def list_jumpstart_models( # pylint: disable=redefined-builtin return sorted(list(model_id_version_dict.keys())) if not list_old_models: - model_id_version_dict = { - model_id: set([max(versions)]) for model_id, versions in model_id_version_dict.items() - } + for model_id, versions in model_id_version_dict.items(): + try: + model_id_version_dict.update({model_id: set([max(versions)])}) + except TypeError: + versions = [str(v) for v in versions] + model_id_version_dict.update({model_id: set([max(versions)])}) model_id_version_set: Set[Tuple[str, str]] = set() for model_id in model_id_version_dict: diff --git a/src/sagemaker/jumpstart/payload_utils.py b/src/sagemaker/jumpstart/payload_utils.py index 595f801598..e4d31e9c83 100644 --- a/src/sagemaker/jumpstart/payload_utils.py +++ b/src/sagemaker/jumpstart/payload_utils.py @@ -23,7 +23,7 @@ from sagemaker.jumpstart.constants import ( DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ) -from sagemaker.jumpstart.enums import MIMEType +from sagemaker.jumpstart.enums import JumpStartModelType, MIMEType from sagemaker.jumpstart.types import JumpStartSerializablePayload from sagemaker.jumpstart.utils import ( get_jumpstart_content_bucket, @@ -61,6 +61,7 @@ def _construct_payload( tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, ) -> Optional[JumpStartSerializablePayload]: """Returns example payload from prompt. @@ -83,6 +84,8 @@ def _construct_payload( object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + model_type (JumpStartModelType): The type of the model, can be open weights model or + proprietary model. (Default: JumpStartModelType.OPEN_WEIGHTS). Returns: Optional[JumpStartSerializablePayload]: serializable payload with prompt, or None if this feature is unavailable for the specified model. @@ -94,6 +97,7 @@ def _construct_payload( tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, + model_type=model_type, ) if payloads is None or len(payloads) == 0: return None diff --git a/tests/unit/sagemaker/jumpstart/test_notebook_utils.py b/tests/unit/sagemaker/jumpstart/test_notebook_utils.py index c00d271ef1..6544c59019 100644 --- a/tests/unit/sagemaker/jumpstart/test_notebook_utils.py +++ b/tests/unit/sagemaker/jumpstart/test_notebook_utils.py @@ -3,7 +3,6 @@ from unittest import TestCase from unittest.mock import Mock, patch -import datetime import pytest from sagemaker.jumpstart.constants import ( @@ -17,7 +16,6 @@ get_prototype_manifest, get_prototype_model_spec, ) -from tests.unit.sagemaker.jumpstart.constants import BASE_PROPRIETARY_MANIFEST from sagemaker.jumpstart.enums import JumpStartModelType from sagemaker.jumpstart.notebook_utils import ( _generate_jumpstart_model_versions, @@ -227,10 +225,6 @@ def test_list_jumpstart_models_simple_case( patched_get_manifest.assert_called() patched_get_model_specs.assert_not_called() - @pytest.mark.skipif( - datetime.datetime.now() < datetime.datetime(year=2024, month=5, day=1), - reason="Contact JumpStart team to fix flaky test.", - ) @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") @patch("sagemaker.jumpstart.notebook_utils.DEFAULT_JUMPSTART_SAGEMAKER_SESSION.read_s3_file") def test_list_jumpstart_models_script_filter( @@ -246,23 +240,25 @@ def test_list_jumpstart_models_script_filter( manifest_length = len(get_prototype_manifest()) vals = [True, False] for val in vals: - kwargs = {"filter": f"training_supported == {val}"} + kwargs = {"filter": And(f"training_supported == {val}", "model_type is open_weights")} list_jumpstart_models(**kwargs) assert patched_read_s3_file.call_count == manifest_length - patched_get_manifest.assert_called_once() + assert patched_get_manifest.call_count == 2 patched_get_manifest.reset_mock() patched_read_s3_file.reset_mock() - kwargs = {"filter": f"training_supported != {val}"} + kwargs = {"filter": And(f"training_supported != {val}", "model_type is open_weights")} list_jumpstart_models(**kwargs) assert patched_read_s3_file.call_count == manifest_length assert patched_get_manifest.call_count == 2 patched_get_manifest.reset_mock() patched_read_s3_file.reset_mock() - - kwargs = {"filter": f"training_supported in {vals}", "list_versions": True} + kwargs = { + "filter": And(f"training_supported != {val}", "model_type is open_weights"), + "list_versions": True, + } assert list_jumpstart_models(**kwargs) == [ ("catboost-classification-model", "1.0.0"), ("huggingface-spc-bert-base-cased", "1.0.0"), @@ -279,7 +275,7 @@ def test_list_jumpstart_models_script_filter( patched_get_manifest.reset_mock() patched_read_s3_file.reset_mock() - kwargs = {"filter": f"training_supported not in {vals}"} + kwargs = {"filter": And(f"training_supported not in {vals}", "model_type is open_weights")} models = list_jumpstart_models(**kwargs) assert [] == models assert patched_read_s3_file.call_count == manifest_length @@ -518,10 +514,6 @@ def get_manifest_more_versions(region: str = JUMPSTART_DEFAULT_REGION_NAME): list_old_models=False, list_versions=True ) == list_jumpstart_models(list_versions=True) - @pytest.mark.skipif( - datetime.datetime.now() < datetime.datetime(year=2024, month=5, day=1), - reason="Contact JumpStart team to fix flaky test.", - ) @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") @patch("sagemaker.jumpstart.notebook_utils.DEFAULT_JUMPSTART_SAGEMAKER_SESSION.read_s3_file") def test_list_jumpstart_models_vulnerable_models( @@ -547,12 +539,15 @@ def vulnerable_training_model_spec(bucket, key, *args, **kwargs): patched_read_s3_file.side_effect = vulnerable_inference_model_spec num_specs = len(PROTOTYPICAL_MODEL_SPECS_DICT) - num_prop_specs = len(BASE_PROPRIETARY_MANIFEST) assert [] == list_jumpstart_models( - And("inference_vulnerable is false", "training_vulnerable is false") + And( + "inference_vulnerable is false", + "training_vulnerable is false", + "model_type is open_weights", + ) ) - assert patched_read_s3_file.call_count == num_specs + num_prop_specs + assert patched_read_s3_file.call_count == num_specs assert patched_get_manifest.call_count == 2 patched_get_manifest.reset_mock() @@ -561,10 +556,14 @@ def vulnerable_training_model_spec(bucket, key, *args, **kwargs): patched_read_s3_file.side_effect = vulnerable_training_model_spec assert [] == list_jumpstart_models( - And("inference_vulnerable is false", "training_vulnerable is false") + And( + "inference_vulnerable is false", + "training_vulnerable is false", + "model_type is open_weights", + ) ) - assert patched_read_s3_file.call_count == num_specs + num_prop_specs + assert patched_read_s3_file.call_count == num_specs assert patched_get_manifest.call_count == 2 patched_get_manifest.reset_mock() @@ -574,10 +573,6 @@ def vulnerable_training_model_spec(bucket, key, *args, **kwargs): assert patched_read_s3_file.call_count == 0 - @pytest.mark.skipif( - datetime.datetime.now() < datetime.datetime(year=2024, month=5, day=1), - reason="Contact JumpStart team to fix flaky test.", - ) @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") @patch("sagemaker.jumpstart.notebook_utils.DEFAULT_JUMPSTART_SAGEMAKER_SESSION.read_s3_file") def test_list_jumpstart_models_deprecated_models( @@ -598,10 +593,11 @@ def deprecated_model_spec(bucket, key, *args, **kwargs) -> str: patched_read_s3_file.side_effect = deprecated_model_spec num_specs = len(PROTOTYPICAL_MODEL_SPECS_DICT) - num_prop_specs = len(BASE_PROPRIETARY_MANIFEST) - assert [] == list_jumpstart_models("deprecated equals false") + assert [] == list_jumpstart_models( + And("deprecated equals false", "model_type is open_weights") + ) - assert patched_read_s3_file.call_count == num_specs + num_prop_specs + assert patched_read_s3_file.call_count == num_specs assert patched_get_manifest.call_count == 2 patched_get_manifest.reset_mock() From c751dbd9050b9efb58eca4bd33d81e54f02e4a81 Mon Sep 17 00:00:00 2001 From: Haixin Wang <98612668+haixiw@users.noreply.github.com> Date: Thu, 2 May 2024 14:04:18 -0700 Subject: [PATCH 12/32] chore: release tgi 2.0.1 (#4642) * chore: release tgi 2.0.1 * minor fix --------- Co-authored-by: Zhaoqi <52220743+zhaoqizqwang@users.noreply.github.com> --- .../image_uri_config/huggingface-llm.json | 49 ++++++++++++++++++- .../image_uris/test_huggingface_llm.py | 1 + 2 files changed, 49 insertions(+), 1 deletion(-) diff --git a/src/sagemaker/image_uri_config/huggingface-llm.json b/src/sagemaker/image_uri_config/huggingface-llm.json index 10073338e7..d357367e6e 100644 --- a/src/sagemaker/image_uri_config/huggingface-llm.json +++ b/src/sagemaker/image_uri_config/huggingface-llm.json @@ -12,7 +12,7 @@ "1.2": "1.2.0", "1.3": "1.3.3", "1.4": "1.4.5", - "2.0": "2.0.0" + "2.0": "2.0.1" }, "versions": { "0.6.0": { @@ -578,6 +578,53 @@ "container_version": { "gpu": "cu121-ubuntu22.04" } + }, + "2.0.1": { + "py_versions": [ + "py310" + ], + "registries": { + "af-south-1": "626614931356", + "il-central-1": "780543022126", + "ap-east-1": "871362719292", + "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", + "ap-south-1": "763104351884", + "ap-south-2": "772153158452", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", + "ap-southeast-4": "457447274322", + "ca-central-1": "763104351884", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", + "eu-central-1": "763104351884", + "eu-central-2": "380420809688", + "eu-north-1": "763104351884", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", + "eu-south-1": "692866216735", + "eu-south-2": "503227376785", + "me-south-1": "217643126080", + "me-central-1": "914824155844", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", + "us-gov-west-1": "442386744353", + "us-iso-east-1": "886529160074", + "us-isob-east-1": "094389454867", + "us-west-1": "763104351884", + "us-west-2": "763104351884", + "ca-west-1": "204538143572" + }, + "tag_prefix": "2.1.1-tgi2.0.1", + "repository": "huggingface-pytorch-tgi-inference", + "container_version": { + "gpu": "cu121-ubuntu22.04" + } } } } diff --git a/tests/unit/sagemaker/image_uris/test_huggingface_llm.py b/tests/unit/sagemaker/image_uris/test_huggingface_llm.py index 582e5cf82d..2ef981a109 100644 --- a/tests/unit/sagemaker/image_uris/test_huggingface_llm.py +++ b/tests/unit/sagemaker/image_uris/test_huggingface_llm.py @@ -32,6 +32,7 @@ "1.4.2": "2.1.1-tgi1.4.2-gpu-py310-cu121-ubuntu22.04", "1.4.5": "2.1.1-tgi1.4.5-gpu-py310-cu121-ubuntu22.04", "2.0.0": "2.1.1-tgi2.0.0-gpu-py310-cu121-ubuntu22.04", + "2.0.1": "2.1.1-tgi2.0.1-gpu-py310-cu121-ubuntu22.04", }, "inf2": { "0.0.16": "1.13.1-optimum0.0.16-neuronx-py310-ubuntu22.04", From 0f7e6780fea2ae9ba304d2223e3b932f6c7d7ef8 Mon Sep 17 00:00:00 2001 From: Kalyani Nikure <110067132+knikure@users.noreply.github.com> Date: Fri, 3 May 2024 11:22:32 -0700 Subject: [PATCH 13/32] fix: Fix UserAgent logging in Python SDK (#4647) --- src/sagemaker/session.py | 29 +++++++++++---- src/sagemaker/user_agent.py | 45 ++++------------------ tests/unit/test_session.py | 70 +++++++++++++---------------------- tests/unit/test_user_agent.py | 64 +++++++++----------------------- 4 files changed, 71 insertions(+), 137 deletions(-) diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index 5ea3d5f8a1..bf2a736871 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -121,7 +121,7 @@ from sagemaker.deprecations import deprecated_class from sagemaker.enums import EndpointType from sagemaker.inputs import ShuffleConfig, TrainingInput, BatchDataCaptureConfig -from sagemaker.user_agent import prepend_user_agent +from sagemaker.user_agent import get_user_agent_extra_suffix from sagemaker.utils import ( name_from_image, secondary_training_status_changed, @@ -285,6 +285,7 @@ def _initialize( Creates or uses a boto_session, sagemaker_client and sagemaker_runtime_client. Sets the region_name. """ + self.boto_session = boto_session or boto3.DEFAULT_SESSION or boto3.Session() self._region_name = self.boto_session.region_name @@ -293,19 +294,30 @@ def _initialize( "Must setup local AWS configuration with a region supported by SageMaker." ) - self.sagemaker_client = sagemaker_client or self.boto_session.client("sagemaker") - prepend_user_agent(self.sagemaker_client) + # Make use of user_agent_extra field of the botocore_config object + # to append SageMaker Python SDK specific user_agent suffix + # to the current User-Agent header value from boto3 + # This config will also make sure that user_agent never fails to log the User-Agent string + # even if boto User-Agent header format is updated in the future + # Ref: https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html + botocore_config = botocore.config.Config(user_agent_extra=get_user_agent_extra_suffix()) + + # Create sagemaker_client with the botocore_config object + # This config is customized to append SageMaker Python SDK specific user_agent suffix + self.sagemaker_client = sagemaker_client or self.boto_session.client( + "sagemaker", config=botocore_config + ) if sagemaker_runtime_client is not None: self.sagemaker_runtime_client = sagemaker_runtime_client else: - config = botocore.config.Config(read_timeout=80) + config = botocore.config.Config( + read_timeout=80, user_agent_extra=get_user_agent_extra_suffix() + ) self.sagemaker_runtime_client = self.boto_session.client( "runtime.sagemaker", config=config ) - prepend_user_agent(self.sagemaker_runtime_client) - if sagemaker_featurestore_runtime_client: self.sagemaker_featurestore_runtime_client = sagemaker_featurestore_runtime_client else: @@ -316,8 +328,9 @@ def _initialize( if sagemaker_metrics_client: self.sagemaker_metrics_client = sagemaker_metrics_client else: - self.sagemaker_metrics_client = self.boto_session.client("sagemaker-metrics") - prepend_user_agent(self.sagemaker_metrics_client) + self.sagemaker_metrics_client = self.boto_session.client( + "sagemaker-metrics", config=botocore_config + ) self.s3_client = self.boto_session.client("s3", region_name=self.boto_region_name) self.s3_resource = self.boto_session.resource("s3", region_name=self.boto_region_name) diff --git a/src/sagemaker/user_agent.py b/src/sagemaker/user_agent.py index 8af89696c2..c1b2bcac07 100644 --- a/src/sagemaker/user_agent.py +++ b/src/sagemaker/user_agent.py @@ -13,8 +13,6 @@ """Placeholder docstring""" from __future__ import absolute_import -import platform -import sys import json import os @@ -28,12 +26,6 @@ STUDIO_METADATA_FILE = "/opt/ml/metadata/resource-metadata.json" SDK_VERSION = importlib_metadata.version("sagemaker") -OS_NAME = platform.system() or "UnresolvedOS" -OS_VERSION = platform.release() or "UnresolvedOSVersion" -OS_NAME_VERSION = "{}/{}".format(OS_NAME, OS_VERSION) -PYTHON_VERSION = "Python/{}.{}.{}".format( - sys.version_info.major, sys.version_info.minor, sys.version_info.micro -) def process_notebook_metadata_file(): @@ -63,45 +55,24 @@ def process_studio_metadata_file(): return None -def determine_prefix(user_agent=""): - """Determines the prefix for the user agent string. +def get_user_agent_extra_suffix(): + """Get the user agent extra suffix string specific to SageMaker Python SDK - Args: - user_agent (str): The user agent string to prepend the prefix to. + Adhers to new boto recommended User-Agent 2.0 header format Returns: - str: The user agent string with the prefix prepended. + str: The user agent extra suffix string to be appended """ - prefix = "{}/{}".format(SDK_PREFIX, SDK_VERSION) - - if PYTHON_VERSION not in user_agent: - prefix = "{} {}".format(prefix, PYTHON_VERSION) - - if OS_NAME_VERSION not in user_agent: - prefix = "{} {}".format(prefix, OS_NAME_VERSION) + suffix = "lib/{}#{}".format(SDK_PREFIX, SDK_VERSION) # Get the notebook instance type and prepend it to the user agent string if exists notebook_instance_type = process_notebook_metadata_file() if notebook_instance_type: - prefix = "{} {}/{}".format(prefix, NOTEBOOK_PREFIX, notebook_instance_type) + suffix = "{} md/{}#{}".format(suffix, NOTEBOOK_PREFIX, notebook_instance_type) # Get the studio app type and prepend it to the user agent string if exists studio_app_type = process_studio_metadata_file() if studio_app_type: - prefix = "{} {}/{}".format(prefix, STUDIO_PREFIX, studio_app_type) - - return prefix - - -def prepend_user_agent(client): - """Prepends the user agent string with the SageMaker Python SDK version. - - Args: - client (botocore.client.BaseClient): The client to prepend the user agent string for. - """ - prefix = determine_prefix(client._client_config.user_agent) + suffix = "{} md/{}#{}".format(suffix, STUDIO_PREFIX, studio_app_type) - if client._client_config.user_agent is None: - client._client_config.user_agent = prefix - else: - client._client_config.user_agent = "{} {}".format(prefix, client._client_config.user_agent) + return suffix diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index 944f22acff..f7dede1ce9 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -43,8 +43,6 @@ from sagemaker.utils import update_list_of_dicts_with_values_from_config from sagemaker.user_agent import ( SDK_PREFIX, - STUDIO_PREFIX, - NOTEBOOK_PREFIX, ) from sagemaker.compute_resource_requirements.resource_requirements import ResourceRequirements from tests.unit import ( @@ -87,15 +85,20 @@ limits={}, ) +SDK_DEFAULT_SUFFIX = f"lib/{SDK_PREFIX}#2.218.0" +NOTEBOOK_SUFFIX = f"{SDK_DEFAULT_SUFFIX} md/AWS-SageMaker-Notebook-Instance#instance_type" +STUDIO_SUFFIX = f"{SDK_DEFAULT_SUFFIX} md/AWS-SageMaker-Studio#app_type" -@pytest.fixture() -def boto_session(): - boto_mock = Mock(name="boto_session", region_name=REGION) +@pytest.fixture +def boto_session(request): + boto_user_agent = "Boto3/1.33.9 md/Botocore#1.33.9 ua/2.0 os/linux#linux-ver md/arch#x86_64 lang/python#3.10.6" + user_agent_suffix = getattr(request, "param", "") + boto_mock = Mock(name="boto_session", region_name=REGION) client_mock = Mock() - client_mock._client_config.user_agent = ( - "Boto3/1.9.69 Python/3.6.5 Linux/4.14.77-70.82.amzn1.x86_64 Botocore/1.12.69 Resource" - ) + user_agent = f"{boto_user_agent} {SDK_DEFAULT_SUFFIX} {user_agent_suffix}" + with patch("sagemaker.user_agent.get_user_agent_extra_suffix", return_value=user_agent_suffix): + client_mock._client_config.user_agent = user_agent boto_mock.client.return_value = client_mock return boto_mock @@ -887,65 +890,42 @@ def test_delete_model(boto_session): boto_session.client().delete_model.assert_called_with(ModelName=model_name) +@pytest.mark.parametrize("boto_session", [""], indirect=True) def test_user_agent_injected(boto_session): - assert SDK_PREFIX not in boto_session.client("sagemaker")._client_config.user_agent - sess = Session(boto_session) - + expected_user_agent_suffix = "lib/AWS-SageMaker-Python-SDK#2.218.0" for client in [ sess.sagemaker_client, sess.sagemaker_runtime_client, sess.sagemaker_metrics_client, ]: - assert SDK_PREFIX in client._client_config.user_agent - assert NOTEBOOK_PREFIX not in client._client_config.user_agent - assert STUDIO_PREFIX not in client._client_config.user_agent + assert expected_user_agent_suffix in client._client_config.user_agent -@patch("sagemaker.user_agent.process_notebook_metadata_file", return_value="ml.t3.medium") -def test_user_agent_injected_with_nbi( - mock_process_notebook_metadata_file, - boto_session, -): - assert SDK_PREFIX not in boto_session.client("sagemaker")._client_config.user_agent - - sess = Session( - boto_session=boto_session, +@pytest.mark.parametrize("boto_session", [f"{NOTEBOOK_SUFFIX}"], indirect=True) +def test_user_agent_with_notebook_instance_type(boto_session): + sess = Session(boto_session) + expected_user_agent_suffix = ( + "lib/AWS-SageMaker-Python-SDK#2.218.0 md/AWS-SageMaker-Notebook-Instance#instance_type" ) - for client in [ sess.sagemaker_client, sess.sagemaker_runtime_client, sess.sagemaker_metrics_client, ]: - mock_process_notebook_metadata_file.assert_called() - - assert SDK_PREFIX in client._client_config.user_agent - assert NOTEBOOK_PREFIX in client._client_config.user_agent - assert STUDIO_PREFIX not in client._client_config.user_agent + assert expected_user_agent_suffix in client._client_config.user_agent -@patch("sagemaker.user_agent.process_studio_metadata_file", return_value="dymmy-app-type") -def test_user_agent_injected_with_studio_app_type( - mock_process_studio_metadata_file, - boto_session, -): - assert SDK_PREFIX not in boto_session.client("sagemaker")._client_config.user_agent - - sess = Session( - boto_session=boto_session, - ) - +@pytest.mark.parametrize("boto_session", [f"{STUDIO_SUFFIX}"], indirect=True) +def test_user_agent_with_studio_app_type(boto_session): + sess = Session(boto_session) + expected_user_agent = "lib/AWS-SageMaker-Python-SDK#2.218.0 md/AWS-SageMaker-Studio#app_type" for client in [ sess.sagemaker_client, sess.sagemaker_runtime_client, sess.sagemaker_metrics_client, ]: - mock_process_studio_metadata_file.assert_called() - - assert SDK_PREFIX in client._client_config.user_agent - assert NOTEBOOK_PREFIX not in client._client_config.user_agent - assert STUDIO_PREFIX in client._client_config.user_agent + assert expected_user_agent in client._client_config.user_agent def test_training_input_all_defaults(): diff --git a/tests/unit/test_user_agent.py b/tests/unit/test_user_agent.py index c116fef951..fb46988e7b 100644 --- a/tests/unit/test_user_agent.py +++ b/tests/unit/test_user_agent.py @@ -13,20 +13,17 @@ from __future__ import absolute_import import json -from mock import MagicMock, patch, mock_open +from mock import patch, mock_open from sagemaker.user_agent import ( SDK_PREFIX, SDK_VERSION, - PYTHON_VERSION, - OS_NAME_VERSION, NOTEBOOK_PREFIX, STUDIO_PREFIX, process_notebook_metadata_file, process_studio_metadata_file, - determine_prefix, - prepend_user_agent, + get_user_agent_extra_suffix, ) @@ -60,45 +57,18 @@ def test_process_studio_metadata_file_not_exists(tmp_path): assert process_studio_metadata_file() is None -# Test determine_prefix function -def test_determine_prefix_notebook_instance_type(monkeypatch): - monkeypatch.setattr( - "sagemaker.user_agent.process_notebook_metadata_file", lambda: "instance_type" - ) - assert ( - determine_prefix() - == f"{SDK_PREFIX}/{SDK_VERSION} {PYTHON_VERSION} {OS_NAME_VERSION} {NOTEBOOK_PREFIX}/instance_type" - ) - - -def test_determine_prefix_studio_app_type(monkeypatch): - monkeypatch.setattr( - "sagemaker.user_agent.process_studio_metadata_file", lambda: "studio_app_type" - ) - assert ( - determine_prefix() - == f"{SDK_PREFIX}/{SDK_VERSION} {PYTHON_VERSION} {OS_NAME_VERSION} {STUDIO_PREFIX}/studio_app_type" - ) - - -def test_determine_prefix_no_metadata(monkeypatch): - monkeypatch.setattr("sagemaker.user_agent.process_notebook_metadata_file", lambda: None) - monkeypatch.setattr("sagemaker.user_agent.process_studio_metadata_file", lambda: None) - assert determine_prefix() == f"{SDK_PREFIX}/{SDK_VERSION} {PYTHON_VERSION} {OS_NAME_VERSION}" - - -# Test prepend_user_agent function -def test_prepend_user_agent_existing_user_agent(monkeypatch): - client = MagicMock() - client._client_config.user_agent = "existing_user_agent" - monkeypatch.setattr("sagemaker.user_agent.determine_prefix", lambda _: "prefix") - prepend_user_agent(client) - assert client._client_config.user_agent == "prefix existing_user_agent" - - -def test_prepend_user_agent_no_user_agent(monkeypatch): - client = MagicMock() - client._client_config.user_agent = None - monkeypatch.setattr("sagemaker.user_agent.determine_prefix", lambda _: "prefix") - prepend_user_agent(client) - assert client._client_config.user_agent == "prefix" +# Test get_user_agent_extra_suffix function +def test_get_user_agent_extra_suffix(): + assert get_user_agent_extra_suffix() == f"lib/{SDK_PREFIX}#{SDK_VERSION}" + + with patch("sagemaker.user_agent.process_notebook_metadata_file", return_value="instance_type"): + assert ( + get_user_agent_extra_suffix() + == f"lib/{SDK_PREFIX}#{SDK_VERSION} md/{NOTEBOOK_PREFIX}#instance_type" + ) + + with patch("sagemaker.user_agent.process_studio_metadata_file", return_value="studio_type"): + assert ( + get_user_agent_extra_suffix() + == f"lib/{SDK_PREFIX}#{SDK_VERSION} md/{STUDIO_PREFIX}#studio_type" + ) From fa1a8bf5dc91e9fa64fb3cd3c699824cee33886a Mon Sep 17 00:00:00 2001 From: ci Date: Fri, 3 May 2024 20:28:25 +0000 Subject: [PATCH 14/32] prepare release v2.218.1 --- CHANGELOG.md | 8 ++++++++ VERSION | 2 +- 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 99416fe44a..38092bf59e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,13 @@ # Changelog +## v2.218.1 (2024-05-03) + +### Bug Fixes and Other Changes + + * Fix UserAgent logging in Python SDK + * chore: release tgi 2.0.1 + * chore: update skipped flaky tests + ## v2.218.0 (2024-05-01) ### Features diff --git a/VERSION b/VERSION index c611a0a1ab..a80e33fcf7 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -2.218.1.dev0 +2.218.1 From 0075fb3bea06fd9eac36bccf2e9f99be802b9aa1 Mon Sep 17 00:00:00 2001 From: ci Date: Fri, 3 May 2024 20:28:27 +0000 Subject: [PATCH 15/32] update development version to v2.218.2.dev0 --- VERSION | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/VERSION b/VERSION index a80e33fcf7..b298acdcc9 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -2.218.1 +2.218.2.dev0 From 49e09c3bbc9976b35d86c499884378a6b9cf5285 Mon Sep 17 00:00:00 2001 From: Keerthan Vasist Date: Thu, 2 May 2024 10:38:24 -0700 Subject: [PATCH 16/32] feature: allow choosing js payload by alias in private method --- src/sagemaker/jumpstart/payload_utils.py | 5 ++- .../sagemaker/jumpstart/test_payload_utils.py | 34 ++++++++++++++++--- 2 files changed, 34 insertions(+), 5 deletions(-) diff --git a/src/sagemaker/jumpstart/payload_utils.py b/src/sagemaker/jumpstart/payload_utils.py index e4d31e9c83..9c6716dc64 100644 --- a/src/sagemaker/jumpstart/payload_utils.py +++ b/src/sagemaker/jumpstart/payload_utils.py @@ -62,6 +62,7 @@ def _construct_payload( tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, + alias: Optional[str] = None, ) -> Optional[JumpStartSerializablePayload]: """Returns example payload from prompt. @@ -102,7 +103,9 @@ def _construct_payload( if payloads is None or len(payloads) == 0: return None - payload_to_use: JumpStartSerializablePayload = list(payloads.values())[0] + payload_to_use: JumpStartSerializablePayload = ( + payloads[alias] if alias else list(payloads.values())[0] + ) prompt_key: Optional[str] = payload_to_use.prompt_key if prompt_key is None: diff --git a/tests/unit/sagemaker/jumpstart/test_payload_utils.py b/tests/unit/sagemaker/jumpstart/test_payload_utils.py index afc955e2f3..3c339c9b95 100644 --- a/tests/unit/sagemaker/jumpstart/test_payload_utils.py +++ b/tests/unit/sagemaker/jumpstart/test_payload_utils.py @@ -32,10 +32,36 @@ def test_construct_payload(self, patched_get_model_specs): region = "us-west-2" constructed_payload_body = _construct_payload( - prompt="kobebryant", - model_id=model_id, - model_version="*", - region=region, + prompt="kobebryant", model_id=model_id, model_version="*", region=region + ).body + + self.assertEqual( + { + "hello": {"prompt": "kobebryant"}, + "seed": 43, + }, + constructed_payload_body, + ) + + # Unsupported model + self.assertIsNone( + _construct_payload( + prompt="blah", + model_id="default_payloads", + model_version="*", + region=region, + ) + ) + + @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") + def test_construct_payload_with_specific_alias(self, patched_get_model_specs): + patched_get_model_specs.side_effect = get_special_model_spec + + model_id = "prompt-key" + region = "us-west-2" + + constructed_payload_body = _construct_payload( + prompt="kobebryant", model_id=model_id, model_version="*", region=region, alias="Dog" ).body self.assertEqual( From ab4d1c5575153a4058014f13f0740605938af01e Mon Sep 17 00:00:00 2001 From: Haotian An <33510317+Captainia@users.noreply.github.com> Date: Tue, 23 Apr 2024 13:34:36 -0400 Subject: [PATCH 17/32] Merge Master. --- src/sagemaker/accept_types.py | 3 + src/sagemaker/content_types.py | 3 + src/sagemaker/deserializers.py | 3 + src/sagemaker/environment_variables.py | 3 + src/sagemaker/hyperparameters.py | 3 + src/sagemaker/image_uris.py | 3 + src/sagemaker/instance_types.py | 3 + .../artifacts/environment_variables.py | 7 + .../jumpstart/artifacts/hyperparameters.py | 3 + .../jumpstart/artifacts/image_uris.py | 4 + .../artifacts/incremental_training.py | 3 + .../jumpstart/artifacts/instance_types.py | 6 + src/sagemaker/jumpstart/artifacts/kwargs.py | 12 ++ .../jumpstart/artifacts/metric_definitions.py | 3 + .../jumpstart/artifacts/model_packages.py | 6 + .../jumpstart/artifacts/model_uris.py | 7 + src/sagemaker/jumpstart/artifacts/payloads.py | 3 + .../jumpstart/artifacts/predictors.py | 24 +++ .../jumpstart/artifacts/resource_names.py | 6 +- .../artifacts/resource_requirements.py | 3 + .../jumpstart/artifacts/script_uris.py | 5 + src/sagemaker/jumpstart/estimator.py | 49 ++++-- src/sagemaker/jumpstart/factory/estimator.py | 25 ++- src/sagemaker/jumpstart/factory/model.py | 27 +++- src/sagemaker/jumpstart/model.py | 22 +++ src/sagemaker/jumpstart/notebook_utils.py | 2 + src/sagemaker/jumpstart/types.py | 36 ++++- src/sagemaker/jumpstart/utils.py | 33 +++- src/sagemaker/jumpstart/validators.py | 3 + src/sagemaker/metric_definitions.py | 3 + src/sagemaker/model_uris.py | 4 + src/sagemaker/resource_requirements.py | 3 + src/sagemaker/script_uris.py | 3 + src/sagemaker/serializers.py | 6 + .../jumpstart/estimator/test_estimator.py | 54 +++++++ .../sagemaker/jumpstart/model/test_model.py | 150 ++++++++++++++++++ tests/unit/sagemaker/jumpstart/test_types.py | 8 - tests/unit/sagemaker/jumpstart/utils.py | 17 ++ 38 files changed, 521 insertions(+), 37 deletions(-) diff --git a/src/sagemaker/accept_types.py b/src/sagemaker/accept_types.py index 78aa655e04..7541425868 100644 --- a/src/sagemaker/accept_types.py +++ b/src/sagemaker/accept_types.py @@ -77,6 +77,7 @@ def retrieve_default( tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, + config_name: Optional[str] = None, ) -> str: """Retrieves the default accept type for the model matching the given arguments. @@ -98,6 +99,7 @@ def retrieve_default( object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). Returns: str: The default accept type to use for the model. @@ -117,4 +119,5 @@ def retrieve_default( tolerate_deprecated_model, sagemaker_session=sagemaker_session, model_type=model_type, + config_name=config_name, ) diff --git a/src/sagemaker/content_types.py b/src/sagemaker/content_types.py index 46d0361f67..627feca0d6 100644 --- a/src/sagemaker/content_types.py +++ b/src/sagemaker/content_types.py @@ -77,6 +77,7 @@ def retrieve_default( tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, + config_name: Optional[str] = None, ) -> str: """Retrieves the default content type for the model matching the given arguments. @@ -98,6 +99,7 @@ def retrieve_default( object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). Returns: str: The default content type to use for the model. @@ -117,6 +119,7 @@ def retrieve_default( tolerate_deprecated_model, sagemaker_session=sagemaker_session, model_type=model_type, + config_name=config_name, ) diff --git a/src/sagemaker/deserializers.py b/src/sagemaker/deserializers.py index 1a4be43897..02e61149ec 100644 --- a/src/sagemaker/deserializers.py +++ b/src/sagemaker/deserializers.py @@ -97,6 +97,7 @@ def retrieve_default( tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, + config_name: Optional[str] = None, ) -> BaseDeserializer: """Retrieves the default deserializer for the model matching the given arguments. @@ -118,6 +119,7 @@ def retrieve_default( object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). Returns: BaseDeserializer: The default deserializer to use for the model. @@ -138,4 +140,5 @@ def retrieve_default( tolerate_deprecated_model, sagemaker_session=sagemaker_session, model_type=model_type, + config_name=config_name, ) diff --git a/src/sagemaker/environment_variables.py b/src/sagemaker/environment_variables.py index b67066fcde..8fa52c3ec8 100644 --- a/src/sagemaker/environment_variables.py +++ b/src/sagemaker/environment_variables.py @@ -36,6 +36,7 @@ def retrieve_default( sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, instance_type: Optional[str] = None, script: JumpStartScriptScope = JumpStartScriptScope.INFERENCE, + config_name: Optional[str] = None, ) -> Dict[str, str]: """Retrieves the default container environment variables for the model matching the arguments. @@ -65,6 +66,7 @@ def retrieve_default( variables specific for the instance type. script (JumpStartScriptScope): The JumpStart script for which to retrieve environment variables. + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). Returns: dict: The variables to use for the model. @@ -87,4 +89,5 @@ def retrieve_default( sagemaker_session=sagemaker_session, instance_type=instance_type, script=script, + config_name=config_name, ) diff --git a/src/sagemaker/hyperparameters.py b/src/sagemaker/hyperparameters.py index 5873e37b9f..5c22409c50 100644 --- a/src/sagemaker/hyperparameters.py +++ b/src/sagemaker/hyperparameters.py @@ -36,6 +36,7 @@ def retrieve_default( tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + config_name: Optional[str] = None, ) -> Dict[str, str]: """Retrieves the default training hyperparameters for the model matching the given arguments. @@ -66,6 +67,7 @@ def retrieve_default( object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). Returns: dict: The hyperparameters to use for the model. @@ -86,6 +88,7 @@ def retrieve_default( tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, + config_name=config_name, ) diff --git a/src/sagemaker/image_uris.py b/src/sagemaker/image_uris.py index 143ecc9bdb..97471f2c41 100644 --- a/src/sagemaker/image_uris.py +++ b/src/sagemaker/image_uris.py @@ -68,6 +68,7 @@ def retrieve( inference_tool=None, serverless_inference_config=None, sagemaker_session=DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + config_name=None, ) -> str: """Retrieves the ECR URI for the Docker image matching the given arguments. @@ -121,6 +122,7 @@ def retrieve( object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). Returns: str: The ECR URI for the corresponding SageMaker Docker image. @@ -160,6 +162,7 @@ def retrieve( tolerate_vulnerable_model, tolerate_deprecated_model, sagemaker_session=sagemaker_session, + config_name=config_name, ) if training_compiler_config and (framework in [HUGGING_FACE_FRAMEWORK, "pytorch"]): diff --git a/src/sagemaker/instance_types.py b/src/sagemaker/instance_types.py index 48aaab0ac8..c4af4b2036 100644 --- a/src/sagemaker/instance_types.py +++ b/src/sagemaker/instance_types.py @@ -36,6 +36,7 @@ def retrieve_default( sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, training_instance_type: Optional[str] = None, model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, + config_name: Optional[str] = None, ) -> str: """Retrieves the default instance type for the model matching the given arguments. @@ -64,6 +65,7 @@ def retrieve_default( Optionally supply this to get a inference instance type conditioned on the training instance, to ensure compatability of training artifact to inference instance. (Default: None). + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). Returns: str: The default instance type to use for the model. @@ -88,6 +90,7 @@ def retrieve_default( sagemaker_session=sagemaker_session, training_instance_type=training_instance_type, model_type=model_type, + config_name=config_name, ) diff --git a/src/sagemaker/jumpstart/artifacts/environment_variables.py b/src/sagemaker/jumpstart/artifacts/environment_variables.py index c28c27ed4e..fcb3ce3bf2 100644 --- a/src/sagemaker/jumpstart/artifacts/environment_variables.py +++ b/src/sagemaker/jumpstart/artifacts/environment_variables.py @@ -39,6 +39,7 @@ def _retrieve_default_environment_variables( sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, instance_type: Optional[str] = None, script: JumpStartScriptScope = JumpStartScriptScope.INFERENCE, + config_name: Optional[str] = None, ) -> Dict[str, str]: """Retrieves the inference environment variables for the model matching the given arguments. @@ -68,6 +69,7 @@ def _retrieve_default_environment_variables( environment variables specific for the instance type. script (JumpStartScriptScope): The JumpStart script for which to retrieve environment variables. + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). Returns: dict: the inference environment variables to use for the model. """ @@ -84,6 +86,7 @@ def _retrieve_default_environment_variables( tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, + config_name=config_name, ) default_environment_variables: Dict[str, str] = {} @@ -121,6 +124,7 @@ def _retrieve_default_environment_variables( tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, instance_type=instance_type, + config_name=config_name, ) ) @@ -167,6 +171,7 @@ def _retrieve_gated_model_uri_env_var_value( tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, instance_type: Optional[str] = None, + config_name: Optional[str] = None, ) -> Optional[str]: """Retrieves the gated model env var URI matching the given arguments. @@ -190,6 +195,7 @@ def _retrieve_gated_model_uri_env_var_value( chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). instance_type (str): An instance type to optionally supply in order to get environment variables specific for the instance type. + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). Returns: Optional[str]: the s3 URI to use for the environment variable, or None if the model does not @@ -211,6 +217,7 @@ def _retrieve_gated_model_uri_env_var_value( tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, + config_name=config_name, ) s3_key: Optional[str] = ( diff --git a/src/sagemaker/jumpstart/artifacts/hyperparameters.py b/src/sagemaker/jumpstart/artifacts/hyperparameters.py index d19530ecfb..67db7d260f 100644 --- a/src/sagemaker/jumpstart/artifacts/hyperparameters.py +++ b/src/sagemaker/jumpstart/artifacts/hyperparameters.py @@ -36,6 +36,7 @@ def _retrieve_default_hyperparameters( tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, instance_type: Optional[str] = None, + config_name: Optional[str] = None, ): """Retrieves the training hyperparameters for the model matching the given arguments. @@ -66,6 +67,7 @@ def _retrieve_default_hyperparameters( chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). instance_type (str): An instance type to optionally supply in order to get hyperparameters specific for the instance type. + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). Returns: dict: the hyperparameters to use for the model. """ @@ -82,6 +84,7 @@ def _retrieve_default_hyperparameters( tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, + config_name=config_name, ) default_hyperparameters: Dict[str, str] = {} diff --git a/src/sagemaker/jumpstart/artifacts/image_uris.py b/src/sagemaker/jumpstart/artifacts/image_uris.py index 9d19d5e069..72633320f5 100644 --- a/src/sagemaker/jumpstart/artifacts/image_uris.py +++ b/src/sagemaker/jumpstart/artifacts/image_uris.py @@ -46,6 +46,7 @@ def _retrieve_image_uri( tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + config_name: Optional[str] = None, ): """Retrieves the container image URI for JumpStart models. @@ -95,6 +96,7 @@ def _retrieve_image_uri( object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). Returns: str: the ECR URI for the corresponding SageMaker Docker image. @@ -116,6 +118,7 @@ def _retrieve_image_uri( tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, + config_name=config_name, ) if image_scope == JumpStartScriptScope.INFERENCE: @@ -200,4 +203,5 @@ def _retrieve_image_uri( distribution=distribution, base_framework_version=base_framework_version_override or base_framework_version, training_compiler_config=training_compiler_config, + config_name=config_name, ) diff --git a/src/sagemaker/jumpstart/artifacts/incremental_training.py b/src/sagemaker/jumpstart/artifacts/incremental_training.py index 1b3c6f4b29..8bbe089354 100644 --- a/src/sagemaker/jumpstart/artifacts/incremental_training.py +++ b/src/sagemaker/jumpstart/artifacts/incremental_training.py @@ -33,6 +33,7 @@ def _model_supports_incremental_training( tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + config_name: Optional[str] = None, ) -> bool: """Returns True if the model supports incremental training. @@ -54,6 +55,7 @@ def _model_supports_incremental_training( object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). Returns: bool: the support status for incremental training. """ @@ -70,6 +72,7 @@ def _model_supports_incremental_training( tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, + config_name=config_name, ) return model_specs.supports_incremental_training() diff --git a/src/sagemaker/jumpstart/artifacts/instance_types.py b/src/sagemaker/jumpstart/artifacts/instance_types.py index e7c9c5911d..f4bf212c1c 100644 --- a/src/sagemaker/jumpstart/artifacts/instance_types.py +++ b/src/sagemaker/jumpstart/artifacts/instance_types.py @@ -40,6 +40,7 @@ def _retrieve_default_instance_type( sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, training_instance_type: Optional[str] = None, model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, + config_name: Optional[str] = None, ) -> str: """Retrieves the default instance type for the model. @@ -68,6 +69,7 @@ def _retrieve_default_instance_type( Optionally supply this to get a inference instance type conditioned on the training instance, to ensure compatability of training artifact to inference instance. (Default: None). + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). Returns: str: the default instance type to use for the model or None. @@ -89,6 +91,7 @@ def _retrieve_default_instance_type( tolerate_deprecated_model=tolerate_deprecated_model, model_type=model_type, sagemaker_session=sagemaker_session, + config_name=config_name, ) if scope == JumpStartScriptScope.INFERENCE: @@ -128,6 +131,7 @@ def _retrieve_instance_types( tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, training_instance_type: Optional[str] = None, + config_name: Optional[str] = None, ) -> List[str]: """Retrieves the supported instance types for the model. @@ -156,6 +160,7 @@ def _retrieve_instance_types( Optionally supply this to get a inference instance type conditioned on the training instance, to ensure compatability of training artifact to inference instance. (Default: None). + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). Returns: list: the supported instance types to use for the model or None. @@ -176,6 +181,7 @@ def _retrieve_instance_types( tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, + config_name=config_name, ) if scope == JumpStartScriptScope.INFERENCE: diff --git a/src/sagemaker/jumpstart/artifacts/kwargs.py b/src/sagemaker/jumpstart/artifacts/kwargs.py index 9cd152b0bb..ceb88d9b26 100644 --- a/src/sagemaker/jumpstart/artifacts/kwargs.py +++ b/src/sagemaker/jumpstart/artifacts/kwargs.py @@ -37,6 +37,7 @@ def _retrieve_model_init_kwargs( tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, + config_name: Optional[str] = None, ) -> dict: """Retrieves kwargs for `Model`. @@ -58,6 +59,7 @@ def _retrieve_model_init_kwargs( object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). Returns: dict: the kwargs to use for the use case. """ @@ -75,6 +77,7 @@ def _retrieve_model_init_kwargs( tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, model_type=model_type, + config_name=config_name, ) kwargs = deepcopy(model_specs.model_kwargs) @@ -94,6 +97,7 @@ def _retrieve_model_deploy_kwargs( tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, + config_name: Optional[str] = None, ) -> dict: """Retrieves kwargs for `Model.deploy`. @@ -117,6 +121,7 @@ def _retrieve_model_deploy_kwargs( object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). Returns: dict: the kwargs to use for the use case. @@ -135,6 +140,7 @@ def _retrieve_model_deploy_kwargs( tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, model_type=model_type, + config_name=config_name, ) if volume_size_supported(instance_type) and model_specs.inference_volume_size is not None: @@ -151,6 +157,7 @@ def _retrieve_estimator_init_kwargs( tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + config_name: Optional[str] = None, ) -> dict: """Retrieves kwargs for `Estimator`. @@ -174,6 +181,7 @@ def _retrieve_estimator_init_kwargs( object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). Returns: dict: the kwargs to use for the use case. """ @@ -190,6 +198,7 @@ def _retrieve_estimator_init_kwargs( tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, + config_name=config_name, ) kwargs = deepcopy(model_specs.estimator_kwargs) @@ -210,6 +219,7 @@ def _retrieve_estimator_fit_kwargs( tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + config_name: Optional[str] = None, ) -> dict: """Retrieves kwargs for `Estimator.fit`. @@ -231,6 +241,7 @@ def _retrieve_estimator_fit_kwargs( object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). Returns: dict: the kwargs to use for the use case. @@ -248,6 +259,7 @@ def _retrieve_estimator_fit_kwargs( tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, + config_name=config_name, ) return model_specs.fit_kwargs diff --git a/src/sagemaker/jumpstart/artifacts/metric_definitions.py b/src/sagemaker/jumpstart/artifacts/metric_definitions.py index 57f66155c7..f23b66aed4 100644 --- a/src/sagemaker/jumpstart/artifacts/metric_definitions.py +++ b/src/sagemaker/jumpstart/artifacts/metric_definitions.py @@ -35,6 +35,7 @@ def _retrieve_default_training_metric_definitions( tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, instance_type: Optional[str] = None, + config_name: Optional[str] = None, ) -> Optional[List[Dict[str, str]]]: """Retrieves the default training metric definitions for the model. @@ -58,6 +59,7 @@ def _retrieve_default_training_metric_definitions( chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). instance_type (str): An instance type to optionally supply in order to get metric definitions specific for the instance type. + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). Returns: list: the default training metric definitions to use for the model or None. """ @@ -74,6 +76,7 @@ def _retrieve_default_training_metric_definitions( tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, + config_name=config_name, ) default_metric_definitions = ( diff --git a/src/sagemaker/jumpstart/artifacts/model_packages.py b/src/sagemaker/jumpstart/artifacts/model_packages.py index aa22351771..12166b1a76 100644 --- a/src/sagemaker/jumpstart/artifacts/model_packages.py +++ b/src/sagemaker/jumpstart/artifacts/model_packages.py @@ -37,6 +37,7 @@ def _retrieve_model_package_arn( tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, + config_name: Optional[str] = None, ) -> Optional[str]: """Retrieves associated model pacakge arn for the model. @@ -60,6 +61,7 @@ def _retrieve_model_package_arn( object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). Returns: str: the model package arn to use for the model or None. @@ -78,6 +80,7 @@ def _retrieve_model_package_arn( tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, model_type=model_type, + config_name=config_name, ) if scope == JumpStartScriptScope.INFERENCE: @@ -118,6 +121,7 @@ def _retrieve_model_package_model_artifact_s3_uri( tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + config_name: Optional[str] = None, ) -> Optional[str]: """Retrieves s3 artifact uri associated with model package. @@ -141,6 +145,7 @@ def _retrieve_model_package_model_artifact_s3_uri( object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). Returns: str: the model package artifact uri to use for the model or None. @@ -162,6 +167,7 @@ def _retrieve_model_package_model_artifact_s3_uri( tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, + config_name=config_name, ) if model_specs.training_model_package_artifact_uris is None: diff --git a/src/sagemaker/jumpstart/artifacts/model_uris.py b/src/sagemaker/jumpstart/artifacts/model_uris.py index 6bb2e576fc..00c6d8b9aa 100644 --- a/src/sagemaker/jumpstart/artifacts/model_uris.py +++ b/src/sagemaker/jumpstart/artifacts/model_uris.py @@ -95,6 +95,7 @@ def _retrieve_model_uri( tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + config_name: Optional[str] = None, ): """Retrieves the model artifact S3 URI for the model matching the given arguments. @@ -120,6 +121,8 @@ def _retrieve_model_uri( object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). + Returns: str: the model artifact S3 URI for the corresponding model. @@ -141,6 +144,7 @@ def _retrieve_model_uri( tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, + config_name=config_name, ) model_artifact_key: str @@ -182,6 +186,7 @@ def _model_supports_training_model_uri( tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + config_name: Optional[str] = None, ) -> bool: """Returns True if the model supports training with model uri field. @@ -203,6 +208,7 @@ def _model_supports_training_model_uri( object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). Returns: bool: the support status for model uri with training. """ @@ -219,6 +225,7 @@ def _model_supports_training_model_uri( tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, + config_name=config_name, ) return model_specs.use_training_model_artifact() diff --git a/src/sagemaker/jumpstart/artifacts/payloads.py b/src/sagemaker/jumpstart/artifacts/payloads.py index 3359e32732..2f4a8bb0ac 100644 --- a/src/sagemaker/jumpstart/artifacts/payloads.py +++ b/src/sagemaker/jumpstart/artifacts/payloads.py @@ -37,6 +37,7 @@ def _retrieve_example_payloads( tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, + config_name: Optional[str] = None, ) -> Optional[Dict[str, JumpStartSerializablePayload]]: """Returns example payloads. @@ -58,6 +59,7 @@ def _retrieve_example_payloads( object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). Returns: Optional[Dict[str, JumpStartSerializablePayload]]: dictionary mapping payload aliases to the serializable payload object. @@ -76,6 +78,7 @@ def _retrieve_example_payloads( tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, model_type=model_type, + config_name=config_name, ) default_payloads = model_specs.default_payloads diff --git a/src/sagemaker/jumpstart/artifacts/predictors.py b/src/sagemaker/jumpstart/artifacts/predictors.py index 4f6dfe1fe3..635f063e05 100644 --- a/src/sagemaker/jumpstart/artifacts/predictors.py +++ b/src/sagemaker/jumpstart/artifacts/predictors.py @@ -78,6 +78,7 @@ def _retrieve_default_deserializer( tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, + config_name: Optional[str] = None, ) -> BaseDeserializer: """Retrieves the default deserializer for the model. @@ -98,6 +99,7 @@ def _retrieve_default_deserializer( object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). Returns: BaseDeserializer: the default deserializer to use for the model. @@ -111,6 +113,7 @@ def _retrieve_default_deserializer( tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, model_type=model_type, + config_name=config_name, ) return _retrieve_deserializer_from_accept_type(MIMEType.from_suffixed_type(default_accept_type)) @@ -124,6 +127,7 @@ def _retrieve_default_serializer( tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, + config_name: Optional[str] = None, ) -> BaseSerializer: """Retrieves the default serializer for the model. @@ -144,6 +148,7 @@ def _retrieve_default_serializer( object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). Returns: BaseSerializer: the default serializer to use for the model. """ @@ -156,6 +161,7 @@ def _retrieve_default_serializer( tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, model_type=model_type, + config_name=config_name, ) return _retrieve_serializer_from_content_type(MIMEType.from_suffixed_type(default_content_type)) @@ -169,6 +175,7 @@ def _retrieve_deserializer_options( tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, + config_name: Optional[str] = None, ) -> List[BaseDeserializer]: """Retrieves the supported deserializers for the model. @@ -189,6 +196,7 @@ def _retrieve_deserializer_options( object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). Returns: List[BaseDeserializer]: the supported deserializers to use for the model. """ @@ -201,6 +209,7 @@ def _retrieve_deserializer_options( tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, model_type=model_type, + config_name=config_name, ) seen_classes: Set[Type] = set() @@ -227,6 +236,7 @@ def _retrieve_serializer_options( tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + config_name: Optional[str] = None, ) -> List[BaseSerializer]: """Retrieves the supported serializers for the model. @@ -247,6 +257,7 @@ def _retrieve_serializer_options( object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). Returns: List[BaseSerializer]: the supported serializers to use for the model. """ @@ -258,6 +269,7 @@ def _retrieve_serializer_options( tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, + config_name=config_name, ) seen_classes: Set[Type] = set() @@ -285,6 +297,7 @@ def _retrieve_default_content_type( tolerate_deprecated_model: bool = False, model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + config_name: Optional[str] = None, ) -> str: """Retrieves the default content type for the model. @@ -305,6 +318,7 @@ def _retrieve_default_content_type( object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). Returns: str: the default content type to use for the model. """ @@ -322,6 +336,7 @@ def _retrieve_default_content_type( tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, model_type=model_type, + config_name=config_name, ) default_content_type = model_specs.predictor_specs.default_content_type @@ -336,6 +351,7 @@ def _retrieve_default_accept_type( tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, + config_name: Optional[str] = None, ) -> str: """Retrieves the default accept type for the model. @@ -356,6 +372,7 @@ def _retrieve_default_accept_type( object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). Returns: str: the default accept type to use for the model. """ @@ -373,6 +390,7 @@ def _retrieve_default_accept_type( tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, model_type=model_type, + config_name=config_name, ) default_accept_type = model_specs.predictor_specs.default_accept_type @@ -388,6 +406,7 @@ def _retrieve_supported_accept_types( tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, + config_name: Optional[str] = None, ) -> List[str]: """Retrieves the supported accept types for the model. @@ -408,6 +427,7 @@ def _retrieve_supported_accept_types( object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). Returns: list: the supported accept types to use for the model. """ @@ -425,6 +445,7 @@ def _retrieve_supported_accept_types( tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, model_type=model_type, + config_name=config_name, ) supported_accept_types = model_specs.predictor_specs.supported_accept_types @@ -440,6 +461,7 @@ def _retrieve_supported_content_types( tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, + config_name: Optional[str] = None, ) -> List[str]: """Retrieves the supported content types for the model. @@ -460,6 +482,7 @@ def _retrieve_supported_content_types( object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). Returns: list: the supported content types to use for the model. """ @@ -477,6 +500,7 @@ def _retrieve_supported_content_types( tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, model_type=model_type, + config_name=config_name, ) supported_content_types = model_specs.predictor_specs.supported_content_types diff --git a/src/sagemaker/jumpstart/artifacts/resource_names.py b/src/sagemaker/jumpstart/artifacts/resource_names.py index cffd46d043..b4fdac770b 100644 --- a/src/sagemaker/jumpstart/artifacts/resource_names.py +++ b/src/sagemaker/jumpstart/artifacts/resource_names.py @@ -35,6 +35,8 @@ def _retrieve_resource_name_base( tolerate_deprecated_model: bool = False, model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + scope: JumpStartScriptScope = JumpStartScriptScope.INFERENCE, + config_name: Optional[str] = None, ) -> bool: """Returns default resource name. @@ -56,6 +58,7 @@ def _retrieve_resource_name_base( object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + config_name (Optional[str]): Name of the JumpStart Model config. (Default: None). Returns: str: the default resource name. """ @@ -67,12 +70,13 @@ def _retrieve_resource_name_base( model_specs = verify_model_region_and_return_specs( model_id=model_id, version=model_version, - scope=JumpStartScriptScope.INFERENCE, + scope=scope, region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, model_type=model_type, sagemaker_session=sagemaker_session, + config_name=config_name, ) return model_specs.resource_name_base diff --git a/src/sagemaker/jumpstart/artifacts/resource_requirements.py b/src/sagemaker/jumpstart/artifacts/resource_requirements.py index 369acac85f..49126da336 100644 --- a/src/sagemaker/jumpstart/artifacts/resource_requirements.py +++ b/src/sagemaker/jumpstart/artifacts/resource_requirements.py @@ -54,6 +54,7 @@ def _retrieve_default_resources( model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, instance_type: Optional[str] = None, + config_name: Optional[str] = None, ) -> ResourceRequirements: """Retrieves the default resource requirements for the model. @@ -79,6 +80,7 @@ def _retrieve_default_resources( chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). instance_type (str): An instance type to optionally supply in order to get host requirements specific for the instance type. + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). Returns: str: The default resource requirements to use for the model or None. @@ -102,6 +104,7 @@ def _retrieve_default_resources( tolerate_deprecated_model=tolerate_deprecated_model, model_type=model_type, sagemaker_session=sagemaker_session, + config_name=config_name, ) if scope == JumpStartScriptScope.INFERENCE: diff --git a/src/sagemaker/jumpstart/artifacts/script_uris.py b/src/sagemaker/jumpstart/artifacts/script_uris.py index f69732d2e0..97313ec626 100644 --- a/src/sagemaker/jumpstart/artifacts/script_uris.py +++ b/src/sagemaker/jumpstart/artifacts/script_uris.py @@ -37,6 +37,7 @@ def _retrieve_script_uri( tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + config_name: Optional[str] = None, ): """Retrieves the script S3 URI associated with the model matching the given arguments. @@ -62,6 +63,7 @@ def _retrieve_script_uri( object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). Returns: str: the model script URI for the corresponding model. @@ -83,6 +85,7 @@ def _retrieve_script_uri( tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, + config_name=config_name, ) if script_scope == JumpStartScriptScope.INFERENCE: @@ -108,6 +111,7 @@ def _model_supports_inference_script_uri( tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + config_name: Optional[str] = None, ) -> bool: """Returns True if the model supports inference with script uri field. @@ -145,6 +149,7 @@ def _model_supports_inference_script_uri( tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, + config_name=config_name, ) return model_specs.use_inference_script_uri() diff --git a/src/sagemaker/jumpstart/estimator.py b/src/sagemaker/jumpstart/estimator.py index f53d109dc8..cf9b720607 100644 --- a/src/sagemaker/jumpstart/estimator.py +++ b/src/sagemaker/jumpstart/estimator.py @@ -34,7 +34,9 @@ from sagemaker.jumpstart.factory.estimator import get_deploy_kwargs, get_fit_kwargs, get_init_kwargs from sagemaker.jumpstart.factory.model import get_default_predictor from sagemaker.jumpstart.session_utils import get_model_id_version_from_training_job +from sagemaker.jumpstart.types import JumpStartMetadataConfig from sagemaker.jumpstart.utils import ( + get_jumpstart_configs, validate_model_id_and_get_type, resolve_model_sagemaker_config_field, verify_model_region_and_return_specs, @@ -109,7 +111,7 @@ def __init__( container_arguments: Optional[List[str]] = None, disable_output_compression: Optional[bool] = None, enable_remote_debug: Optional[Union[bool, PipelineVariable]] = None, - enable_session_tag_chaining: Optional[Union[bool, PipelineVariable]] = None, + config_name: Optional[str] = None, ): """Initializes a ``JumpStartEstimator``. @@ -501,8 +503,8 @@ def __init__( to Amazon S3 without compression after training finishes. enable_remote_debug (bool or PipelineVariable): Optional. Specifies whether RemoteDebug is enabled for the training job - enable_session_tag_chaining (bool or PipelineVariable): Optional. - Specifies whether SessionTagChaining is enabled for the training job + config_name (Optional[str]): + Name of the JumpStart Model config to apply. (Default: None). Raises: ValueError: If the model ID is not recognized by JumpStart. @@ -581,7 +583,7 @@ def _validate_model_id_and_get_type_hook(): disable_output_compression=disable_output_compression, enable_infra_check=enable_infra_check, enable_remote_debug=enable_remote_debug, - enable_session_tag_chaining=enable_session_tag_chaining, + config_name=config_name, ) self.model_id = estimator_init_kwargs.model_id @@ -595,6 +597,8 @@ def _validate_model_id_and_get_type_hook(): self.role = estimator_init_kwargs.role self.sagemaker_session = estimator_init_kwargs.sagemaker_session self._enable_network_isolation = estimator_init_kwargs.enable_network_isolation + self.config_name = estimator_init_kwargs.config_name + self.init_kwargs = estimator_init_kwargs.to_kwargs_dict(False) super(JumpStartEstimator, self).__init__(**estimator_init_kwargs.to_kwargs_dict()) @@ -669,6 +673,7 @@ def fit( tolerate_vulnerable_model=self.tolerate_vulnerable_model, tolerate_deprecated_model=self.tolerate_deprecated_model, sagemaker_session=self.sagemaker_session, + config_name=self.config_name, ) return super(JumpStartEstimator, self).fit(**estimator_fit_kwargs.to_kwargs_dict()) @@ -734,12 +739,7 @@ def attach( model_version = model_version or "*" - additional_kwargs = { - "model_id": model_id, - "model_version": model_version, - "tolerate_vulnerable_model": True, # model is already trained - "tolerate_deprecated_model": True, # model is already trained - } + additional_kwargs = {"model_id": model_id, "model_version": model_version} model_specs = verify_model_region_and_return_specs( model_id=model_id, @@ -1085,6 +1085,7 @@ def deploy( git_config=git_config, use_compiled_model=use_compiled_model, training_instance_type=self.instance_type, + config_name=self.config_name, ) predictor = super(JumpStartEstimator, self).deploy( @@ -1101,11 +1102,39 @@ def deploy( tolerate_deprecated_model=self.tolerate_deprecated_model, tolerate_vulnerable_model=self.tolerate_vulnerable_model, sagemaker_session=self.sagemaker_session, + # config_name=self.config_name, ) # If a predictor class was passed, do not mutate predictor return predictor + def list_training_configs(self) -> List[JumpStartMetadataConfig]: + """Returns a list of configs associated with the estimator. + + Raises: + ValueError: If the combination of arguments specified is not supported. + VulnerableJumpStartModelError: If any of the dependencies required by the script have + known security vulnerabilities. + DeprecatedJumpStartModelError: If the version of the model is deprecated. + """ + configs_dict = get_jumpstart_configs( + model_id=self.model_id, + model_version=self.model_version, + model_type=self.model_type, + region=self.region, + scope=JumpStartScriptScope.TRAINING, + sagemaker_session=self.sagemaker_session, + ) + return list(configs_dict.values()) + + def set_training_config(self, config_name: str) -> None: + """Sets the config to apply to the model. + + Args: + config_name (str): The name of the config. + """ + self.__init__(**self.init_kwargs, config_name=config_name) + def __str__(self) -> str: """Overriding str(*) method to make more human-readable.""" return stringify_object(self) diff --git a/src/sagemaker/jumpstart/factory/estimator.py b/src/sagemaker/jumpstart/factory/estimator.py index 387a4a843c..926f313b68 100644 --- a/src/sagemaker/jumpstart/factory/estimator.py +++ b/src/sagemaker/jumpstart/factory/estimator.py @@ -130,7 +130,7 @@ def get_init_kwargs( disable_output_compression: Optional[bool] = None, enable_infra_check: Optional[Union[bool, PipelineVariable]] = None, enable_remote_debug: Optional[Union[bool, PipelineVariable]] = None, - enable_session_tag_chaining: Optional[Union[bool, PipelineVariable]] = None, + config_name: Optional[str] = None, ) -> JumpStartEstimatorInitKwargs: """Returns kwargs required to instantiate `sagemaker.estimator.Estimator` object.""" @@ -189,7 +189,7 @@ def get_init_kwargs( disable_output_compression=disable_output_compression, enable_infra_check=enable_infra_check, enable_remote_debug=enable_remote_debug, - enable_session_tag_chaining=enable_session_tag_chaining, + config_name=config_name, ) estimator_init_kwargs = _add_model_version_to_kwargs(estimator_init_kwargs) @@ -223,6 +223,7 @@ def get_fit_kwargs( tolerate_vulnerable_model: Optional[bool] = None, tolerate_deprecated_model: Optional[bool] = None, sagemaker_session: Optional[Session] = None, + config_name: Optional[str] = None, ) -> JumpStartEstimatorFitKwargs: """Returns kwargs required call `fit` on `sagemaker.estimator.Estimator` object.""" @@ -238,6 +239,7 @@ def get_fit_kwargs( tolerate_deprecated_model=tolerate_deprecated_model, tolerate_vulnerable_model=tolerate_vulnerable_model, sagemaker_session=sagemaker_session, + config_name=config_name, ) estimator_fit_kwargs = _add_model_version_to_kwargs(estimator_fit_kwargs) @@ -289,6 +291,7 @@ def get_deploy_kwargs( use_compiled_model: Optional[bool] = None, model_name: Optional[str] = None, training_instance_type: Optional[str] = None, + config_name: Optional[str] = None, ) -> JumpStartEstimatorDeployKwargs: """Returns kwargs required to call `deploy` on `sagemaker.estimator.Estimator` object.""" @@ -316,6 +319,7 @@ def get_deploy_kwargs( tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, + config_name=config_name, ) model_init_kwargs: JumpStartModelInitKwargs = model.get_init_kwargs( @@ -344,6 +348,7 @@ def get_deploy_kwargs( tolerate_deprecated_model=tolerate_deprecated_model, training_instance_type=training_instance_type, disable_instance_type_logging=True, + config_name=config_name, ) estimator_deploy_kwargs: JumpStartEstimatorDeployKwargs = JumpStartEstimatorDeployKwargs( @@ -388,6 +393,7 @@ def get_deploy_kwargs( tolerate_vulnerable_model=model_init_kwargs.tolerate_vulnerable_model, tolerate_deprecated_model=model_init_kwargs.tolerate_deprecated_model, use_compiled_model=use_compiled_model, + config_name=config_name, ) return estimator_deploy_kwargs @@ -443,6 +449,7 @@ def _add_instance_type_and_count_to_kwargs( tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, sagemaker_session=kwargs.sagemaker_session, + config_name=kwargs.config_name, ) kwargs.instance_count = kwargs.instance_count or 1 @@ -466,6 +473,7 @@ def _add_tags_to_kwargs(kwargs: JumpStartEstimatorInitKwargs) -> JumpStartEstima tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, tolerate_deprecated_model=kwargs.tolerate_deprecated_model, sagemaker_session=kwargs.sagemaker_session, + config_name=kwargs.config_name, ).version if kwargs.sagemaker_session.settings.include_jumpstart_tags: @@ -488,6 +496,7 @@ def _add_image_uri_to_kwargs(kwargs: JumpStartEstimatorInitKwargs) -> JumpStartE tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, sagemaker_session=kwargs.sagemaker_session, + config_name=kwargs.config_name, ) return kwargs @@ -513,6 +522,7 @@ def _add_model_uri_to_kwargs(kwargs: JumpStartEstimatorInitKwargs) -> JumpStartE sagemaker_session=kwargs.sagemaker_session, region=kwargs.region, instance_type=kwargs.instance_type, + config_name=kwargs.config_name, ) if ( @@ -525,6 +535,7 @@ def _add_model_uri_to_kwargs(kwargs: JumpStartEstimatorInitKwargs) -> JumpStartE tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, sagemaker_session=kwargs.sagemaker_session, + config_name=kwargs.config_name, ) ): JUMPSTART_LOGGER.warning( @@ -560,6 +571,7 @@ def _add_source_dir_to_kwargs(kwargs: JumpStartEstimatorInitKwargs) -> JumpStart tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, region=kwargs.region, sagemaker_session=kwargs.sagemaker_session, + config_name=kwargs.config_name, ) return kwargs @@ -580,6 +592,7 @@ def _add_env_to_kwargs( sagemaker_session=kwargs.sagemaker_session, script=JumpStartScriptScope.TRAINING, instance_type=kwargs.instance_type, + config_name=kwargs.config_name, ) model_package_artifact_uri = _retrieve_model_package_model_artifact_s3_uri( @@ -590,6 +603,7 @@ def _add_env_to_kwargs( tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, sagemaker_session=kwargs.sagemaker_session, + config_name=kwargs.config_name, ) if model_package_artifact_uri: @@ -617,6 +631,7 @@ def _add_env_to_kwargs( tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, sagemaker_session=kwargs.sagemaker_session, + config_name=kwargs.config_name, ) if model_specs.is_gated_model(): raise ValueError( @@ -646,9 +661,11 @@ def _add_training_job_name_to_kwargs( model_id=kwargs.model_id, model_version=kwargs.model_version, region=kwargs.region, + scope=JumpStartScriptScope.TRAINING, tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, sagemaker_session=kwargs.sagemaker_session, + config_name=kwargs.config_name, ) kwargs.job_name = kwargs.job_name or ( @@ -675,6 +692,7 @@ def _add_hyperparameters_to_kwargs( tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, sagemaker_session=kwargs.sagemaker_session, instance_type=kwargs.instance_type, + config_name=kwargs.config_name, ) for key, value in default_hyperparameters.items(): @@ -708,6 +726,7 @@ def _add_metric_definitions_to_kwargs( tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, sagemaker_session=kwargs.sagemaker_session, instance_type=kwargs.instance_type, + config_name=kwargs.config_name, ) or [] ) @@ -737,6 +756,7 @@ def _add_estimator_extra_kwargs( tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, sagemaker_session=kwargs.sagemaker_session, + config_name=kwargs.config_name, ) for key, value in estimator_kwargs_to_add.items(): @@ -761,6 +781,7 @@ def _add_fit_extra_kwargs(kwargs: JumpStartEstimatorFitKwargs) -> JumpStartEstim tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, sagemaker_session=kwargs.sagemaker_session, + config_name=kwargs.config_name, ) for key, value in fit_kwargs_to_add.items(): diff --git a/src/sagemaker/jumpstart/factory/model.py b/src/sagemaker/jumpstart/factory/model.py index 28746990e3..25a1d63215 100644 --- a/src/sagemaker/jumpstart/factory/model.py +++ b/src/sagemaker/jumpstart/factory/model.py @@ -72,6 +72,7 @@ def get_default_predictor( tolerate_deprecated_model: bool, sagemaker_session: Session, model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, + config_name: Optional[str] = None, ) -> Predictor: """Converts predictor returned from ``Model.deploy()`` into a JumpStart-specific one. @@ -94,6 +95,7 @@ def get_default_predictor( tolerate_vulnerable_model=tolerate_vulnerable_model, sagemaker_session=sagemaker_session, model_type=model_type, + config_name=config_name, ) predictor.deserializer = deserializers.retrieve_default( model_id=model_id, @@ -103,6 +105,7 @@ def get_default_predictor( tolerate_vulnerable_model=tolerate_vulnerable_model, sagemaker_session=sagemaker_session, model_type=model_type, + config_name=config_name, ) predictor.accept = accept_types.retrieve_default( model_id=model_id, @@ -112,6 +115,7 @@ def get_default_predictor( tolerate_vulnerable_model=tolerate_vulnerable_model, sagemaker_session=sagemaker_session, model_type=model_type, + config_name=config_name, ) predictor.content_type = content_types.retrieve_default( model_id=model_id, @@ -121,6 +125,7 @@ def get_default_predictor( tolerate_vulnerable_model=tolerate_vulnerable_model, sagemaker_session=sagemaker_session, model_type=model_type, + config_name=config_name, ) return predictor @@ -184,7 +189,6 @@ def _add_instance_type_to_kwargs( """Sets instance type based on default or override, returns full kwargs.""" orig_instance_type = kwargs.instance_type - kwargs.instance_type = kwargs.instance_type or instance_types.retrieve_default( region=kwargs.region, model_id=kwargs.model_id, @@ -195,6 +199,7 @@ def _add_instance_type_to_kwargs( sagemaker_session=kwargs.sagemaker_session, training_instance_type=kwargs.training_instance_type, model_type=kwargs.model_type, + config_name=kwargs.config_name, ) if not disable_instance_type_logging and orig_instance_type is None: @@ -226,6 +231,7 @@ def _add_image_uri_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModel tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, sagemaker_session=kwargs.sagemaker_session, + config_name=kwargs.config_name, ) return kwargs @@ -247,6 +253,7 @@ def _add_model_data_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartMode tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, sagemaker_session=kwargs.sagemaker_session, instance_type=kwargs.instance_type, + config_name=kwargs.config_name, ) if isinstance(model_data, str) and model_data.startswith("s3://") and model_data.endswith("/"): @@ -287,6 +294,7 @@ def _add_source_dir_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartMode tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, sagemaker_session=kwargs.sagemaker_session, + config_name=kwargs.config_name, ): source_dir = source_dir or script_uris.retrieve( script_scope=JumpStartScriptScope.INFERENCE, @@ -296,6 +304,7 @@ def _add_source_dir_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartMode tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, sagemaker_session=kwargs.sagemaker_session, + config_name=kwargs.config_name, ) kwargs.source_dir = source_dir @@ -319,6 +328,7 @@ def _add_entry_point_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartMod tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, sagemaker_session=kwargs.sagemaker_session, + config_name=kwargs.config_name, ): entry_point = entry_point or INFERENCE_ENTRY_POINT_SCRIPT_NAME @@ -350,6 +360,7 @@ def _add_env_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModelInitKw sagemaker_session=kwargs.sagemaker_session, script=JumpStartScriptScope.INFERENCE, instance_type=kwargs.instance_type, + config_name=kwargs.config_name, ) for key, value in extra_env_vars.items(): @@ -380,6 +391,7 @@ def _add_model_package_arn_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpSt tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, sagemaker_session=kwargs.sagemaker_session, model_type=kwargs.model_type, + config_name=kwargs.config_name, ) kwargs.model_package_arn = model_package_arn @@ -397,6 +409,7 @@ def _add_extra_model_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModelI tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, sagemaker_session=kwargs.sagemaker_session, model_type=kwargs.model_type, + config_name=kwargs.config_name, ) for key, value in model_kwargs_to_add.items(): @@ -433,6 +446,7 @@ def _add_endpoint_name_to_kwargs( tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, sagemaker_session=kwargs.sagemaker_session, model_type=kwargs.model_type, + config_name=kwargs.config_name, ) kwargs.endpoint_name = kwargs.endpoint_name or ( @@ -455,6 +469,7 @@ def _add_model_name_to_kwargs( tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, sagemaker_session=kwargs.sagemaker_session, model_type=kwargs.model_type, + config_name=kwargs.config_name, ) kwargs.name = kwargs.name or ( @@ -476,6 +491,7 @@ def _add_tags_to_kwargs(kwargs: JumpStartModelDeployKwargs) -> Dict[str, Any]: tolerate_deprecated_model=kwargs.tolerate_deprecated_model, sagemaker_session=kwargs.sagemaker_session, model_type=kwargs.model_type, + config_name=kwargs.config_name, ).version if kwargs.sagemaker_session.settings.include_jumpstart_tags: @@ -498,6 +514,7 @@ def _add_deploy_extra_kwargs(kwargs: JumpStartModelInitKwargs) -> Dict[str, Any] tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, sagemaker_session=kwargs.sagemaker_session, model_type=kwargs.model_type, + config_name=kwargs.config_name, ) for key, value in deploy_kwargs_to_add.items(): @@ -520,6 +537,7 @@ def _add_resources_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModel sagemaker_session=kwargs.sagemaker_session, model_type=kwargs.model_type, instance_type=kwargs.instance_type, + config_name=kwargs.config_name, ) return kwargs @@ -555,6 +573,7 @@ def get_deploy_kwargs( resources: Optional[ResourceRequirements] = None, managed_instance_scaling: Optional[str] = None, endpoint_type: Optional[EndpointType] = None, + config_name: Optional[str] = None, ) -> JumpStartModelDeployKwargs: """Returns kwargs required to call `deploy` on `sagemaker.estimator.Model` object.""" @@ -586,6 +605,7 @@ def get_deploy_kwargs( accept_eula=accept_eula, endpoint_logging=endpoint_logging, resources=resources, + config_name=config_name, ) deploy_kwargs = _add_sagemaker_session_to_kwargs(kwargs=deploy_kwargs) @@ -639,6 +659,7 @@ def get_register_kwargs( data_input_configuration: Optional[str] = None, skip_model_validation: Optional[str] = None, source_uri: Optional[str] = None, + config_name: Optional[str] = None, ) -> JumpStartModelRegisterKwargs: """Returns kwargs required to call `register` on `sagemaker.estimator.Model` object.""" @@ -671,6 +692,7 @@ def get_register_kwargs( data_input_configuration=data_input_configuration, skip_model_validation=skip_model_validation, source_uri=source_uri, + config_name=config_name, ) model_specs = verify_model_region_and_return_specs( @@ -681,6 +703,7 @@ def get_register_kwargs( sagemaker_session=sagemaker_session, tolerate_deprecated_model=tolerate_deprecated_model, tolerate_vulnerable_model=tolerate_vulnerable_model, + config_name=config_name, ) register_kwargs.content_types = ( @@ -723,6 +746,7 @@ def get_init_kwargs( training_instance_type: Optional[str] = None, disable_instance_type_logging: bool = False, resources: Optional[ResourceRequirements] = None, + config_name: Optional[str] = None, ) -> JumpStartModelInitKwargs: """Returns kwargs required to instantiate `sagemaker.estimator.Model` object.""" @@ -754,6 +778,7 @@ def get_init_kwargs( model_package_arn=model_package_arn, training_instance_type=training_instance_type, resources=resources, + config_name=config_name, ) model_init_kwargs = _add_model_version_to_kwargs(kwargs=model_init_kwargs) diff --git a/src/sagemaker/jumpstart/model.py b/src/sagemaker/jumpstart/model.py index 4529bc11b9..ad70ffa805 100644 --- a/src/sagemaker/jumpstart/model.py +++ b/src/sagemaker/jumpstart/model.py @@ -92,6 +92,7 @@ def __init__( git_config: Optional[Dict[str, str]] = None, model_package_arn: Optional[str] = None, resources: Optional[ResourceRequirements] = None, + config_name: Optional[str] = None, ): """Initializes a ``JumpStartModel``. @@ -277,6 +278,8 @@ def __init__( for a model to be deployed to an endpoint. Only EndpointType.INFERENCE_COMPONENT_BASED supports this feature. (Default: None). + config_name (Optional[str]): The name of the JumpStartConfig that can be + optionally applied to the model and override corresponding fields. Raises: ValueError: If the model ID is not recognized by JumpStart. """ @@ -326,6 +329,7 @@ def _validate_model_id_and_type(): git_config=git_config, model_package_arn=model_package_arn, resources=resources, + config_name=config_name, ) self.orig_predictor_cls = predictor_cls @@ -338,6 +342,7 @@ def _validate_model_id_and_type(): self.tolerate_deprecated_model = model_init_kwargs.tolerate_deprecated_model self.region = model_init_kwargs.region self.sagemaker_session = model_init_kwargs.sagemaker_session + self.config_name = config_name if self.model_type == JumpStartModelType.PROPRIETARY: self.log_subscription_warning() @@ -345,6 +350,7 @@ def _validate_model_id_and_type(): super(JumpStartModel, self).__init__(**model_init_kwargs.to_kwargs_dict()) self.model_package_arn = model_init_kwargs.model_package_arn + self.init_kwargs = model_init_kwargs.to_kwargs_dict(False) def log_subscription_warning(self) -> None: """Log message prompting the customer to subscribe to the proprietary model.""" @@ -402,6 +408,18 @@ def retrieve_example_payload(self) -> JumpStartSerializablePayload: sagemaker_session=self.sagemaker_session, ) + def set_deployment_config(self, config_name: Optional[str]) -> None: + """Sets the deployment config to apply to the model. + + Args: + config_name (Optional[str]): + The name of the deployment config. Set to None to unset + any existing config that is applied to the model. + """ + self.__init__( + model_id=self.model_id, model_version=self.model_version, config_name=config_name + ) + def _create_sagemaker_model( self, instance_type=None, @@ -625,6 +643,7 @@ def deploy( managed_instance_scaling=managed_instance_scaling, endpoint_type=endpoint_type, model_type=self.model_type, + config_name=self.config_name, ) if ( self.model_type == JumpStartModelType.PROPRIETARY @@ -644,6 +663,7 @@ def deploy( model_type=self.model_type, scope=JumpStartScriptScope.INFERENCE, sagemaker_session=self.sagemaker_session, + config_name=self.config_name, ).model_subscription_link get_proprietary_model_subscription_error(e, subscription_link) raise @@ -659,6 +679,7 @@ def deploy( tolerate_vulnerable_model=self.tolerate_vulnerable_model, sagemaker_session=self.sagemaker_session, model_type=self.model_type, + config_name=self.config_name, ) # If a predictor class was passed, do not mutate predictor @@ -769,6 +790,7 @@ def register( data_input_configuration=data_input_configuration, skip_model_validation=skip_model_validation, source_uri=source_uri, + config_name=self.config_name, ) model_package = super(JumpStartModel, self).register(**register_kwargs.to_kwargs_dict()) diff --git a/src/sagemaker/jumpstart/notebook_utils.py b/src/sagemaker/jumpstart/notebook_utils.py index 83613cd71b..781548b42a 100644 --- a/src/sagemaker/jumpstart/notebook_utils.py +++ b/src/sagemaker/jumpstart/notebook_utils.py @@ -535,6 +535,7 @@ def get_model_url( model_version: str, region: Optional[str] = None, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + config_name: Optional[str] = None, ) -> str: """Retrieve web url describing pretrained model. @@ -563,5 +564,6 @@ def get_model_url( sagemaker_session=sagemaker_session, scope=JumpStartScriptScope.INFERENCE, model_type=model_type, + config_name=config_name, ) return model_specs.url diff --git a/src/sagemaker/jumpstart/types.py b/src/sagemaker/jumpstart/types.py index 05c38da266..1de0f662da 100644 --- a/src/sagemaker/jumpstart/types.py +++ b/src/sagemaker/jumpstart/types.py @@ -1064,8 +1064,9 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: Dictionary representation of the config component. """ for field in json_obj.keys(): - if field in self.__slots__: - setattr(self, field, json_obj[field]) + if field not in self.__slots__: + raise ValueError(f"Invalid component field: {field}") + setattr(self, field, json_obj[field]) class JumpStartMetadataConfig(JumpStartDataHolderType): @@ -1464,11 +1465,11 @@ class JumpStartKwargs(JumpStartDataHolderType): SERIALIZATION_EXCLUSION_SET: Set[str] = set() - def to_kwargs_dict(self): + def to_kwargs_dict(self, exclude_keys: bool = True): """Serializes object to dictionary to be used for kwargs for method arguments.""" kwargs_dict = {} for field in self.__slots__: - if field not in self.SERIALIZATION_EXCLUSION_SET: + if exclude_keys and field not in self.SERIALIZATION_EXCLUSION_SET or not exclude_keys: att_value = getattr(self, field) if att_value is not None: kwargs_dict[field] = getattr(self, field) @@ -1506,6 +1507,7 @@ class JumpStartModelInitKwargs(JumpStartKwargs): "model_package_arn", "training_instance_type", "resources", + "config_name", ] SERIALIZATION_EXCLUSION_SET = { @@ -1518,6 +1520,7 @@ class JumpStartModelInitKwargs(JumpStartKwargs): "region", "model_package_arn", "training_instance_type", + "config_name", } def __init__( @@ -1549,6 +1552,7 @@ def __init__( model_package_arn: Optional[str] = None, training_instance_type: Optional[str] = None, resources: Optional[ResourceRequirements] = None, + config_name: Optional[str] = None, ) -> None: """Instantiates JumpStartModelInitKwargs object.""" @@ -1579,6 +1583,7 @@ def __init__( self.model_package_arn = model_package_arn self.training_instance_type = training_instance_type self.resources = resources + self.config_name = config_name class JumpStartModelDeployKwargs(JumpStartKwargs): @@ -1614,6 +1619,7 @@ class JumpStartModelDeployKwargs(JumpStartKwargs): "endpoint_logging", "resources", "endpoint_type", + "config_name", ] SERIALIZATION_EXCLUSION_SET = { @@ -1625,6 +1631,7 @@ class JumpStartModelDeployKwargs(JumpStartKwargs): "tolerate_vulnerable_model", "sagemaker_session", "training_instance_type", + "config_name", } def __init__( @@ -1658,6 +1665,7 @@ def __init__( endpoint_logging: Optional[bool] = None, resources: Optional[ResourceRequirements] = None, endpoint_type: Optional[EndpointType] = None, + config_name: Optional[str] = None, ) -> None: """Instantiates JumpStartModelDeployKwargs object.""" @@ -1690,6 +1698,7 @@ def __init__( self.endpoint_logging = endpoint_logging self.resources = resources self.endpoint_type = endpoint_type + self.config_name = config_name class JumpStartEstimatorInitKwargs(JumpStartKwargs): @@ -1750,7 +1759,7 @@ class JumpStartEstimatorInitKwargs(JumpStartKwargs): "disable_output_compression", "enable_infra_check", "enable_remote_debug", - "enable_session_tag_chaining", + "config_name", ] SERIALIZATION_EXCLUSION_SET = { @@ -1760,6 +1769,7 @@ class JumpStartEstimatorInitKwargs(JumpStartKwargs): "model_id", "model_version", "model_type", + "config_name", } def __init__( @@ -1818,7 +1828,7 @@ def __init__( disable_output_compression: Optional[bool] = None, enable_infra_check: Optional[Union[bool, PipelineVariable]] = None, enable_remote_debug: Optional[Union[bool, PipelineVariable]] = None, - enable_session_tag_chaining: Optional[Union[bool, PipelineVariable]] = None, + config_name: Optional[str] = None, ) -> None: """Instantiates JumpStartEstimatorInitKwargs object.""" @@ -1878,7 +1888,7 @@ def __init__( self.disable_output_compression = disable_output_compression self.enable_infra_check = enable_infra_check self.enable_remote_debug = enable_remote_debug - self.enable_session_tag_chaining = enable_session_tag_chaining + self.config_name = config_name class JumpStartEstimatorFitKwargs(JumpStartKwargs): @@ -1897,6 +1907,7 @@ class JumpStartEstimatorFitKwargs(JumpStartKwargs): "tolerate_deprecated_model", "tolerate_vulnerable_model", "sagemaker_session", + "config_name", ] SERIALIZATION_EXCLUSION_SET = { @@ -1907,6 +1918,7 @@ class JumpStartEstimatorFitKwargs(JumpStartKwargs): "tolerate_deprecated_model", "tolerate_vulnerable_model", "sagemaker_session", + "config_name", } def __init__( @@ -1923,6 +1935,7 @@ def __init__( tolerate_deprecated_model: Optional[bool] = None, tolerate_vulnerable_model: Optional[bool] = None, sagemaker_session: Optional[Session] = None, + config_name: Optional[str] = None, ) -> None: """Instantiates JumpStartEstimatorInitKwargs object.""" @@ -1938,6 +1951,7 @@ def __init__( self.tolerate_deprecated_model = tolerate_deprecated_model self.tolerate_vulnerable_model = tolerate_vulnerable_model self.sagemaker_session = sagemaker_session + self.config_name = config_name class JumpStartEstimatorDeployKwargs(JumpStartKwargs): @@ -1983,6 +1997,7 @@ class JumpStartEstimatorDeployKwargs(JumpStartKwargs): "tolerate_vulnerable_model", "model_name", "use_compiled_model", + "config_name", ] SERIALIZATION_EXCLUSION_SET = { @@ -1992,6 +2007,7 @@ class JumpStartEstimatorDeployKwargs(JumpStartKwargs): "model_id", "model_version", "sagemaker_session", + "config_name", } def __init__( @@ -2035,6 +2051,7 @@ def __init__( tolerate_deprecated_model: Optional[bool] = None, tolerate_vulnerable_model: Optional[bool] = None, use_compiled_model: bool = False, + config_name: Optional[str] = None, ) -> None: """Instantiates JumpStartEstimatorInitKwargs object.""" @@ -2077,6 +2094,7 @@ def __init__( self.tolerate_deprecated_model = tolerate_deprecated_model self.tolerate_vulnerable_model = tolerate_vulnerable_model self.use_compiled_model = use_compiled_model + self.config_name = config_name class JumpStartModelRegisterKwargs(JumpStartKwargs): @@ -2111,6 +2129,7 @@ class JumpStartModelRegisterKwargs(JumpStartKwargs): "data_input_configuration", "skip_model_validation", "source_uri", + "config_name", ] SERIALIZATION_EXCLUSION_SET = { @@ -2120,6 +2139,7 @@ class JumpStartModelRegisterKwargs(JumpStartKwargs): "model_id", "model_version", "sagemaker_session", + "config_name", } def __init__( @@ -2152,6 +2172,7 @@ def __init__( data_input_configuration: Optional[str] = None, skip_model_validation: Optional[str] = None, source_uri: Optional[str] = None, + config_name: Optional[str] = None, ) -> None: """Instantiates JumpStartModelRegisterKwargs object.""" @@ -2184,3 +2205,4 @@ def __init__( self.data_input_configuration = data_input_configuration self.skip_model_validation = skip_model_validation self.source_uri = source_uri + self.config_name = config_name diff --git a/src/sagemaker/jumpstart/utils.py b/src/sagemaker/jumpstart/utils.py index 63cfac0939..1459594faa 100644 --- a/src/sagemaker/jumpstart/utils.py +++ b/src/sagemaker/jumpstart/utils.py @@ -547,6 +547,7 @@ def verify_model_region_and_return_specs( tolerate_deprecated_model: bool = False, sagemaker_session: Session = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION, model_type: enums.JumpStartModelType = enums.JumpStartModelType.OPEN_WEIGHTS, + config_name: Optional[str] = None, ) -> JumpStartModelSpecs: """Verifies that an acceptable model_id, version, scope, and region combination is provided. @@ -569,6 +570,7 @@ def verify_model_region_and_return_specs( object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). Raises: NotImplementedError: If the scope is not supported. @@ -634,6 +636,9 @@ def verify_model_region_and_return_specs( scope=constants.JumpStartScriptScope.TRAINING, ) + if model_specs and config_name: + model_specs.set_config(config_name, scope) + return model_specs @@ -890,7 +895,11 @@ def get_config_names( scope: enums.JumpStartScriptScope = enums.JumpStartScriptScope.INFERENCE, model_type: enums.JumpStartModelType = enums.JumpStartModelType.OPEN_WEIGHTS, ) -> List[str]: - """Returns a list of config names for the given model ID and region.""" + """Returns a list of config names for the given model ID and region. + + Raises: + ValueError: If the script scope is not supported by JumpStart. + """ model_specs = verify_model_region_and_return_specs( region=region, model_id=model_id, @@ -905,7 +914,7 @@ def get_config_names( elif scope == enums.JumpStartScriptScope.TRAINING: metadata_configs = model_specs.training_configs else: - raise ValueError(f"Unknown script scope {scope}.") + raise ValueError(f"Unknown script scope: {scope}.") return list(metadata_configs.configs.keys()) if metadata_configs else [] @@ -919,7 +928,11 @@ def get_benchmark_stats( scope: enums.JumpStartScriptScope = enums.JumpStartScriptScope.INFERENCE, model_type: enums.JumpStartModelType = enums.JumpStartModelType.OPEN_WEIGHTS, ) -> Dict[str, List[JumpStartBenchmarkStat]]: - """Returns benchmark stats for the given model ID and region.""" + """Returns benchmark stats for the given model ID and region. + + Raises: + ValueError: If the script scope is not supported by JumpStart. + """ model_specs = verify_model_region_and_return_specs( region=region, model_id=model_id, @@ -934,7 +947,7 @@ def get_benchmark_stats( elif scope == enums.JumpStartScriptScope.TRAINING: metadata_configs = model_specs.training_configs else: - raise ValueError(f"Unknown script scope {scope}.") + raise ValueError(f"Unknown script scope: {scope}.") if not config_names: config_names = metadata_configs.configs.keys() if metadata_configs else [] @@ -942,7 +955,7 @@ def get_benchmark_stats( benchmark_stats = {} for config_name in config_names: if config_name not in metadata_configs.configs: - raise ValueError(f"Unknown config name: '{config_name}'") + raise ValueError(f"Unknown config name: {config_name}") benchmark_stats[config_name] = metadata_configs.configs.get(config_name).benchmark_metrics return benchmark_stats @@ -956,8 +969,12 @@ def get_jumpstart_configs( sagemaker_session: Optional[Session] = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION, scope: enums.JumpStartScriptScope = enums.JumpStartScriptScope.INFERENCE, model_type: enums.JumpStartModelType = enums.JumpStartModelType.OPEN_WEIGHTS, -) -> Dict[str, List[JumpStartMetadataConfig]]: - """Returns metadata configs for the given model ID and region.""" +) -> Dict[str, JumpStartMetadataConfig]: + """Returns metadata configs for the given model ID and region. + + Raises: + ValueError: If the script scope is not supported by JumpStart. + """ model_specs = verify_model_region_and_return_specs( region=region, model_id=model_id, @@ -972,7 +989,7 @@ def get_jumpstart_configs( elif scope == enums.JumpStartScriptScope.TRAINING: metadata_configs = model_specs.training_configs else: - raise ValueError(f"Unknown script scope {scope}.") + raise ValueError(f"Unknown script scope: {scope}.") if not config_names: config_names = metadata_configs.configs.keys() if metadata_configs else [] diff --git a/src/sagemaker/jumpstart/validators.py b/src/sagemaker/jumpstart/validators.py index c7098a1185..bcb0365f7b 100644 --- a/src/sagemaker/jumpstart/validators.py +++ b/src/sagemaker/jumpstart/validators.py @@ -171,6 +171,7 @@ def validate_hyperparameters( sagemaker_session: Optional[session.Session] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, + config_name: Optional[str] = None, ) -> None: """Validate hyperparameters for JumpStart models. @@ -193,6 +194,7 @@ def validate_hyperparameters( tolerate_deprecated_model (bool): True if deprecated models should be tolerated (exception not raised). False if these models should raise an exception. (Default: False). + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). Raises: JumpStartHyperparametersError: If the hyperparameters are not formatted correctly, @@ -218,6 +220,7 @@ def validate_hyperparameters( sagemaker_session=sagemaker_session, tolerate_deprecated_model=tolerate_deprecated_model, tolerate_vulnerable_model=tolerate_vulnerable_model, + config_name=config_name, ) hyperparameters_specs = model_specs.hyperparameters diff --git a/src/sagemaker/metric_definitions.py b/src/sagemaker/metric_definitions.py index 71dd26db45..0c066ff801 100644 --- a/src/sagemaker/metric_definitions.py +++ b/src/sagemaker/metric_definitions.py @@ -33,6 +33,7 @@ def retrieve_default( tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + config_name: Optional[str] = None, ) -> Optional[List[Dict[str, str]]]: """Retrieves the default training metric definitions for the model matching the given arguments. @@ -56,6 +57,7 @@ def retrieve_default( object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). Returns: list: The default metric definitions to use for the model or None. @@ -76,4 +78,5 @@ def retrieve_default( tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, + config_name=config_name, ) diff --git a/src/sagemaker/model_uris.py b/src/sagemaker/model_uris.py index 937180bd44..122647e536 100644 --- a/src/sagemaker/model_uris.py +++ b/src/sagemaker/model_uris.py @@ -34,6 +34,7 @@ def retrieve( tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + config_name: Optional[str] = None, ) -> str: """Retrieves the model artifact Amazon S3 URI for the model matching the given arguments. @@ -57,6 +58,8 @@ def retrieve( object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). + Returns: str: The model artifact S3 URI for the corresponding model. @@ -81,4 +84,5 @@ def retrieve( tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, + config_name=config_name, ) diff --git a/src/sagemaker/resource_requirements.py b/src/sagemaker/resource_requirements.py index df14ac558f..7808d0172a 100644 --- a/src/sagemaker/resource_requirements.py +++ b/src/sagemaker/resource_requirements.py @@ -37,6 +37,7 @@ def retrieve_default( model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, instance_type: Optional[str] = None, + config_name: Optional[str] = None, ) -> ResourceRequirements: """Retrieves the default resource requirements for the model matching the given arguments. @@ -62,6 +63,7 @@ def retrieve_default( chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). instance_type (str): An instance type to optionally supply in order to get host requirements specific for the instance type. + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). Returns: str: The default resource requirements to use for the model. @@ -87,4 +89,5 @@ def retrieve_default( model_type=model_type, sagemaker_session=sagemaker_session, instance_type=instance_type, + config_name=config_name, ) diff --git a/src/sagemaker/script_uris.py b/src/sagemaker/script_uris.py index 9a1c4933d2..6e10785498 100644 --- a/src/sagemaker/script_uris.py +++ b/src/sagemaker/script_uris.py @@ -33,6 +33,7 @@ def retrieve( tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + config_name: Optional[str] = None, ) -> str: """Retrieves the script S3 URI associated with the model matching the given arguments. @@ -55,6 +56,7 @@ def retrieve( object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). Returns: str: The model script URI for the corresponding model. @@ -78,4 +80,5 @@ def retrieve( tolerate_vulnerable_model, tolerate_deprecated_model, sagemaker_session=sagemaker_session, + config_name=config_name, ) diff --git a/src/sagemaker/serializers.py b/src/sagemaker/serializers.py index aefb52bd97..d197df731c 100644 --- a/src/sagemaker/serializers.py +++ b/src/sagemaker/serializers.py @@ -45,6 +45,7 @@ def retrieve_options( tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + config_name: Optional[str] = None, ) -> List[BaseSerializer]: """Retrieves the supported serializers for the model matching the given arguments. @@ -66,6 +67,7 @@ def retrieve_options( object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). Returns: List[SimpleBaseSerializer]: The supported serializers to use for the model. @@ -85,6 +87,7 @@ def retrieve_options( tolerate_vulnerable_model, tolerate_deprecated_model, sagemaker_session=sagemaker_session, + config_name=config_name, ) @@ -96,6 +99,7 @@ def retrieve_default( tolerate_deprecated_model: bool = False, model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + config_name: Optional[str] = None, ) -> BaseSerializer: """Retrieves the default serializer for the model matching the given arguments. @@ -117,6 +121,7 @@ def retrieve_default( object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). Returns: SimpleBaseSerializer: The default serializer to use for the model. @@ -137,4 +142,5 @@ def retrieve_default( tolerate_deprecated_model, sagemaker_session=sagemaker_session, model_type=model_type, + config_name=config_name, ) diff --git a/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py b/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py index 36d8b11fab..f07bb44ba1 100644 --- a/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py +++ b/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py @@ -46,6 +46,8 @@ from sagemaker.model import Model from sagemaker.predictor import Predictor from tests.unit.sagemaker.jumpstart.utils import ( + get_prototype_manifest, + get_prototype_spec_with_configs, get_special_model_spec, overwrite_dictionary, ) @@ -1113,6 +1115,7 @@ def test_jumpstart_estimator_kwargs_match_parent_class(self): "region", "tolerate_vulnerable_model", "tolerate_deprecated_model", + "config_name", } assert parent_class_init_args - js_class_init_args == init_args_to_skip @@ -1370,6 +1373,7 @@ def test_incremental_training_with_unsupported_model_logs_warning( tolerate_deprecated_model=False, tolerate_vulnerable_model=False, sagemaker_session=sagemaker_session, + config_name=None, ) @mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type") @@ -1421,6 +1425,7 @@ def test_incremental_training_with_supported_model_doesnt_log_warning( tolerate_deprecated_model=False, tolerate_vulnerable_model=False, sagemaker_session=sagemaker_session, + config_name=None, ) @mock.patch("sagemaker.utils.sagemaker_timestamp") @@ -1852,6 +1857,55 @@ def test_jumpstart_estimator_session( assert len(s3_clients) == 1 assert list(s3_clients)[0] == session.s3_client + @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") + @mock.patch("sagemaker.jumpstart.factory.estimator.Session") + @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") + @mock.patch("sagemaker.jumpstart.estimator.Estimator.fit") + @mock.patch("sagemaker.jumpstart.estimator.Estimator.__init__") + @mock.patch("sagemaker.jumpstart.factory.estimator.JUMPSTART_DEFAULT_REGION_NAME", region) + def test_estimator_initialization_with_config_name( + self, + mock_estimator_init: mock.Mock, + mock_estimator_fit: mock.Mock, + mock_get_model_specs: mock.Mock, + mock_session: mock.Mock, + mock_get_manifest: mock.Mock, + ): + mock_get_model_specs.side_effect = get_prototype_spec_with_configs + mock_get_manifest.side_effect = ( + lambda region, model_type, *args, **kwargs: get_prototype_manifest(region, model_type) + ) + mock_estimator_fit.return_value = default_predictor + + model_id, _ = "pytorch-eqa-bert-base-cased", "*" + + mock_session.return_value = sagemaker_session + + estimator = JumpStartEstimator(model_id=model_id, config_name="neuron-training") + + mock_estimator_init.assert_called_once_with( + instance_type="ml.p2.xlarge", + instance_count=1, + image_uri="763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-training:1.5.0-gpu-py3", + model_uri="s3://jumpstart-cache-prod-us-west-2/artifacts/meta-textgeneration-llama-2-7b/" + "neuron-training/model/", + source_dir="s3://jumpstart-cache-prod-us-west-2/source-directory-tarballs/pytorch/" + "transfer_learning/eqa/v1.0.0/sourcedir.tar.gz", + entry_point="transfer_learning.py", + hyperparameters={"epochs": "3", "adam-learning-rate": "2e-05", "batch-size": "4"}, + role="fake role! do not use!", + sagemaker_session=sagemaker_session, + tags=[ + {"Key": JumpStartTag.MODEL_ID, "Value": "pytorch-eqa-bert-base-cased"}, + {"Key": JumpStartTag.MODEL_VERSION, "Value": "1.0.0"}, + ], + enable_network_isolation=False, + ) + + estimator.fit() + + mock_estimator_fit.assert_called_once_with(wait=True) + def test_jumpstart_estimator_requires_model_id(): with pytest.raises(ValueError): diff --git a/tests/unit/sagemaker/jumpstart/model/test_model.py b/tests/unit/sagemaker/jumpstart/model/test_model.py index 8b00eb5bcd..cb7b602fbf 100644 --- a/tests/unit/sagemaker/jumpstart/model/test_model.py +++ b/tests/unit/sagemaker/jumpstart/model/test_model.py @@ -40,6 +40,7 @@ from sagemaker.compute_resource_requirements.resource_requirements import ResourceRequirements from tests.unit.sagemaker.jumpstart.utils import ( + get_prototype_spec_with_configs, get_spec_from_base_spec, get_special_model_spec, overwrite_dictionary, @@ -715,6 +716,7 @@ def test_jumpstart_model_kwargs_match_parent_class(self): "tolerate_deprecated_model", "instance_type", "model_package_arn", + "config_name", } assert parent_class_init_args - js_class_init_args == init_args_to_skip @@ -786,6 +788,7 @@ def test_no_predictor_returns_default_predictor( tolerate_vulnerable_model=False, sagemaker_session=model.sagemaker_session, model_type=JumpStartModelType.OPEN_WEIGHTS, + config_name=None, ) self.assertEqual(type(predictor), Predictor) self.assertEqual(predictor, default_predictor_with_presets) @@ -1414,6 +1417,153 @@ def test_model_local_mode( endpoint_logging=False, ) + @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") + @mock.patch("sagemaker.jumpstart.factory.model.Session") + @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") + @mock.patch("sagemaker.jumpstart.model.Model.deploy") + @mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region) + def test_model_initialization_with_config_name( + self, + mock_model_deploy: mock.Mock, + mock_get_model_specs: mock.Mock, + mock_session: mock.Mock, + mock_get_manifest: mock.Mock, + ): + mock_get_model_specs.side_effect = get_prototype_spec_with_configs + mock_get_manifest.side_effect = ( + lambda region, model_type, *args, **kwargs: get_prototype_manifest(region, model_type) + ) + mock_model_deploy.return_value = default_predictor + + model_id, _ = "pytorch-eqa-bert-base-cased", "*" + + mock_session.return_value = sagemaker_session + + model = JumpStartModel(model_id=model_id, config_name="neuron-inference") + + model.deploy() + + mock_model_deploy.assert_called_once_with( + initial_instance_count=1, + instance_type="ml.inf2.xlarge", + tags=[ + {"Key": JumpStartTag.MODEL_ID, "Value": "pytorch-eqa-bert-base-cased"}, + {"Key": JumpStartTag.MODEL_VERSION, "Value": "1.0.0"}, + ], + wait=True, + endpoint_logging=False, + ) + + @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") + @mock.patch("sagemaker.jumpstart.factory.model.Session") + @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") + @mock.patch("sagemaker.jumpstart.model.Model.deploy") + @mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region) + def test_model_set_deployment_config( + self, + mock_model_deploy: mock.Mock, + mock_get_model_specs: mock.Mock, + mock_session: mock.Mock, + mock_get_manifest: mock.Mock, + ): + mock_get_model_specs.side_effect = get_prototype_model_spec + mock_get_manifest.side_effect = ( + lambda region, model_type, *args, **kwargs: get_prototype_manifest(region, model_type) + ) + mock_model_deploy.return_value = default_predictor + + model_id, _ = "pytorch-eqa-bert-base-cased", "*" + + mock_session.return_value = sagemaker_session + + model = JumpStartModel(model_id=model_id) + + model.deploy() + + mock_model_deploy.assert_called_once_with( + initial_instance_count=1, + instance_type="ml.p2.xlarge", + tags=[ + {"Key": JumpStartTag.MODEL_ID, "Value": "pytorch-eqa-bert-base-cased"}, + {"Key": JumpStartTag.MODEL_VERSION, "Value": "1.0.0"}, + ], + wait=True, + endpoint_logging=False, + ) + + mock_get_model_specs.reset_mock() + mock_model_deploy.reset_mock() + mock_get_model_specs.side_effect = get_prototype_spec_with_configs + model.set_deployment_config("neuron-inference") + + model.deploy() + + mock_model_deploy.assert_called_once_with( + initial_instance_count=1, + instance_type="ml.inf2.xlarge", + tags=[ + {"Key": JumpStartTag.MODEL_ID, "Value": "pytorch-eqa-bert-base-cased"}, + {"Key": JumpStartTag.MODEL_VERSION, "Value": "1.0.0"}, + ], + wait=True, + endpoint_logging=False, + ) + + @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") + @mock.patch("sagemaker.jumpstart.factory.model.Session") + @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") + @mock.patch("sagemaker.jumpstart.model.Model.deploy") + @mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region) + def test_model_unset_deployment_config( + self, + mock_model_deploy: mock.Mock, + mock_get_model_specs: mock.Mock, + mock_session: mock.Mock, + mock_get_manifest: mock.Mock, + ): + mock_get_model_specs.side_effect = get_prototype_spec_with_configs + mock_get_manifest.side_effect = ( + lambda region, model_type, *args, **kwargs: get_prototype_manifest(region, model_type) + ) + mock_model_deploy.return_value = default_predictor + + model_id, _ = "pytorch-eqa-bert-base-cased", "*" + + mock_session.return_value = sagemaker_session + + model = JumpStartModel(model_id=model_id, config_name="neuron-inference") + + model.deploy() + + mock_model_deploy.assert_called_once_with( + initial_instance_count=1, + instance_type="ml.inf2.xlarge", + tags=[ + {"Key": JumpStartTag.MODEL_ID, "Value": "pytorch-eqa-bert-base-cased"}, + {"Key": JumpStartTag.MODEL_VERSION, "Value": "1.0.0"}, + ], + wait=True, + endpoint_logging=False, + ) + + mock_get_model_specs.reset_mock() + mock_model_deploy.reset_mock() + mock_get_model_specs.side_effect = get_prototype_model_spec + model.set_deployment_config(None) + + model.deploy() + + mock_model_deploy.assert_called_once_with( + initial_instance_count=1, + instance_type="ml.p2.xlarge", + tags=[ + {"Key": JumpStartTag.MODEL_ID, "Value": "pytorch-eqa-bert-base-cased"}, + {"Key": JumpStartTag.MODEL_VERSION, "Value": "1.0.0"}, + ], + wait=True, + endpoint_logging=False, + ) + def test_jumpstart_model_requires_model_id(): with pytest.raises(ValueError): diff --git a/tests/unit/sagemaker/jumpstart/test_types.py b/tests/unit/sagemaker/jumpstart/test_types.py index b2758c73ef..5ca01c3c52 100644 --- a/tests/unit/sagemaker/jumpstart/test_types.py +++ b/tests/unit/sagemaker/jumpstart/test_types.py @@ -1052,14 +1052,6 @@ def test_inference_configs_parsing(): ) assert list(config.config_components.keys()) == ["neuron-inference"] - spec = { - **BASE_SPEC, - **INFERENCE_CONFIGS, - **INFERENCE_CONFIG_RANKINGS, - "unrecognized-field": "blah", # New fields in base metadata fields should be ignored - } - specs1 = JumpStartModelSpecs(spec) - def test_set_inference_configs(): spec = {**BASE_SPEC, **INFERENCE_CONFIGS, **INFERENCE_CONFIG_RANKINGS} diff --git a/tests/unit/sagemaker/jumpstart/utils.py b/tests/unit/sagemaker/jumpstart/utils.py index e102251060..aee1497ec9 100644 --- a/tests/unit/sagemaker/jumpstart/utils.py +++ b/tests/unit/sagemaker/jumpstart/utils.py @@ -222,6 +222,23 @@ def get_base_spec_with_prototype_configs( return JumpStartModelSpecs(spec) +def get_prototype_spec_with_configs( + region: str = None, + model_id: str = None, + version: str = None, + s3_client: boto3.client = None, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, +) -> JumpStartModelSpecs: + spec = copy.deepcopy(PROTOTYPICAL_MODEL_SPECS_DICT[model_id]) + inference_configs = {**INFERENCE_CONFIGS, **INFERENCE_CONFIG_RANKINGS} + training_configs = {**TRAINING_CONFIGS, **TRAINING_CONFIG_RANKINGS} + + spec.update(inference_configs) + spec.update(training_configs) + + return JumpStartModelSpecs(spec) + + def patched_retrieval_function( _modelCacheObj: JumpStartModelsCache, key: JumpStartCachedS3ContentKey, From d717dcde31d38bfaf78108ef9958ee2e44a6aaba Mon Sep 17 00:00:00 2001 From: Jonathan Makunga <54963715+makungaj1@users.noreply.github.com> Date: Wed, 24 Apr 2024 10:22:35 -0700 Subject: [PATCH 18/32] Add ReadOnly APIs (#4606) * Add ReadOnly APIs * Resolving PR review comments * Resolve PR review comments * Refactoring * Refactoring * Add Caching * Refactore * Resolving conflicts * Add Unit Tests * Fix Unit Tests * Fix unit tests * Fix UT * Refactoring * Fix Integ tests * refactoring after Notebook testing * Fix code styles --------- Co-authored-by: Jonathan Makunga --- src/sagemaker/jumpstart/model.py | 109 ++++++- src/sagemaker/jumpstart/types.py | 96 ++++++ src/sagemaker/jumpstart/utils.py | 41 +++ .../serve/builder/jumpstart_builder.py | 14 +- src/sagemaker/utils.py | 69 +++++ tests/unit/sagemaker/jumpstart/constants.py | 157 ++++++++++ .../sagemaker/jumpstart/model/test_model.py | 285 ++++++++++++++++++ .../jumpstart/model/test_sagemaker_config.py | 32 ++ .../sagemaker/jumpstart/test_predictor.py | 4 + tests/unit/sagemaker/jumpstart/test_utils.py | 50 +++ tests/unit/sagemaker/jumpstart/utils.py | 22 +- .../serve/builder/test_js_builder.py | 85 ++++++ tests/unit/sagemaker/serve/constants.py | 150 +++++++++ tests/unit/test_utils.py | 76 +++++ 14 files changed, 1185 insertions(+), 5 deletions(-) diff --git a/src/sagemaker/jumpstart/model.py b/src/sagemaker/jumpstart/model.py index ad70ffa805..2addb0a044 100644 --- a/src/sagemaker/jumpstart/model.py +++ b/src/sagemaker/jumpstart/model.py @@ -14,7 +14,9 @@ from __future__ import absolute_import -from typing import Dict, List, Optional, Union +from functools import lru_cache +from typing import Dict, List, Optional, Union, Any +import pandas as pd from botocore.exceptions import ClientError from sagemaker import payloads @@ -36,14 +38,21 @@ get_init_kwargs, get_register_kwargs, ) -from sagemaker.jumpstart.types import JumpStartSerializablePayload +from sagemaker.jumpstart.types import ( + JumpStartSerializablePayload, + DeploymentConfigMetadata, + JumpStartBenchmarkStat, + JumpStartMetadataConfig, +) from sagemaker.jumpstart.utils import ( validate_model_id_and_get_type, verify_model_region_and_return_specs, + get_jumpstart_configs, + extract_metrics_from_deployment_configs, ) from sagemaker.jumpstart.constants import JUMPSTART_LOGGER from sagemaker.jumpstart.enums import JumpStartModelType -from sagemaker.utils import stringify_object, format_tags, Tags +from sagemaker.utils import stringify_object, format_tags, Tags, get_instance_rate_per_hour from sagemaker.model import ( Model, ModelPackage, @@ -352,6 +361,18 @@ def _validate_model_id_and_type(): self.model_package_arn = model_init_kwargs.model_package_arn self.init_kwargs = model_init_kwargs.to_kwargs_dict(False) + metadata_configs = get_jumpstart_configs( + region=self.region, + model_id=self.model_id, + model_version=self.model_version, + sagemaker_session=self.sagemaker_session, + model_type=self.model_type, + ) + self._deployment_configs = [ + self._convert_to_deployment_config_metadata(config_name, config) + for config_name, config in metadata_configs.items() + ] + def log_subscription_warning(self) -> None: """Log message prompting the customer to subscribe to the proprietary model.""" subscription_link = verify_model_region_and_return_specs( @@ -420,6 +441,27 @@ def set_deployment_config(self, config_name: Optional[str]) -> None: model_id=self.model_id, model_version=self.model_version, config_name=config_name ) + @property + def benchmark_metrics(self) -> pd.DataFrame: + """Benchmark Metrics for deployment configs + + Returns: + Metrics: Pandas DataFrame object. + """ + return pd.DataFrame(self._get_benchmark_data(self.config_name)) + + def display_benchmark_metrics(self) -> None: + """Display Benchmark Metrics for deployment configs.""" + print(self.benchmark_metrics.to_markdown()) + + def list_deployment_configs(self) -> List[Dict[str, Any]]: + """List deployment configs for ``This`` model. + + Returns: + List[Dict[str, Any]]: A list of deployment configs. + """ + return self._deployment_configs + def _create_sagemaker_model( self, instance_type=None, @@ -808,6 +850,67 @@ def register_deploy_wrapper(*args, **kwargs): return model_package + @lru_cache + def _get_benchmark_data(self, config_name: str) -> Dict[str, List[str]]: + """Constructs deployment configs benchmark data. + + Args: + config_name (str): The name of the selected deployment config. + Returns: + Dict[str, List[str]]: Deployment config benchmark data. + """ + return extract_metrics_from_deployment_configs( + self._deployment_configs, + config_name, + ) + + def _convert_to_deployment_config_metadata( + self, config_name: str, metadata_config: JumpStartMetadataConfig + ) -> Dict[str, Any]: + """Retrieve deployment config for config name. + + Args: + config_name (str): Name of deployment config. + metadata_config (JumpStartMetadataConfig): Metadata config for deployment config. + Returns: + A deployment metadata config for config name (dict[str, Any]). + """ + default_inference_instance_type = metadata_config.resolved_config.get( + "default_inference_instance_type" + ) + + instance_rate = get_instance_rate_per_hour( + instance_type=default_inference_instance_type, region=self.region + ) + + benchmark_metrics = ( + metadata_config.benchmark_metrics.get(default_inference_instance_type) + if metadata_config.benchmark_metrics is not None + else None + ) + if instance_rate is not None: + if benchmark_metrics is not None: + benchmark_metrics.append(JumpStartBenchmarkStat(instance_rate)) + else: + benchmark_metrics = [JumpStartBenchmarkStat(instance_rate)] + + init_kwargs = get_init_kwargs( + model_id=self.model_id, + instance_type=default_inference_instance_type, + sagemaker_session=self.sagemaker_session, + ) + deploy_kwargs = get_deploy_kwargs( + model_id=self.model_id, + instance_type=default_inference_instance_type, + sagemaker_session=self.sagemaker_session, + ) + + deployment_config_metadata = DeploymentConfigMetadata( + config_name, benchmark_metrics, init_kwargs, deploy_kwargs + ) + + return deployment_config_metadata.to_json() + def __str__(self) -> str: """Overriding str(*) method to make more human-readable.""" return stringify_object(self) diff --git a/src/sagemaker/jumpstart/types.py b/src/sagemaker/jumpstart/types.py index 1de0f662da..07bd769054 100644 --- a/src/sagemaker/jumpstart/types.py +++ b/src/sagemaker/jumpstart/types.py @@ -2206,3 +2206,99 @@ def __init__( self.skip_model_validation = skip_model_validation self.source_uri = source_uri self.config_name = config_name + + +class BaseDeploymentConfigDataHolder(JumpStartDataHolderType): + """Base class for Deployment Config Data.""" + + def _convert_to_pascal_case(self, attr_name: str) -> str: + """Converts a snake_case attribute name into a camelCased string. + + Args: + attr_name (str): The snake_case attribute name. + Returns: + str: The PascalCased attribute name. + """ + return attr_name.replace("_", " ").title().replace(" ", "") + + def to_json(self) -> Dict[str, Any]: + """Represents ``This`` object as JSON.""" + json_obj = {} + for att in self.__slots__: + if hasattr(self, att): + cur_val = getattr(self, att) + att = self._convert_to_pascal_case(att) + if issubclass(type(cur_val), JumpStartDataHolderType): + json_obj[att] = cur_val.to_json() + elif isinstance(cur_val, list): + json_obj[att] = [] + for obj in cur_val: + if issubclass(type(obj), JumpStartDataHolderType): + json_obj[att].append(obj.to_json()) + else: + json_obj[att].append(obj) + elif isinstance(cur_val, dict): + json_obj[att] = {} + for key, val in cur_val.items(): + if issubclass(type(val), JumpStartDataHolderType): + json_obj[att][self._convert_to_pascal_case(key)] = val.to_json() + else: + json_obj[att][key] = val + else: + json_obj[att] = cur_val + return json_obj + + +class DeploymentConfig(BaseDeploymentConfigDataHolder): + """Dataclass representing a Deployment Config.""" + + __slots__ = [ + "model_data_download_timeout", + "container_startup_health_check_timeout", + "image_uri", + "model_data", + "instance_type", + "environment", + "compute_resource_requirements", + ] + + def __init__( + self, init_kwargs: JumpStartModelInitKwargs, deploy_kwargs: JumpStartModelDeployKwargs + ): + """Instantiates DeploymentConfig object.""" + if init_kwargs is not None: + self.image_uri = init_kwargs.image_uri + self.model_data = init_kwargs.model_data + self.instance_type = init_kwargs.instance_type + self.environment = init_kwargs.env + if init_kwargs.resources is not None: + self.compute_resource_requirements = ( + init_kwargs.resources.get_compute_resource_requirements() + ) + if deploy_kwargs is not None: + self.model_data_download_timeout = deploy_kwargs.model_data_download_timeout + self.container_startup_health_check_timeout = ( + deploy_kwargs.container_startup_health_check_timeout + ) + + +class DeploymentConfigMetadata(BaseDeploymentConfigDataHolder): + """Dataclass representing a Deployment Config Metadata""" + + __slots__ = [ + "config_name", + "benchmark_metrics", + "deployment_config", + ] + + def __init__( + self, + config_name: str, + benchmark_metrics: List[JumpStartBenchmarkStat], + init_kwargs: JumpStartModelInitKwargs, + deploy_kwargs: JumpStartModelDeployKwargs, + ): + """Instantiates DeploymentConfigMetadata object.""" + self.config_name = config_name + self.benchmark_metrics = benchmark_metrics + self.deployment_config = DeploymentConfig(init_kwargs, deploy_kwargs) diff --git a/src/sagemaker/jumpstart/utils.py b/src/sagemaker/jumpstart/utils.py index 1459594faa..905f2a18d5 100644 --- a/src/sagemaker/jumpstart/utils.py +++ b/src/sagemaker/jumpstart/utils.py @@ -999,3 +999,44 @@ def get_jumpstart_configs( if metadata_configs else {} ) + + +def extract_metrics_from_deployment_configs( + deployment_configs: List[Dict[str, Any]], config_name: str +) -> Dict[str, List[str]]: + """Extracts metrics from deployment configs. + + Args: + deployment_configs (list[dict[str, Any]]): List of deployment configs. + config_name (str): The name of the deployment config use by the model. + """ + + data = {"Config Name": [], "Instance Type": [], "Selected": []} + + for index, deployment_config in enumerate(deployment_configs): + if deployment_config.get("DeploymentConfig") is None: + continue + + benchmark_metrics = deployment_config.get("BenchmarkMetrics") + if benchmark_metrics is not None: + data["Config Name"].append(deployment_config.get("ConfigName")) + data["Instance Type"].append( + deployment_config.get("DeploymentConfig").get("InstanceType") + ) + data["Selected"].append( + "Yes" + if (config_name is not None and config_name == deployment_config.get("ConfigName")) + else "No" + ) + + if index == 0: + for benchmark_metric in benchmark_metrics: + column_name = f"{benchmark_metric.get('name')} ({benchmark_metric.get('unit')})" + data[column_name] = [] + + for benchmark_metric in benchmark_metrics: + column_name = f"{benchmark_metric.get('name')} ({benchmark_metric.get('unit')})" + if column_name in data.keys(): + data[column_name].append(benchmark_metric.get("value")) + + return data diff --git a/src/sagemaker/serve/builder/jumpstart_builder.py b/src/sagemaker/serve/builder/jumpstart_builder.py index e3368869fe..c1760311e7 100644 --- a/src/sagemaker/serve/builder/jumpstart_builder.py +++ b/src/sagemaker/serve/builder/jumpstart_builder.py @@ -16,7 +16,7 @@ import copy from abc import ABC, abstractmethod from datetime import datetime, timedelta -from typing import Type +from typing import Type, Any, List, Dict import logging from sagemaker.model import Model @@ -431,6 +431,18 @@ def tune_for_tgi_jumpstart(self, max_tuning_duration: int = 1800): sharded_supported=sharded_supported, max_tuning_duration=max_tuning_duration ) + def display_benchmark_metrics(self): + """Display Markdown Benchmark Metrics for deployment configs.""" + self.pysdk_model.display_benchmark_metrics() + + def list_deployment_configs(self) -> List[Dict[str, Any]]: + """List deployment configs for ``This`` model in the current region. + + Returns: + List[Dict[str, Any]]: A list of deployment configs. + """ + return self.pysdk_model.list_deployment_configs() + def _build_for_jumpstart(self): """Placeholder docstring""" # we do not pickle for jumpstart. set to none diff --git a/src/sagemaker/utils.py b/src/sagemaker/utils.py index 0436c0afea..35f60b37e1 100644 --- a/src/sagemaker/utils.py +++ b/src/sagemaker/utils.py @@ -25,6 +25,7 @@ import tarfile import tempfile import time +from functools import lru_cache from typing import Union, Any, List, Optional, Dict import json import abc @@ -33,6 +34,8 @@ from os.path import abspath, realpath, dirname, normpath, join as joinpath from importlib import import_module + +import boto3 import botocore from botocore.utils import merge_dicts from six.moves.urllib import parse @@ -1655,3 +1658,69 @@ def deep_override_dict( ) flattened_dict1.update(flattened_dict2) return unflatten_dict(flattened_dict1) if flattened_dict1 else {} + + +@lru_cache +def get_instance_rate_per_hour( + instance_type: str, + region: str, +) -> Union[Dict[str, str], None]: + """Gets instance rate per hour for the given instance type. + + Args: + instance_type (str): The instance type. + region (str): The region. + Returns: + Union[Dict[str, str], None]: Instance rate per hour. + Example: {'name': 'Instance Rate', 'unit': 'USD/Hrs', 'value': '1.1250000000'}}. + """ + + region_name = "us-east-1" + if region.startswith("eu") or region.startswith("af"): + region_name = "eu-central-1" + elif region.startswith("ap") or region.startswith("cn"): + region_name = "ap-south-1" + + pricing_client: boto3.client = boto3.client("pricing", region_name=region_name) + try: + res = pricing_client.get_products( + ServiceCode="AmazonSageMaker", + Filters=[ + {"Type": "TERM_MATCH", "Field": "instanceName", "Value": instance_type}, + {"Type": "TERM_MATCH", "Field": "locationType", "Value": "AWS Region"}, + {"Type": "TERM_MATCH", "Field": "regionCode", "Value": region}, + ], + ) + + price_list = res.get("PriceList", []) + if len(price_list) > 0: + price_data = price_list[0] + if isinstance(price_data, str): + price_data = json.loads(price_data) + + return extract_instance_rate_per_hour(price_data) + except Exception as e: # pylint: disable=W0703 + logging.exception("Error getting instance rate: %s", e) + return None + + +def extract_instance_rate_per_hour(price_data: Dict[str, Any]) -> Union[Dict[str, str], None]: + """Extract instance rate per hour for the given Price JSON data. + + Args: + price_data (Dict[str, Any]): The Price JSON data. + Returns: + Union[Dict[str, str], None]: Instance rate per hour. + """ + + if price_data is not None: + price_dimensions = price_data.get("terms", {}).get("OnDemand", {}).values() + for dimension in price_dimensions: + for price in dimension.get("priceDimensions", {}).values(): + for currency in price.get("pricePerUnit", {}).keys(): + return { + "unit": f"{currency}/{price.get('unit', 'Hrs')}", + "value": price.get("pricePerUnit", {}).get(currency), + "name": "Instance Rate", + } + return None diff --git a/tests/unit/sagemaker/jumpstart/constants.py b/tests/unit/sagemaker/jumpstart/constants.py index f165a513a9..b83f85ffde 100644 --- a/tests/unit/sagemaker/jumpstart/constants.py +++ b/tests/unit/sagemaker/jumpstart/constants.py @@ -7907,3 +7907,160 @@ }, } } + + +DEPLOYMENT_CONFIGS = [ + { + "ConfigName": "neuron-inference", + "BenchmarkMetrics": [{"name": "Instance Rate", "value": "0.0083000000", "unit": "USD/Hrs"}], + "DeploymentConfig": { + "ModelDataDownloadTimeout": None, + "ContainerStartupHealthCheckTimeout": None, + "ImageUri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4" + ".0-gpu-py310-cu121-ubuntu20.04", + "ModelData": { + "S3DataSource": { + "S3Uri": "s3://jumpstart-cache-alpha-us-west-2/huggingface-textgeneration/huggingface" + "-textgeneration-bloom-1b1/artifacts/inference-prepack/v4.0.0/", + "S3DataType": "S3Prefix", + "CompressionType": "None", + } + }, + "InstanceType": "ml.p2.xlarge", + "Environment": { + "SAGEMAKER_PROGRAM": "inference.py", + "ENDPOINT_SERVER_TIMEOUT": "3600", + "MODEL_CACHE_ROOT": "/opt/ml/model", + "SAGEMAKER_ENV": "1", + "HF_MODEL_ID": "/opt/ml/model", + "SM_NUM_GPUS": "1", + "MAX_INPUT_LENGTH": "2047", + "MAX_TOTAL_TOKENS": "2048", + "SAGEMAKER_MODEL_SERVER_WORKERS": "1", + }, + "ComputeResourceRequirements": {"MinMemoryRequiredInMb": None}, + }, + }, + { + "ConfigName": "neuron-inference-budget", + "BenchmarkMetrics": [{"name": "Instance Rate", "value": "0.0083000000", "unit": "USD/Hrs"}], + "DeploymentConfig": { + "ModelDataDownloadTimeout": None, + "ContainerStartupHealthCheckTimeout": None, + "ImageUri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4" + ".0-gpu-py310-cu121-ubuntu20.04", + "ModelData": { + "S3DataSource": { + "S3Uri": "s3://jumpstart-cache-alpha-us-west-2/huggingface-textgeneration/huggingface" + "-textgeneration-bloom-1b1/artifacts/inference-prepack/v4.0.0/", + "S3DataType": "S3Prefix", + "CompressionType": "None", + } + }, + "InstanceType": "ml.p2.xlarge", + "Environment": { + "SAGEMAKER_PROGRAM": "inference.py", + "ENDPOINT_SERVER_TIMEOUT": "3600", + "MODEL_CACHE_ROOT": "/opt/ml/model", + "SAGEMAKER_ENV": "1", + "HF_MODEL_ID": "/opt/ml/model", + "SM_NUM_GPUS": "1", + "MAX_INPUT_LENGTH": "2047", + "MAX_TOTAL_TOKENS": "2048", + "SAGEMAKER_MODEL_SERVER_WORKERS": "1", + }, + "ComputeResourceRequirements": {"MinMemoryRequiredInMb": None}, + }, + }, + { + "ConfigName": "gpu-inference-budget", + "BenchmarkMetrics": [{"name": "Instance Rate", "value": "0.0083000000", "unit": "USD/Hrs"}], + "DeploymentConfig": { + "ModelDataDownloadTimeout": None, + "ContainerStartupHealthCheckTimeout": None, + "ImageUri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4" + ".0-gpu-py310-cu121-ubuntu20.04", + "ModelData": { + "S3DataSource": { + "S3Uri": "s3://jumpstart-cache-alpha-us-west-2/huggingface-textgeneration/huggingface" + "-textgeneration-bloom-1b1/artifacts/inference-prepack/v4.0.0/", + "S3DataType": "S3Prefix", + "CompressionType": "None", + } + }, + "InstanceType": "ml.p2.xlarge", + "Environment": { + "SAGEMAKER_PROGRAM": "inference.py", + "ENDPOINT_SERVER_TIMEOUT": "3600", + "MODEL_CACHE_ROOT": "/opt/ml/model", + "SAGEMAKER_ENV": "1", + "HF_MODEL_ID": "/opt/ml/model", + "SM_NUM_GPUS": "1", + "MAX_INPUT_LENGTH": "2047", + "MAX_TOTAL_TOKENS": "2048", + "SAGEMAKER_MODEL_SERVER_WORKERS": "1", + }, + "ComputeResourceRequirements": {"MinMemoryRequiredInMb": None}, + }, + }, + { + "ConfigName": "gpu-inference", + "BenchmarkMetrics": [{"name": "Instance Rate", "value": "0.0083000000", "unit": "USD/Hrs"}], + "DeploymentConfig": { + "ModelDataDownloadTimeout": None, + "ContainerStartupHealthCheckTimeout": None, + "ImageUri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4" + ".0-gpu-py310-cu121-ubuntu20.04", + "ModelData": { + "S3DataSource": { + "S3Uri": "s3://jumpstart-cache-alpha-us-west-2/huggingface-textgeneration/huggingface" + "-textgeneration-bloom-1b1/artifacts/inference-prepack/v4.0.0/", + "S3DataType": "S3Prefix", + "CompressionType": "None", + } + }, + "InstanceType": "ml.p2.xlarge", + "Environment": { + "SAGEMAKER_PROGRAM": "inference.py", + "ENDPOINT_SERVER_TIMEOUT": "3600", + "MODEL_CACHE_ROOT": "/opt/ml/model", + "SAGEMAKER_ENV": "1", + "HF_MODEL_ID": "/opt/ml/model", + "SM_NUM_GPUS": "1", + "MAX_INPUT_LENGTH": "2047", + "MAX_TOTAL_TOKENS": "2048", + "SAGEMAKER_MODEL_SERVER_WORKERS": "1", + }, + "ComputeResourceRequirements": {"MinMemoryRequiredInMb": None}, + }, + }, +] + + +INIT_KWARGS = { + "image_uri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.0-gpu" + "-py310-cu121-ubuntu20.04", + "model_data": { + "S3DataSource": { + "S3Uri": "s3://jumpstart-cache-alpha-us-west-2/huggingface-textgeneration/huggingface-textgeneration" + "-bloom-1b1/artifacts/inference-prepack/v4.0.0/", + "S3DataType": "S3Prefix", + "CompressionType": "None", + } + }, + "instance_type": "ml.p2.xlarge", + "env": { + "SAGEMAKER_PROGRAM": "inference.py", + "ENDPOINT_SERVER_TIMEOUT": "3600", + "MODEL_CACHE_ROOT": "/opt/ml/model", + "SAGEMAKER_ENV": "1", + "HF_MODEL_ID": "/opt/ml/model", + "SM_NUM_GPUS": "1", + "MAX_INPUT_LENGTH": "2047", + "MAX_TOTAL_TOKENS": "2048", + "SAGEMAKER_MODEL_SERVER_WORKERS": "1", + }, + "role": "arn:aws:iam::312206380606:role/service-role/AmazonSageMaker-ExecutionRole-20230707T131628", + "name": "hf-textgeneration-bloom-1b1-2024-04-22-20-23-48-799", + "enable_network_isolation": True, +} diff --git a/tests/unit/sagemaker/jumpstart/model/test_model.py b/tests/unit/sagemaker/jumpstart/model/test_model.py index cb7b602fbf..2df904dce2 100644 --- a/tests/unit/sagemaker/jumpstart/model/test_model.py +++ b/tests/unit/sagemaker/jumpstart/model/test_model.py @@ -15,6 +15,7 @@ from typing import Optional, Set from unittest import mock import unittest +import pandas as pd from mock import MagicMock, Mock import pytest from sagemaker.async_inference.async_inference_config import AsyncInferenceConfig @@ -47,6 +48,9 @@ get_special_model_spec_for_inference_component_based_endpoint, get_prototype_manifest, get_prototype_model_spec, + get_base_spec_with_prototype_configs, + get_mock_init_kwargs, + get_base_deployment_configs, ) import boto3 @@ -64,6 +68,9 @@ class ModelTest(unittest.TestCase): mock_session_empty_config = MagicMock(sagemaker_config={}) + @mock.patch( + "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {} + ) @mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_LOGGER") @mock.patch("sagemaker.utils.sagemaker_timestamp") @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @@ -81,6 +88,7 @@ def test_non_prepacked( mock_validate_model_id_and_get_type: mock.Mock, mock_sagemaker_timestamp: mock.Mock, mock_jumpstart_model_factory_logger: mock.Mock, + mock_get_jumpstart_configs: mock.Mock, ): mock_model_deploy.return_value = default_predictor @@ -140,6 +148,9 @@ def test_non_prepacked( endpoint_logging=False, ) + @mock.patch( + "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {} + ) @mock.patch("sagemaker.utils.sagemaker_timestamp") @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.factory.model.Session") @@ -155,6 +166,7 @@ def test_non_prepacked_inference_component_based_endpoint( mock_session: mock.Mock, mock_validate_model_id_and_get_type: mock.Mock, mock_sagemaker_timestamp: mock.Mock, + mock_get_jumpstart_configs: mock.Mock, ): mock_model_deploy.return_value = default_predictor @@ -220,6 +232,9 @@ def test_non_prepacked_inference_component_based_endpoint( endpoint_type=EndpointType.INFERENCE_COMPONENT_BASED, ) + @mock.patch( + "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {} + ) @mock.patch("sagemaker.utils.sagemaker_timestamp") @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.factory.model.Session") @@ -235,6 +250,7 @@ def test_non_prepacked_inference_component_based_endpoint_no_default_pass_custom mock_session: mock.Mock, mock_validate_model_id_and_get_type: mock.Mock, mock_sagemaker_timestamp: mock.Mock, + mock_get_jumpstart_configs: mock.Mock, ): mock_model_deploy.return_value = default_predictor @@ -295,6 +311,9 @@ def test_non_prepacked_inference_component_based_endpoint_no_default_pass_custom endpoint_type=EndpointType.INFERENCE_COMPONENT_BASED, ) + @mock.patch( + "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {} + ) @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.factory.model.Session") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @@ -308,6 +327,7 @@ def test_prepacked( mock_get_model_specs: mock.Mock, mock_session: mock.Mock, mock_validate_model_id_and_get_type: mock.Mock, + mock_get_jumpstart_configs: mock.Mock, ): mock_model_deploy.return_value = default_predictor @@ -354,6 +374,9 @@ def test_prepacked( endpoint_logging=False, ) + @mock.patch( + "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {} + ) @mock.patch("sagemaker.model.LOGGER.warning") @mock.patch("sagemaker.utils.sagemaker_timestamp") @mock.patch("sagemaker.session.Session.endpoint_from_production_variants") @@ -371,6 +394,7 @@ def test_no_compiled_model_warning_log_js_models( mock_endpoint_from_production_variants: mock.Mock, mock_timestamp: mock.Mock, mock_warning: mock.Mock(), + mock_get_jumpstart_configs: mock.Mock, ): mock_timestamp.return_value = "1234" @@ -391,6 +415,9 @@ def test_no_compiled_model_warning_log_js_models( mock_warning.assert_not_called() + @mock.patch( + "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {} + ) @mock.patch("sagemaker.utils.sagemaker_timestamp") @mock.patch("sagemaker.session.Session.endpoint_from_production_variants") @mock.patch("sagemaker.session.Session.create_model") @@ -406,6 +433,7 @@ def test_eula_gated_conditional_s3_prefix_metadata_model( mock_create_model: mock.Mock, mock_endpoint_from_production_variants: mock.Mock, mock_timestamp: mock.Mock, + mock_get_jumpstart_configs: mock.Mock, ): mock_timestamp.return_value = "1234" @@ -453,6 +481,7 @@ def test_eula_gated_conditional_s3_prefix_metadata_model( ], ) + @mock.patch("sagemaker.jumpstart.model.get_jumpstart_configs") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") @mock.patch("sagemaker.utils.sagemaker_timestamp") @mock.patch("sagemaker.jumpstart.utils.validate_model_id_and_get_type") @@ -470,7 +499,9 @@ def test_proprietary_model_endpoint( mock_validate_model_id_and_get_type: mock.Mock, mock_sagemaker_timestamp: mock.Mock, mock_get_manifest: mock.Mock, + mock_get_jumpstart_configs: mock.Mock, ): + mock_get_jumpstart_configs.side_effect = lambda *args, **kwargs: {} mock_get_manifest.side_effect = ( lambda region, model_type, *args, **kwargs: get_prototype_manifest(region, model_type) ) @@ -510,6 +541,7 @@ def test_proprietary_model_endpoint( container_startup_health_check_timeout=600, ) + @mock.patch("sagemaker.jumpstart.model.get_jumpstart_configs") @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.model.Model.__init__") @@ -521,7 +553,9 @@ def test_deprecated( mock_model_init: mock.Mock, mock_get_model_specs: mock.Mock, mock_validate_model_id_and_get_type: mock.Mock, + mock_get_jumpstart_configs: mock.Mock, ): + mock_get_jumpstart_configs.side_effect = lambda *args, **kwargs: {} mock_model_deploy.return_value = default_predictor mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS @@ -537,6 +571,9 @@ def test_deprecated( JumpStartModel(model_id=model_id, tolerate_deprecated_model=True).deploy() + @mock.patch( + "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {} + ) @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.model.Model.__init__") @@ -548,6 +585,7 @@ def test_vulnerable( mock_model_init: mock.Mock, mock_get_model_specs: mock.Mock, mock_validate_model_id_and_get_type: mock.Mock, + mock_get_jumpstart_configs: mock.Mock, ): mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS @@ -611,6 +649,9 @@ def test_model_use_kwargs(self): deploy_kwargs=all_deploy_kwargs_used, ) + @mock.patch( + "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {} + ) @mock.patch("sagemaker.jumpstart.factory.model.environment_variables.retrieve_default") @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.factory.model.Session") @@ -626,6 +667,7 @@ def evaluate_model_workflow_with_kwargs( mock_session: mock.Mock, mock_validate_model_id_and_get_type: mock.Mock, mock_retrieve_environment_variables: mock.Mock, + mock_get_jumpstart_configs: mock.Mock, init_kwargs: Optional[dict] = None, deploy_kwargs: Optional[dict] = None, ): @@ -729,6 +771,9 @@ def test_jumpstart_model_kwargs_match_parent_class(self): assert js_class_deploy_args - parent_class_deploy_args == set() assert parent_class_deploy_args - js_class_deploy_args == deploy_args_to_skip + @mock.patch( + "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {} + ) @mock.patch("sagemaker.jumpstart.model.get_init_kwargs") @mock.patch("sagemaker.jumpstart.model.Model.__init__") @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @@ -737,6 +782,7 @@ def test_validate_model_id_and_get_type( mock_validate_model_id_and_get_type: mock.Mock, mock_init: mock.Mock, mock_get_init_kwargs: mock.Mock, + mock_get_jumpstart_configs: mock.Mock, ): mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS JumpStartModel(model_id="valid_model_id") @@ -745,6 +791,9 @@ def test_validate_model_id_and_get_type( with pytest.raises(ValueError): JumpStartModel(model_id="invalid_model_id") + @mock.patch( + "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {} + ) @mock.patch("sagemaker.jumpstart.model.get_default_predictor") @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.factory.model.Session") @@ -760,6 +809,7 @@ def test_no_predictor_returns_default_predictor( mock_session: mock.Mock, mock_validate_model_id_and_get_type: mock.Mock, mock_get_default_predictor: mock.Mock, + mock_get_jumpstart_configs: mock.Mock, ): mock_get_default_predictor.return_value = default_predictor_with_presets @@ -793,6 +843,9 @@ def test_no_predictor_returns_default_predictor( self.assertEqual(type(predictor), Predictor) self.assertEqual(predictor, default_predictor_with_presets) + @mock.patch( + "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {} + ) @mock.patch("sagemaker.jumpstart.model.get_default_predictor") @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.factory.model.Session") @@ -808,6 +861,7 @@ def test_no_predictor_yes_async_inference_config( mock_session: mock.Mock, mock_validate_model_id_and_get_type: mock.Mock, mock_get_default_predictor: mock.Mock, + mock_get_jumpstart_configs: mock.Mock, ): mock_get_default_predictor.return_value = default_predictor_with_presets @@ -829,6 +883,9 @@ def test_no_predictor_yes_async_inference_config( mock_get_default_predictor.assert_not_called() + @mock.patch( + "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {} + ) @mock.patch("sagemaker.jumpstart.model.get_default_predictor") @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.factory.model.Session") @@ -844,6 +901,7 @@ def test_yes_predictor_returns_default_predictor( mock_session: mock.Mock, mock_validate_model_id_and_get_type: mock.Mock, mock_get_default_predictor: mock.Mock, + mock_get_jumpstart_configs: mock.Mock, ): mock_get_default_predictor.return_value = default_predictor_with_presets @@ -865,6 +923,9 @@ def test_yes_predictor_returns_default_predictor( self.assertEqual(type(predictor), Predictor) self.assertEqual(predictor, default_predictor) + @mock.patch( + "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {} + ) @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.model.Model.__init__") @mock.patch("sagemaker.jumpstart.factory.model._retrieve_model_init_kwargs") @@ -880,6 +941,7 @@ def test_model_id_not_found_refeshes_cache_inference( mock_retrieve_kwargs: mock.Mock, mock_model_init: mock.Mock, mock_validate_model_id_and_get_type: mock.Mock, + mock_get_jumpstart_configs: mock.Mock, ): mock_validate_model_id_and_get_type.side_effect = [False, False] @@ -948,6 +1010,9 @@ def test_model_id_not_found_refeshes_cache_inference( ] ) + @mock.patch( + "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {} + ) @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region) @@ -955,6 +1020,7 @@ def test_jumpstart_model_tags( self, mock_get_model_specs: mock.Mock, mock_validate_model_id_and_get_type: mock.Mock, + mock_get_jumpstart_configs: mock.Mock, ): mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS @@ -984,6 +1050,9 @@ def test_jumpstart_model_tags( [{"Key": "blah", "Value": "blahagain"}] + js_tags, ) + @mock.patch( + "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {} + ) @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region) @@ -991,6 +1060,7 @@ def test_jumpstart_model_tags_disabled( self, mock_get_model_specs: mock.Mock, mock_validate_model_id_and_get_type: mock.Mock, + mock_get_jumpstart_configs: mock.Mock, ): mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS @@ -1018,6 +1088,9 @@ def test_jumpstart_model_tags_disabled( [{"Key": "blah", "Value": "blahagain"}], ) + @mock.patch( + "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {} + ) @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region) @@ -1025,6 +1098,7 @@ def test_jumpstart_model_package_arn( self, mock_get_model_specs: mock.Mock, mock_validate_model_id_and_get_type: mock.Mock, + mock_get_jumpstart_configs: mock.Mock, ): mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS @@ -1052,6 +1126,9 @@ def test_jumpstart_model_package_arn( self.assertIn(tag, mock_session.create_model.call_args[1]["tags"]) + @mock.patch( + "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {} + ) @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region) @@ -1059,6 +1136,7 @@ def test_jumpstart_model_package_arn_override( self, mock_get_model_specs: mock.Mock, mock_validate_model_id_and_get_type: mock.Mock, + mock_get_jumpstart_configs: mock.Mock, ): mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS @@ -1094,6 +1172,9 @@ def test_jumpstart_model_package_arn_override( }, ) + @mock.patch( + "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {} + ) @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.factory.model.Session") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @@ -1103,6 +1184,7 @@ def test_jumpstart_model_package_arn_unsupported_region( mock_get_model_specs: mock.Mock, mock_session: mock.Mock, mock_validate_model_id_and_get_type: mock.Mock, + mock_get_jumpstart_configs: mock.Mock, ): mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS @@ -1120,6 +1202,9 @@ def test_jumpstart_model_package_arn_unsupported_region( "us-east-2. Please try one of the following regions: us-west-2, us-east-1." ) + @mock.patch( + "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {} + ) @mock.patch("sagemaker.utils.sagemaker_timestamp") @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.factory.model.Session") @@ -1137,6 +1222,7 @@ def test_model_data_s3_prefix_override( mock_session: mock.Mock, mock_validate_model_id_and_get_type: mock.Mock, mock_sagemaker_timestamp: mock.Mock, + mock_get_jumpstart_configs: mock.Mock, ): mock_model_deploy.return_value = default_predictor @@ -1186,6 +1272,9 @@ def test_model_data_s3_prefix_override( '"S3DataType": "S3Prefix", "CompressionType": "None"}}', ) + @mock.patch( + "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {} + ) @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.factory.model.Session") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @@ -1201,6 +1290,7 @@ def test_model_data_s3_prefix_model( mock_get_model_specs: mock.Mock, mock_session: mock.Mock, mock_validate_model_id_and_get_type: mock.Mock, + mock_get_jumpstart_configs: mock.Mock, ): mock_model_deploy.return_value = default_predictor @@ -1230,6 +1320,9 @@ def test_model_data_s3_prefix_model( mock_js_info_logger.assert_not_called() + @mock.patch( + "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {} + ) @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.factory.model.Session") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @@ -1245,6 +1338,7 @@ def test_model_artifact_variant_model( mock_get_model_specs: mock.Mock, mock_session: mock.Mock, mock_validate_model_id_and_get_type: mock.Mock, + mock_get_jumpstart_configs: mock.Mock, ): mock_model_deploy.return_value = default_predictor @@ -1295,6 +1389,9 @@ def test_model_artifact_variant_model( enable_network_isolation=True, ) + @mock.patch( + "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {} + ) @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.factory.model.Session") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @@ -1308,6 +1405,7 @@ def test_model_registry_accept_and_response_types( mock_get_model_specs: mock.Mock, mock_session: mock.Mock, mock_validate_model_id_and_get_type: mock.Mock, + mock_get_jumpstart_configs: mock.Mock, ): mock_model_deploy.return_value = default_predictor @@ -1327,6 +1425,9 @@ def test_model_registry_accept_and_response_types( response_types=["application/json;verbose", "application/json"], ) + @mock.patch( + "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {} + ) @mock.patch("sagemaker.jumpstart.model.get_default_predictor") @mock.patch("sagemaker.jumpstart.model.Model.__init__") @mock.patch("sagemaker.jumpstart.model.Model.deploy") @@ -1340,6 +1441,7 @@ def test_jumpstart_model_session( mock_deploy, mock_init, get_default_predictor, + mock_get_jumpstart_configs: mock.Mock, ): mock_validate_model_id_and_get_type.return_value = True @@ -1373,6 +1475,9 @@ def test_jumpstart_model_session( assert len(s3_clients) == 1 assert list(s3_clients)[0] == session.s3_client + @mock.patch( + "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {} + ) @mock.patch.dict( "sagemaker.jumpstart.cache.os.environ", { @@ -1391,6 +1496,7 @@ def test_model_local_mode( mock_get_model_specs: mock.Mock, mock_session: mock.Mock, mock_get_manifest: mock.Mock, + mock_get_jumpstart_configs: mock.Mock, ): mock_get_model_specs.side_effect = get_prototype_model_spec mock_get_manifest.side_effect = ( @@ -1417,6 +1523,9 @@ def test_model_local_mode( endpoint_logging=False, ) + @mock.patch( + "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {} + ) @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") @mock.patch("sagemaker.jumpstart.factory.model.Session") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @@ -1428,6 +1537,7 @@ def test_model_initialization_with_config_name( mock_get_model_specs: mock.Mock, mock_session: mock.Mock, mock_get_manifest: mock.Mock, + mock_get_jumpstart_configs: mock.Mock, ): mock_get_model_specs.side_effect = get_prototype_spec_with_configs mock_get_manifest.side_effect = ( @@ -1454,6 +1564,9 @@ def test_model_initialization_with_config_name( endpoint_logging=False, ) + @mock.patch( + "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {} + ) @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") @mock.patch("sagemaker.jumpstart.factory.model.Session") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @@ -1465,6 +1578,7 @@ def test_model_set_deployment_config( mock_get_model_specs: mock.Mock, mock_session: mock.Mock, mock_get_manifest: mock.Mock, + mock_get_jumpstart_configs: mock.Mock, ): mock_get_model_specs.side_effect = get_prototype_model_spec mock_get_manifest.side_effect = ( @@ -1509,6 +1623,9 @@ def test_model_set_deployment_config( endpoint_logging=False, ) + @mock.patch( + "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {} + ) @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") @mock.patch("sagemaker.jumpstart.factory.model.Session") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @@ -1520,6 +1637,7 @@ def test_model_unset_deployment_config( mock_get_model_specs: mock.Mock, mock_session: mock.Mock, mock_get_manifest: mock.Mock, + mock_get_jumpstart_configs: mock.Mock, ): mock_get_model_specs.side_effect = get_prototype_spec_with_configs mock_get_manifest.side_effect = ( @@ -1564,6 +1682,173 @@ def test_model_unset_deployment_config( endpoint_logging=False, ) + @mock.patch("sagemaker.jumpstart.model.get_init_kwargs") + @mock.patch("sagemaker.jumpstart.utils.verify_model_region_and_return_specs") + @mock.patch("sagemaker.jumpstart.model.get_instance_rate_per_hour") + @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") + @mock.patch("sagemaker.jumpstart.factory.model.Session") + @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") + @mock.patch("sagemaker.jumpstart.model.Model.deploy") + @mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region) + def test_model_list_deployment_configs( + self, + mock_model_deploy: mock.Mock, + mock_get_model_specs: mock.Mock, + mock_session: mock.Mock, + mock_get_manifest: mock.Mock, + mock_get_instance_rate_per_hour: mock.Mock, + mock_verify_model_region_and_return_specs: mock.Mock, + mock_get_init_kwargs: mock.Mock, + ): + model_id, _ = "pytorch-eqa-bert-base-cased", "*" + + mock_get_init_kwargs.side_effect = lambda *args, **kwargs: get_mock_init_kwargs(model_id) + mock_verify_model_region_and_return_specs.side_effect = ( + lambda *args, **kwargs: get_base_spec_with_prototype_configs() + ) + mock_get_instance_rate_per_hour.side_effect = lambda *args, **kwargs: { + "name": "Instance Rate", + "unit": "USD/Hrs", + "value": "0.0083000000", + } + mock_get_model_specs.side_effect = get_prototype_spec_with_configs + mock_get_manifest.side_effect = ( + lambda region, model_type, *args, **kwargs: get_prototype_manifest(region, model_type) + ) + mock_model_deploy.return_value = default_predictor + + mock_session.return_value = sagemaker_session + + model = JumpStartModel(model_id=model_id) + + configs = model.list_deployment_configs() + + self.assertEqual(configs, get_base_deployment_configs()) + + @mock.patch("sagemaker.jumpstart.utils.verify_model_region_and_return_specs") + @mock.patch("sagemaker.jumpstart.model.get_instance_rate_per_hour") + @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") + @mock.patch("sagemaker.jumpstart.factory.model.Session") + @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") + @mock.patch("sagemaker.jumpstart.model.Model.deploy") + @mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region) + def test_model_list_deployment_configs_empty( + self, + mock_model_deploy: mock.Mock, + mock_get_model_specs: mock.Mock, + mock_session: mock.Mock, + mock_get_manifest: mock.Mock, + mock_get_instance_rate_per_hour: mock.Mock, + mock_verify_model_region_and_return_specs: mock.Mock, + ): + model_id, _ = "pytorch-eqa-bert-base-cased", "*" + + mock_verify_model_region_and_return_specs.side_effect = ( + lambda *args, **kwargs: get_special_model_spec(model_id="gemma-model") + ) + mock_get_instance_rate_per_hour.side_effect = lambda *args, **kwargs: { + "name": "Instance Rate", + "unit": "USD/Hrs", + "value": "0.0083000000", + } + mock_get_model_specs.side_effect = get_prototype_spec_with_configs + mock_get_manifest.side_effect = ( + lambda region, model_type, *args, **kwargs: get_prototype_manifest(region, model_type) + ) + mock_model_deploy.return_value = default_predictor + + mock_session.return_value = sagemaker_session + + model = JumpStartModel(model_id=model_id) + + configs = model.list_deployment_configs() + + self.assertTrue(len(configs) == 0) + + @mock.patch("sagemaker.jumpstart.model.get_init_kwargs") + @mock.patch("sagemaker.jumpstart.utils.verify_model_region_and_return_specs") + @mock.patch("sagemaker.jumpstart.model.get_instance_rate_per_hour") + @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") + @mock.patch("sagemaker.jumpstart.factory.model.Session") + @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") + @mock.patch("sagemaker.jumpstart.model.Model.deploy") + @mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region) + def test_model_display_benchmark_metrics( + self, + mock_model_deploy: mock.Mock, + mock_get_model_specs: mock.Mock, + mock_session: mock.Mock, + mock_get_manifest: mock.Mock, + mock_get_instance_rate_per_hour: mock.Mock, + mock_verify_model_region_and_return_specs: mock.Mock, + mock_get_init_kwargs: mock.Mock, + ): + model_id, _ = "pytorch-eqa-bert-base-cased", "*" + + mock_get_init_kwargs.side_effect = lambda *args, **kwargs: get_mock_init_kwargs(model_id) + mock_verify_model_region_and_return_specs.side_effect = ( + lambda *args, **kwargs: get_base_spec_with_prototype_configs() + ) + mock_get_instance_rate_per_hour.side_effect = lambda *args, **kwargs: { + "name": "Instance Rate", + "unit": "USD/Hrs", + "value": "0.0083000000", + } + mock_get_model_specs.side_effect = get_prototype_spec_with_configs + mock_get_manifest.side_effect = ( + lambda region, model_type, *args, **kwargs: get_prototype_manifest(region, model_type) + ) + mock_model_deploy.return_value = default_predictor + + mock_session.return_value = sagemaker_session + + model = JumpStartModel(model_id=model_id) + + model.display_benchmark_metrics() + + @mock.patch("sagemaker.jumpstart.model.get_init_kwargs") + @mock.patch("sagemaker.jumpstart.utils.verify_model_region_and_return_specs") + @mock.patch("sagemaker.jumpstart.model.get_instance_rate_per_hour") + @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") + @mock.patch("sagemaker.jumpstart.factory.model.Session") + @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") + @mock.patch("sagemaker.jumpstart.model.Model.deploy") + @mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region) + def test_model_benchmark_metrics( + self, + mock_model_deploy: mock.Mock, + mock_get_model_specs: mock.Mock, + mock_session: mock.Mock, + mock_get_manifest: mock.Mock, + mock_get_instance_rate_per_hour: mock.Mock, + mock_verify_model_region_and_return_specs: mock.Mock, + mock_get_init_kwargs: mock.Mock, + ): + model_id, _ = "pytorch-eqa-bert-base-cased", "*" + + mock_get_init_kwargs.side_effect = lambda *args, **kwargs: get_mock_init_kwargs(model_id) + mock_verify_model_region_and_return_specs.side_effect = ( + lambda *args, **kwargs: get_base_spec_with_prototype_configs() + ) + mock_get_instance_rate_per_hour.side_effect = lambda *args, **kwargs: { + "name": "Instance Rate", + "unit": "USD/Hrs", + "value": "0.0083000000", + } + mock_get_model_specs.side_effect = get_prototype_spec_with_configs + mock_get_manifest.side_effect = ( + lambda region, model_type, *args, **kwargs: get_prototype_manifest(region, model_type) + ) + mock_model_deploy.return_value = default_predictor + + mock_session.return_value = sagemaker_session + + model = JumpStartModel(model_id=model_id) + + df = model.benchmark_metrics + + self.assertTrue(isinstance(df, pd.DataFrame)) + def test_jumpstart_model_requires_model_id(): with pytest.raises(ValueError): diff --git a/tests/unit/sagemaker/jumpstart/model/test_sagemaker_config.py b/tests/unit/sagemaker/jumpstart/model/test_sagemaker_config.py index 70409704e6..2be4bde7e4 100644 --- a/tests/unit/sagemaker/jumpstart/model/test_sagemaker_config.py +++ b/tests/unit/sagemaker/jumpstart/model/test_sagemaker_config.py @@ -60,6 +60,9 @@ class IntelligentDefaultsModelTest(unittest.TestCase): region = "us-west-2" sagemaker_session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION + @mock.patch( + "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {} + ) @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.model.Model.__init__") @mock.patch("sagemaker.jumpstart.factory.model._retrieve_model_init_kwargs") @@ -77,6 +80,7 @@ def test_without_arg_overwrites_without_kwarg_collisions_with_config( mock_retrieve_kwargs: mock.Mock, mock_model_init: mock.Mock, mock_validate_model_id_and_get_type: mock.Mock, + mock_get_jumpstart_configs: mock.Mock, ): mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS @@ -101,6 +105,9 @@ def test_without_arg_overwrites_without_kwarg_collisions_with_config( assert "enable_network_isolation" not in mock_model_init.call_args[1] + @mock.patch( + "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {} + ) @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.model.Model.__init__") @mock.patch("sagemaker.jumpstart.factory.model._retrieve_model_init_kwargs") @@ -118,6 +125,7 @@ def test_all_arg_overwrites_without_kwarg_collisions_with_config( mock_retrieve_kwargs: mock.Mock, mock_model_init: mock.Mock, mock_validate_model_id_and_get_type: mock.Mock, + mock_get_jumpstart_configs: mock.Mock, ): mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS @@ -147,6 +155,9 @@ def test_all_arg_overwrites_without_kwarg_collisions_with_config( override_enable_network_isolation, ) + @mock.patch( + "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {} + ) @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.model.Model.__init__") @mock.patch("sagemaker.jumpstart.factory.model._retrieve_model_init_kwargs") @@ -164,6 +175,7 @@ def test_without_arg_overwrites_all_kwarg_collisions_with_config( mock_retrieve_kwargs: mock.Mock, mock_model_init: mock.Mock, mock_validate_model_id_and_get_type: mock.Mock, + mock_get_jumpstart_configs: mock.Mock, ): mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS @@ -193,6 +205,9 @@ def test_without_arg_overwrites_all_kwarg_collisions_with_config( config_enable_network_isolation, ) + @mock.patch( + "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {} + ) @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.model.Model.__init__") @mock.patch("sagemaker.jumpstart.factory.model._retrieve_model_init_kwargs") @@ -210,6 +225,7 @@ def test_with_arg_overwrites_all_kwarg_collisions_with_config( mock_retrieve_kwargs: mock.Mock, mock_model_init: mock.Mock, mock_validate_model_id_and_get_type: mock.Mock, + mock_get_jumpstart_configs: mock.Mock, ): mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS @@ -241,6 +257,9 @@ def test_with_arg_overwrites_all_kwarg_collisions_with_config( override_enable_network_isolation, ) + @mock.patch( + "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {} + ) @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.model.Model.__init__") @mock.patch("sagemaker.jumpstart.factory.model._retrieve_model_init_kwargs") @@ -258,6 +277,7 @@ def test_without_arg_overwrites_all_kwarg_collisions_without_config( mock_retrieve_kwargs: mock.Mock, mock_model_init: mock.Mock, mock_validate_model_id_and_get_type: mock.Mock, + mock_get_jumpstart_configs: mock.Mock, ): mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS @@ -287,6 +307,9 @@ def test_without_arg_overwrites_all_kwarg_collisions_without_config( metadata_enable_network_isolation, ) + @mock.patch( + "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {} + ) @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.model.Model.__init__") @mock.patch("sagemaker.jumpstart.factory.model._retrieve_model_init_kwargs") @@ -304,6 +327,7 @@ def test_with_arg_overwrites_all_kwarg_collisions_without_config( mock_retrieve_kwargs: mock.Mock, mock_model_init: mock.Mock, mock_validate_model_id_and_get_type: mock.Mock, + mock_get_jumpstart_configs: mock.Mock, ): mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS @@ -334,6 +358,9 @@ def test_with_arg_overwrites_all_kwarg_collisions_without_config( override_enable_network_isolation, ) + @mock.patch( + "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {} + ) @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.model.Model.__init__") @mock.patch("sagemaker.jumpstart.factory.model._retrieve_model_init_kwargs") @@ -351,6 +378,7 @@ def test_without_arg_overwrites_without_kwarg_collisions_without_config( mock_retrieve_kwargs: mock.Mock, mock_model_init: mock.Mock, mock_validate_model_id_and_get_type: mock.Mock, + mock_get_jumpstart_configs: mock.Mock, ): mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS @@ -375,6 +403,9 @@ def test_without_arg_overwrites_without_kwarg_collisions_without_config( self.assertEquals(mock_model_init.call_args[1].get("role"), execution_role) assert "enable_network_isolation" not in mock_model_init.call_args[1] + @mock.patch( + "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {} + ) @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.model.Model.__init__") @mock.patch("sagemaker.jumpstart.factory.model._retrieve_model_init_kwargs") @@ -392,6 +423,7 @@ def test_with_arg_overwrites_without_kwarg_collisions_without_config( mock_retrieve_kwargs: mock.Mock, mock_model_init: mock.Mock, mock_validate_model_id_and_get_type: mock.Mock, + mock_get_jumpstart_configs: mock.Mock, ): mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS diff --git a/tests/unit/sagemaker/jumpstart/test_predictor.py b/tests/unit/sagemaker/jumpstart/test_predictor.py index 52f28f2da1..abce2bf687 100644 --- a/tests/unit/sagemaker/jumpstart/test_predictor.py +++ b/tests/unit/sagemaker/jumpstart/test_predictor.py @@ -159,6 +159,9 @@ def test_jumpstart_predictor_support_no_model_id_supplied_sad_case( patched_get_default_predictor.assert_not_called() +@mock.patch( + "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {} +) @patch("sagemaker.predictor.get_model_id_version_from_endpoint") @patch("sagemaker.jumpstart.payload_utils.JumpStartS3PayloadAccessor.get_object_cached") @patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @@ -170,6 +173,7 @@ def test_jumpstart_serializable_payload_with_predictor( patched_validate_model_id_and_get_type, patched_get_object_cached, patched_get_model_id_version_from_endpoint, + patched_get_jumpstart_configs, ): patched_get_object_cached.return_value = base64.b64decode("encodedimage") diff --git a/tests/unit/sagemaker/jumpstart/test_utils.py b/tests/unit/sagemaker/jumpstart/test_utils.py index c1ea8abcb8..85911a2854 100644 --- a/tests/unit/sagemaker/jumpstart/test_utils.py +++ b/tests/unit/sagemaker/jumpstart/test_utils.py @@ -49,6 +49,7 @@ get_spec_from_base_spec, get_special_model_spec, get_prototype_manifest, + get_base_deployment_configs, ) from mock import MagicMock @@ -1708,3 +1709,52 @@ def test_get_jumpstart_benchmark_stats_training( ] }, } + + +@pytest.mark.parametrize( + "config_name, expected", + [ + ( + None, + { + "Config Name": [ + "neuron-inference", + "neuron-inference-budget", + "gpu-inference-budget", + "gpu-inference", + ], + "Instance Type": ["ml.p2.xlarge", "ml.p2.xlarge", "ml.p2.xlarge", "ml.p2.xlarge"], + "Selected": ["No", "No", "No", "No"], + "Instance Rate (USD/Hrs)": [ + "0.0083000000", + "0.0083000000", + "0.0083000000", + "0.0083000000", + ], + }, + ), + ( + "neuron-inference", + { + "Config Name": [ + "neuron-inference", + "neuron-inference-budget", + "gpu-inference-budget", + "gpu-inference", + ], + "Instance Type": ["ml.p2.xlarge", "ml.p2.xlarge", "ml.p2.xlarge", "ml.p2.xlarge"], + "Selected": ["Yes", "No", "No", "No"], + "Instance Rate (USD/Hrs)": [ + "0.0083000000", + "0.0083000000", + "0.0083000000", + "0.0083000000", + ], + }, + ), + ], +) +def test_extract_metrics_from_deployment_configs(config_name, expected): + data = utils.extract_metrics_from_deployment_configs(get_base_deployment_configs(), config_name) + + assert data == expected diff --git a/tests/unit/sagemaker/jumpstart/utils.py b/tests/unit/sagemaker/jumpstart/utils.py index aee1497ec9..e0d6f645a8 100644 --- a/tests/unit/sagemaker/jumpstart/utils.py +++ b/tests/unit/sagemaker/jumpstart/utils.py @@ -12,9 +12,10 @@ # language governing permissions and limitations under the License. from __future__ import absolute_import import copy -from typing import List +from typing import List, Dict, Any import boto3 +from sagemaker.compute_resource_requirements import ResourceRequirements from sagemaker.jumpstart.cache import JumpStartModelsCache from sagemaker.jumpstart.constants import ( JUMPSTART_DEFAULT_REGION_NAME, @@ -27,6 +28,7 @@ JumpStartModelSpecs, JumpStartS3FileType, JumpStartModelHeader, + JumpStartModelInitKwargs, ) from sagemaker.jumpstart.enums import JumpStartModelType from sagemaker.jumpstart.utils import get_formatted_manifest @@ -43,6 +45,8 @@ SPECIAL_MODEL_SPECS_DICT, TRAINING_CONFIG_RANKINGS, TRAINING_CONFIGS, + DEPLOYMENT_CONFIGS, + INIT_KWARGS, ) @@ -297,3 +301,19 @@ def overwrite_dictionary( base_dictionary[key] = value return base_dictionary + + +def get_base_deployment_configs() -> List[Dict[str, Any]]: + return DEPLOYMENT_CONFIGS + + +def get_mock_init_kwargs(model_id) -> JumpStartModelInitKwargs: + return JumpStartModelInitKwargs( + model_id=model_id, + model_type=JumpStartModelType.OPEN_WEIGHTS, + model_data=INIT_KWARGS.get("model_data"), + image_uri=INIT_KWARGS.get("image_uri"), + instance_type=INIT_KWARGS.get("instance_type"), + env=INIT_KWARGS.get("env"), + resources=ResourceRequirements(), + ) diff --git a/tests/unit/sagemaker/serve/builder/test_js_builder.py b/tests/unit/sagemaker/serve/builder/test_js_builder.py index 2a0c791215..3d5148772e 100644 --- a/tests/unit/sagemaker/serve/builder/test_js_builder.py +++ b/tests/unit/sagemaker/serve/builder/test_js_builder.py @@ -23,6 +23,7 @@ LocalModelOutOfMemoryException, LocalModelInvocationException, ) +from tests.unit.sagemaker.serve.constants import DEPLOYMENT_CONFIGS mock_model_id = "huggingface-llm-amazon-falconlite" mock_t5_model_id = "google/flan-t5-xxl" @@ -638,3 +639,87 @@ def test_js_gated_model_ex( ValueError, lambda: builder.build(), ) + + @patch("sagemaker.serve.builder.jumpstart_builder._capture_telemetry", side_effect=None) + @patch( + "sagemaker.serve.builder.jumpstart_builder.JumpStart._is_jumpstart_model_id", + return_value=True, + ) + @patch( + "sagemaker.serve.builder.jumpstart_builder.JumpStart._create_pre_trained_js_model", + return_value=MagicMock(), + ) + @patch( + "sagemaker.serve.builder.jumpstart_builder.prepare_tgi_js_resources", + return_value=({"model_type": "t5", "n_head": 71}, True), + ) + @patch("sagemaker.serve.builder.jumpstart_builder._get_ram_usage_mb", return_value=1024) + @patch( + "sagemaker.serve.builder.jumpstart_builder._get_nb_instance", return_value="ml.g5.24xlarge" + ) + def test_list_deployment_configs( + self, + mock_get_nb_instance, + mock_get_ram_usage_mb, + mock_prepare_for_tgi, + mock_pre_trained_model, + mock_is_jumpstart_model, + mock_telemetry, + ): + builder = ModelBuilder( + model="facebook/galactica-mock-model-id", + schema_builder=mock_schema_builder, + ) + + mock_pre_trained_model.return_value.image_uri = mock_tgi_image_uri + mock_pre_trained_model.return_value.list_deployment_configs.side_effect = ( + lambda: DEPLOYMENT_CONFIGS + ) + + model = builder.build() + builder.serve_settings.telemetry_opt_out = True + + configs = model.list_deployment_configs() + + self.assertEqual(configs, DEPLOYMENT_CONFIGS) + + @patch("sagemaker.serve.builder.jumpstart_builder._capture_telemetry", side_effect=None) + @patch( + "sagemaker.serve.builder.jumpstart_builder.JumpStart._is_jumpstart_model_id", + return_value=True, + ) + @patch( + "sagemaker.serve.builder.jumpstart_builder.JumpStart._create_pre_trained_js_model", + return_value=MagicMock(), + ) + @patch( + "sagemaker.serve.builder.jumpstart_builder.prepare_tgi_js_resources", + return_value=({"model_type": "t5", "n_head": 71}, True), + ) + @patch("sagemaker.serve.builder.jumpstart_builder._get_ram_usage_mb", return_value=1024) + @patch( + "sagemaker.serve.builder.jumpstart_builder._get_nb_instance", return_value="ml.g5.24xlarge" + ) + def test_display_benchmark_metrics( + self, + mock_get_nb_instance, + mock_get_ram_usage_mb, + mock_prepare_for_tgi, + mock_pre_trained_model, + mock_is_jumpstart_model, + mock_telemetry, + ): + builder = ModelBuilder( + model="facebook/galactica-mock-model-id", + schema_builder=mock_schema_builder, + ) + + mock_pre_trained_model.return_value.image_uri = mock_tgi_image_uri + mock_pre_trained_model.return_value.display_benchmark_metrics.side_effect = ( + lambda *args, **kwargs: "metric data" + ) + + model = builder.build() + builder.serve_settings.telemetry_opt_out = True + + model.display_benchmark_metrics() diff --git a/tests/unit/sagemaker/serve/constants.py b/tests/unit/sagemaker/serve/constants.py index db9dd623d8..5c40c1bf64 100644 --- a/tests/unit/sagemaker/serve/constants.py +++ b/tests/unit/sagemaker/serve/constants.py @@ -15,3 +15,153 @@ MOCK_IMAGE_CONFIG = {"RepositoryAccessMode": "Vpc"} MOCK_VPC_CONFIG = {"Subnets": ["subnet-1234"], "SecurityGroupIds": ["sg123"]} +DEPLOYMENT_CONFIGS = [ + { + "ConfigName": "neuron-inference", + "BenchmarkMetrics": [ + {"name": "Latency", "value": "100", "unit": "Tokens/S"}, + {"name": "Throughput", "value": "1867", "unit": "Tokens/S"}, + ], + "DeploymentConfig": { + "ModelDataDownloadTimeout": 1200, + "ContainerStartupHealthCheckTimeout": 1200, + "ImageUri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4" + ".0-gpu-py310-cu121-ubuntu20.04", + "ModelData": { + "S3DataSource": { + "S3Uri": "s3://jumpstart-private-cache-prod-us-west-2/meta-textgeneration/meta-textgeneration" + "-llama-2-7b/artifacts/inference-prepack/v1.0.0/", + "S3DataType": "S3Prefix", + "CompressionType": "None", + } + }, + "InstanceType": "ml.p2.xlarge", + "Environment": { + "SAGEMAKER_PROGRAM": "inference.py", + "ENDPOINT_SERVER_TIMEOUT": "3600", + "MODEL_CACHE_ROOT": "/opt/ml/model", + "SAGEMAKER_ENV": "1", + "HF_MODEL_ID": "/opt/ml/model", + "MAX_INPUT_LENGTH": "4095", + "MAX_TOTAL_TOKENS": "4096", + "SM_NUM_GPUS": "1", + "SAGEMAKER_MODEL_SERVER_WORKERS": "1", + }, + "ComputeResourceRequirements": { + "MinMemoryRequiredInMb": 16384, + "NumberOfAcceleratorDevicesRequired": 1, + }, + }, + }, + { + "ConfigName": "neuron-inference-budget", + "BenchmarkMetrics": [ + {"name": "Latency", "value": "100", "unit": "Tokens/S"}, + {"name": "Throughput", "value": "1867", "unit": "Tokens/S"}, + ], + "DeploymentConfig": { + "ModelDataDownloadTimeout": 1200, + "ContainerStartupHealthCheckTimeout": 1200, + "ImageUri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4" + ".0-gpu-py310-cu121-ubuntu20.04", + "ModelData": { + "S3DataSource": { + "S3Uri": "s3://jumpstart-private-cache-prod-us-west-2/meta-textgeneration/meta-textgeneration" + "-llama-2-7b/artifacts/inference-prepack/v1.0.0/", + "S3DataType": "S3Prefix", + "CompressionType": "None", + } + }, + "InstanceType": "ml.p2.xlarge", + "Environment": { + "SAGEMAKER_PROGRAM": "inference.py", + "ENDPOINT_SERVER_TIMEOUT": "3600", + "MODEL_CACHE_ROOT": "/opt/ml/model", + "SAGEMAKER_ENV": "1", + "HF_MODEL_ID": "/opt/ml/model", + "MAX_INPUT_LENGTH": "4095", + "MAX_TOTAL_TOKENS": "4096", + "SM_NUM_GPUS": "1", + "SAGEMAKER_MODEL_SERVER_WORKERS": "1", + }, + "ComputeResourceRequirements": { + "MinMemoryRequiredInMb": 16384, + "NumberOfAcceleratorDevicesRequired": 1, + }, + }, + }, + { + "ConfigName": "gpu-inference-budget", + "BenchmarkMetrics": [ + {"name": "Latency", "value": "100", "unit": "Tokens/S"}, + {"name": "Throughput", "value": "1867", "unit": "Tokens/S"}, + ], + "DeploymentConfig": { + "ModelDataDownloadTimeout": 1200, + "ContainerStartupHealthCheckTimeout": 1200, + "ImageUri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4" + ".0-gpu-py310-cu121-ubuntu20.04", + "ModelData": { + "S3DataSource": { + "S3Uri": "s3://jumpstart-private-cache-prod-us-west-2/meta-textgeneration/meta-textgeneration" + "-llama-2-7b/artifacts/inference-prepack/v1.0.0/", + "S3DataType": "S3Prefix", + "CompressionType": "None", + } + }, + "InstanceType": "ml.p2.xlarge", + "Environment": { + "SAGEMAKER_PROGRAM": "inference.py", + "ENDPOINT_SERVER_TIMEOUT": "3600", + "MODEL_CACHE_ROOT": "/opt/ml/model", + "SAGEMAKER_ENV": "1", + "HF_MODEL_ID": "/opt/ml/model", + "MAX_INPUT_LENGTH": "4095", + "MAX_TOTAL_TOKENS": "4096", + "SM_NUM_GPUS": "1", + "SAGEMAKER_MODEL_SERVER_WORKERS": "1", + }, + "ComputeResourceRequirements": { + "MinMemoryRequiredInMb": 16384, + "NumberOfAcceleratorDevicesRequired": 1, + }, + }, + }, + { + "ConfigName": "gpu-inference", + "BenchmarkMetrics": [ + {"name": "Latency", "value": "100", "unit": "Tokens/S"}, + {"name": "Throughput", "value": "1867", "unit": "Tokens/S"}, + ], + "DeploymentConfig": { + "ModelDataDownloadTimeout": 1200, + "ContainerStartupHealthCheckTimeout": 1200, + "ImageUri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4" + ".0-gpu-py310-cu121-ubuntu20.04", + "ModelData": { + "S3DataSource": { + "S3Uri": "s3://jumpstart-private-cache-prod-us-west-2/meta-textgeneration/meta-textgeneration" + "-llama-2-7b/artifacts/inference-prepack/v1.0.0/", + "S3DataType": "S3Prefix", + "CompressionType": "None", + } + }, + "InstanceType": "ml.p2.xlarge", + "Environment": { + "SAGEMAKER_PROGRAM": "inference.py", + "ENDPOINT_SERVER_TIMEOUT": "3600", + "MODEL_CACHE_ROOT": "/opt/ml/model", + "SAGEMAKER_ENV": "1", + "HF_MODEL_ID": "/opt/ml/model", + "MAX_INPUT_LENGTH": "4095", + "MAX_TOTAL_TOKENS": "4096", + "SM_NUM_GPUS": "1", + "SAGEMAKER_MODEL_SERVER_WORKERS": "1", + }, + "ComputeResourceRequirements": { + "MinMemoryRequiredInMb": 16384, + "NumberOfAcceleratorDevicesRequired": 1, + }, + }, + }, +] diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py index 81d8279e6d..bf6a7cb09f 100644 --- a/tests/unit/test_utils.py +++ b/tests/unit/test_utils.py @@ -50,6 +50,8 @@ _is_bad_link, custom_extractall_tarfile, can_model_package_source_uri_autopopulate, + get_instance_rate_per_hour, + extract_instance_rate_per_hour, ) from tests.unit.sagemaker.workflow.helpers import CustomStep from sagemaker.workflow.parameters import ParameterString, ParameterInteger @@ -1866,3 +1868,77 @@ def test_deep_override_skip_keys(self): expected_result = {"a": 1, "b": {"x": 20, "y": 3, "z": 30}, "c": [4, 5]} self.assertEqual(deep_override_dict(dict1, dict2, skip_keys=["c", "d"]), expected_result) + + +@pytest.mark.parametrize( + "instance, region", + [ + ("t4g.nano", "us-west-2"), + ("t4g.nano", "eu-central-1"), + ("t4g.nano", "af-south-1"), + ("t4g.nano", "ap-northeast-2"), + ("t4g.nano", "cn-north-1"), + ], +) +@patch("boto3.client") +def test_get_instance_rate_per_hour(mock_client, instance, region): + amazon_sagemaker_price_result = { + "PriceList": [ + '{"terms": {"OnDemand": {"22VNQ3N6GZGZMXYM.JRTCKXETXF": {"priceDimensions":{' + '"22VNQ3N6GZGZMXYM.JRTCKXETXF.6YS6EN2CT7": {"unit": "Hrs", "endRange": "Inf", "description": "$0.0083 per ' + "On" + 'Demand Ubuntu Pro t4g.nano Instance Hour", "appliesTo": [], "rateCode": ' + '"22VNQ3N6GZGZMXYM.JRTCKXETXF.6YS6EN2CT7", "beginRange": "0", "pricePerUnit":{"USD": "0.0083000000"}}}, ' + '"sku": "22VNQ3N6GZGZMXYM", "effectiveDate": "2024-04-01T00:00:00Z", "offerTermCode": "JRTCKXETXF", ' + '"termAttributes": {}}}}}' + ] + } + + mock_client.return_value.get_products.side_effect = ( + lambda *args, **kwargs: amazon_sagemaker_price_result + ) + instance_rate = get_instance_rate_per_hour(instance_type=instance, region=region) + + assert instance_rate == {"name": "Instance Rate", "unit": "USD/Hrs", "value": "0.0083000000"} + + +@patch("boto3.client") +def test_get_instance_rate_per_hour_ex(mock_client): + mock_client.return_value.get_products.side_effect = lambda *args, **kwargs: Exception() + instance_rate = get_instance_rate_per_hour(instance_type="ml.t4g.nano", region="us-west-2") + + assert instance_rate is None + + +@pytest.mark.parametrize( + "price_data, expected_result", + [ + (None, None), + ( + { + "terms": { + "OnDemand": { + "3WK7G7WSYVS3K492.JRTCKXETXF": { + "priceDimensions": { + "3WK7G7WSYVS3K492.JRTCKXETXF.6YS6EN2CT7": { + "unit": "Hrs", + "endRange": "Inf", + "description": "$0.9 per Unused Reservation Linux p2.xlarge Instance Hour", + "appliesTo": [], + "rateCode": "3WK7G7WSYVS3K492.JRTCKXETXF.6YS6EN2CT7", + "beginRange": "0", + "pricePerUnit": {"USD": "0.9000000000"}, + } + } + } + } + } + }, + {"name": "Instance Rate", "unit": "USD/Hrs", "value": "0.9000000000"}, + ), + ], +) +def test_extract_instance_rate_per_hour(price_data, expected_result): + out = extract_instance_rate_per_hour(price_data) + + assert out == expected_result From eead6a0966e03217c681f7306cf9e747fabbfc6c Mon Sep 17 00:00:00 2001 From: Haotian An <33510317+Captainia@users.noreply.github.com> Date: Thu, 25 Apr 2024 11:51:10 -0400 Subject: [PATCH 19/32] feat: tag JumpStart resource with config names (#4608) * tag config name * format * resolving comments * format * format * update * fix * format * updates inference component config name * fix: tests --- src/sagemaker/jumpstart/enums.py | 1 + src/sagemaker/jumpstart/estimator.py | 9 +- src/sagemaker/jumpstart/factory/estimator.py | 5 +- src/sagemaker/jumpstart/factory/model.py | 2 +- src/sagemaker/jumpstart/session_utils.py | 56 +++--- src/sagemaker/jumpstart/types.py | 22 ++- src/sagemaker/jumpstart/utils.py | 95 ++++++---- src/sagemaker/predictor.py | 10 +- .../jumpstart/estimator/test_estimator.py | 19 +- .../sagemaker/jumpstart/model/test_model.py | 3 + .../sagemaker/jumpstart/test_predictor.py | 20 +- .../sagemaker/jumpstart/test_session_utils.py | 173 ++++++++++++------ tests/unit/sagemaker/jumpstart/test_utils.py | 65 ++++++- 13 files changed, 315 insertions(+), 165 deletions(-) diff --git a/src/sagemaker/jumpstart/enums.py b/src/sagemaker/jumpstart/enums.py index ca49fd41a3..0c192084ec 100644 --- a/src/sagemaker/jumpstart/enums.py +++ b/src/sagemaker/jumpstart/enums.py @@ -92,6 +92,7 @@ class JumpStartTag(str, Enum): MODEL_ID = "sagemaker-sdk:jumpstart-model-id" MODEL_VERSION = "sagemaker-sdk:jumpstart-model-version" MODEL_TYPE = "sagemaker-sdk:jumpstart-model-type" + MODEL_CONFIG_NAME = "sagemaker-sdk:jumpstart-model-config-name" class SerializerType(str, Enum): diff --git a/src/sagemaker/jumpstart/estimator.py b/src/sagemaker/jumpstart/estimator.py index cf9b720607..33bb73a83c 100644 --- a/src/sagemaker/jumpstart/estimator.py +++ b/src/sagemaker/jumpstart/estimator.py @@ -33,7 +33,7 @@ from sagemaker.jumpstart.factory.estimator import get_deploy_kwargs, get_fit_kwargs, get_init_kwargs from sagemaker.jumpstart.factory.model import get_default_predictor -from sagemaker.jumpstart.session_utils import get_model_id_version_from_training_job +from sagemaker.jumpstart.session_utils import get_model_info_from_training_job from sagemaker.jumpstart.types import JumpStartMetadataConfig from sagemaker.jumpstart.utils import ( get_jumpstart_configs, @@ -730,10 +730,10 @@ def attach( ValueError: if the model ID or version cannot be inferred from the training job. """ - + config_name = None if model_id is None: - model_id, model_version = get_model_id_version_from_training_job( + model_id, model_version, config_name = get_model_info_from_training_job( training_job_name=training_job_name, sagemaker_session=sagemaker_session ) @@ -749,6 +749,7 @@ def attach( tolerate_deprecated_model=True, # model is already trained, so tolerate if deprecated tolerate_vulnerable_model=True, # model is already trained, so tolerate if vulnerable sagemaker_session=sagemaker_session, + config_name=config_name, ) # eula was already accepted if the model was successfully trained @@ -1102,7 +1103,7 @@ def deploy( tolerate_deprecated_model=self.tolerate_deprecated_model, tolerate_vulnerable_model=self.tolerate_vulnerable_model, sagemaker_session=self.sagemaker_session, - # config_name=self.config_name, + config_name=self.config_name, ) # If a predictor class was passed, do not mutate predictor diff --git a/src/sagemaker/jumpstart/factory/estimator.py b/src/sagemaker/jumpstart/factory/estimator.py index 926f313b68..2d5b29b52f 100644 --- a/src/sagemaker/jumpstart/factory/estimator.py +++ b/src/sagemaker/jumpstart/factory/estimator.py @@ -478,7 +478,10 @@ def _add_tags_to_kwargs(kwargs: JumpStartEstimatorInitKwargs) -> JumpStartEstima if kwargs.sagemaker_session.settings.include_jumpstart_tags: kwargs.tags = add_jumpstart_model_id_version_tags( - kwargs.tags, kwargs.model_id, full_model_version + kwargs.tags, + kwargs.model_id, + full_model_version, + config_name=kwargs.config_name, ) return kwargs diff --git a/src/sagemaker/jumpstart/factory/model.py b/src/sagemaker/jumpstart/factory/model.py index 25a1d63215..b4f6d70583 100644 --- a/src/sagemaker/jumpstart/factory/model.py +++ b/src/sagemaker/jumpstart/factory/model.py @@ -496,7 +496,7 @@ def _add_tags_to_kwargs(kwargs: JumpStartModelDeployKwargs) -> Dict[str, Any]: if kwargs.sagemaker_session.settings.include_jumpstart_tags: kwargs.tags = add_jumpstart_model_id_version_tags( - kwargs.tags, kwargs.model_id, full_model_version, kwargs.model_type + kwargs.tags, kwargs.model_id, full_model_version, kwargs.model_type, kwargs.config_name ) return kwargs diff --git a/src/sagemaker/jumpstart/session_utils.py b/src/sagemaker/jumpstart/session_utils.py index e511a052d1..0fa7722f91 100644 --- a/src/sagemaker/jumpstart/session_utils.py +++ b/src/sagemaker/jumpstart/session_utils.py @@ -22,12 +22,12 @@ from sagemaker.utils import aws_partition -def get_model_id_version_from_endpoint( +def get_model_info_from_endpoint( endpoint_name: str, inference_component_name: Optional[str] = None, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, -) -> Tuple[str, str, Optional[str]]: - """Given an endpoint and optionally inference component names, return the model ID and version. +) -> Tuple[str, str, Optional[str], Optional[str]]: + """Optionally inference component names, return the model ID, version and config name. Infers the model ID and version based on the resource tags. Returns a tuple of the model ID and version. A third string element is included in the tuple for any inferred inference @@ -46,7 +46,8 @@ def get_model_id_version_from_endpoint( ( model_id, model_version, - ) = _get_model_id_version_from_inference_component_endpoint_with_inference_component_name( # noqa E501 # pylint: disable=c0301 + config_name, + ) = _get_model_info_from_inference_component_endpoint_with_inference_component_name( # noqa E501 # pylint: disable=c0301 inference_component_name, sagemaker_session ) @@ -54,22 +55,23 @@ def get_model_id_version_from_endpoint( ( model_id, model_version, + config_name, inference_component_name, - ) = _get_model_id_version_from_inference_component_endpoint_without_inference_component_name( # noqa E501 # pylint: disable=c0301 + ) = _get_model_info_from_inference_component_endpoint_without_inference_component_name( # noqa E501 # pylint: disable=c0301 endpoint_name, sagemaker_session ) else: - model_id, model_version = _get_model_id_version_from_model_based_endpoint( + model_id, model_version, config_name = _get_model_info_from_model_based_endpoint( endpoint_name, inference_component_name, sagemaker_session ) - return model_id, model_version, inference_component_name + return model_id, model_version, inference_component_name, config_name -def _get_model_id_version_from_inference_component_endpoint_without_inference_component_name( +def _get_model_info_from_inference_component_endpoint_without_inference_component_name( endpoint_name: str, sagemaker_session: Session -) -> Tuple[str, str, str]: - """Given an endpoint name, derives the model ID, version, and inferred inference component name. +) -> Tuple[str, str, str, str]: + """Derives the model ID, version, config name and inferred inference component name. This function assumes the endpoint corresponds to an inference-component-based endpoint. An endpoint is inference-component-based if and only if the associated endpoint config @@ -98,14 +100,14 @@ def _get_model_id_version_from_inference_component_endpoint_without_inference_co ) inference_component_name = inference_component_names[0] return ( - *_get_model_id_version_from_inference_component_endpoint_with_inference_component_name( + *_get_model_info_from_inference_component_endpoint_with_inference_component_name( inference_component_name, sagemaker_session ), inference_component_name, ) -def _get_model_id_version_from_inference_component_endpoint_with_inference_component_name( +def _get_model_info_from_inference_component_endpoint_with_inference_component_name( inference_component_name: str, sagemaker_session: Session ): """Returns the model ID and version inferred from a SageMaker inference component. @@ -123,7 +125,7 @@ def _get_model_id_version_from_inference_component_endpoint_with_inference_compo f"inference-component/{inference_component_name}" ) - model_id, model_version = get_jumpstart_model_id_version_from_resource_arn( + model_id, model_version, config_name = get_jumpstart_model_id_version_from_resource_arn( inference_component_arn, sagemaker_session ) @@ -134,15 +136,15 @@ def _get_model_id_version_from_inference_component_endpoint_with_inference_compo "when retrieving default predictor for this inference component." ) - return model_id, model_version + return model_id, model_version, config_name -def _get_model_id_version_from_model_based_endpoint( +def _get_model_info_from_model_based_endpoint( endpoint_name: str, inference_component_name: Optional[str], sagemaker_session: Session, -) -> Tuple[str, str]: - """Returns the model ID and version inferred from a model-based endpoint. +) -> Tuple[str, str, Optional[str]]: + """Returns the model ID, version and config name inferred from a model-based endpoint. Raises: ValueError: If an inference component name is supplied, or if the endpoint does @@ -161,7 +163,7 @@ def _get_model_id_version_from_model_based_endpoint( endpoint_arn = f"arn:{partition}:sagemaker:{region}:{account_id}:endpoint/{endpoint_name}" - model_id, model_version = get_jumpstart_model_id_version_from_resource_arn( + model_id, model_version, config_name = get_jumpstart_model_id_version_from_resource_arn( endpoint_arn, sagemaker_session ) @@ -172,14 +174,14 @@ def _get_model_id_version_from_model_based_endpoint( "predictor for this endpoint." ) - return model_id, model_version + return model_id, model_version, config_name -def get_model_id_version_from_training_job( +def get_model_info_from_training_job( training_job_name: str, sagemaker_session: Optional[Session] = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, -) -> Tuple[str, str]: - """Returns the model ID and version inferred from a training job. +) -> Tuple[str, str, Optional[str]]: + """Returns the model ID and version and config name inferred from a training job. Raises: ValueError: If the training job does not have tags from which the model ID @@ -194,9 +196,11 @@ def get_model_id_version_from_training_job( f"arn:{partition}:sagemaker:{region}:{account_id}:training-job/{training_job_name}" ) - model_id, inferred_model_version = get_jumpstart_model_id_version_from_resource_arn( - training_job_arn, sagemaker_session - ) + ( + model_id, + inferred_model_version, + config_name, + ) = get_jumpstart_model_id_version_from_resource_arn(training_job_arn, sagemaker_session) model_version = inferred_model_version or None @@ -207,4 +211,4 @@ def get_model_id_version_from_training_job( "for this training job." ) - return model_id, model_version + return model_id, model_version, config_name diff --git a/src/sagemaker/jumpstart/types.py b/src/sagemaker/jumpstart/types.py index 07bd769054..bf0a84319b 100644 --- a/src/sagemaker/jumpstart/types.py +++ b/src/sagemaker/jumpstart/types.py @@ -1064,9 +1064,8 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: Dictionary representation of the config component. """ for field in json_obj.keys(): - if field not in self.__slots__: - raise ValueError(f"Invalid component field: {field}") - setattr(self, field, json_obj[field]) + if field in self.__slots__: + setattr(self, field, json_obj[field]) class JumpStartMetadataConfig(JumpStartDataHolderType): @@ -1164,6 +1163,8 @@ def get_top_config_from_ranking( ) -> Optional[JumpStartMetadataConfig]: """Gets the best the config based on config ranking. + Fallback to use the ordering in config names if + ranking is not available. Args: ranking_name (str): The ranking name that config priority is based on. @@ -1171,13 +1172,8 @@ def get_top_config_from_ranking( The instance type which the config selection is based on. Raises: - ValueError: If the config exists but missing config ranking. NotImplementedError: If the scope is unrecognized. """ - if self.configs and ( - not self.config_rankings or not self.config_rankings.get(ranking_name) - ): - raise ValueError(f"Config exists but missing config ranking {ranking_name}.") if self.scope == JumpStartScriptScope.INFERENCE: instance_type_attribute = "supported_inference_instance_types" @@ -1186,8 +1182,14 @@ def get_top_config_from_ranking( else: raise NotImplementedError(f"Unknown script scope {self.scope}") - rankings = self.config_rankings.get(ranking_name) - for config_name in rankings.rankings: + if self.configs and ( + not self.config_rankings or not self.config_rankings.get(ranking_name) + ): + ranked_config_names = sorted(list(self.configs.keys())) + else: + rankings = self.config_rankings.get(ranking_name) + ranked_config_names = rankings.rankings + for config_name in ranked_config_names: resolved_config = self.configs[config_name].resolved_config if instance_type and instance_type not in getattr( resolved_config, instance_type_attribute diff --git a/src/sagemaker/jumpstart/utils.py b/src/sagemaker/jumpstart/utils.py index 905f2a18d5..59bf11b415 100644 --- a/src/sagemaker/jumpstart/utils.py +++ b/src/sagemaker/jumpstart/utils.py @@ -318,6 +318,7 @@ def add_single_jumpstart_tag( tag_key_in_array(enums.JumpStartTag.MODEL_ID, curr_tags) or tag_key_in_array(enums.JumpStartTag.MODEL_VERSION, curr_tags) or tag_key_in_array(enums.JumpStartTag.MODEL_TYPE, curr_tags) + or tag_key_in_array(enums.JumpStartTag.MODEL_CONFIG_NAME, curr_tags) ) if is_uri else False @@ -353,6 +354,7 @@ def add_jumpstart_model_id_version_tags( model_id: str, model_version: str, model_type: Optional[enums.JumpStartModelType] = None, + config_name: Optional[str] = None, ) -> List[TagsDict]: """Add custom model ID and version tags to JumpStart related resources.""" if model_id is None or model_version is None: @@ -376,6 +378,13 @@ def add_jumpstart_model_id_version_tags( tags, is_uri=False, ) + if config_name: + tags = add_single_jumpstart_tag( + config_name, + enums.JumpStartTag.MODEL_CONFIG_NAME, + tags, + is_uri=False, + ) return tags @@ -800,52 +809,72 @@ def validate_model_id_and_get_type( return None +def _extract_value_from_list_of_tags( + tag_keys: List[str], + list_tags_result: List[str], + resource_name: str, + resource_arn: str, +): + """Extracts value from list of tags with check of duplicate tags. + + Returns None if no value is found. + """ + resolved_value = None + for tag_key in tag_keys: + try: + value_from_tag = get_tag_value(tag_key, list_tags_result) + except KeyError: + continue + if value_from_tag is not None: + if resolved_value is not None and value_from_tag != resolved_value: + constants.JUMPSTART_LOGGER.warning( + "Found multiple %s tags on the following resource: %s", + resource_name, + resource_arn, + ) + resolved_value = None + break + resolved_value = value_from_tag + return resolved_value + + def get_jumpstart_model_id_version_from_resource_arn( resource_arn: str, sagemaker_session: Session = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION, -) -> Tuple[Optional[str], Optional[str]]: - """Returns the JumpStart model ID and version if in resource tags. +) -> Tuple[Optional[str], Optional[str], Optional[str]]: + """Returns the JumpStart model ID, version and config name if in resource tags. - Returns 'None' if model ID or version cannot be inferred from tags. + Returns 'None' if model ID or version or config name cannot be inferred from tags. """ list_tags_result = sagemaker_session.list_tags(resource_arn) - model_id: Optional[str] = None - model_version: Optional[str] = None - model_id_keys = [enums.JumpStartTag.MODEL_ID, *constants.EXTRA_MODEL_ID_TAGS] model_version_keys = [enums.JumpStartTag.MODEL_VERSION, *constants.EXTRA_MODEL_VERSION_TAGS] + model_config_name_keys = [enums.JumpStartTag.MODEL_CONFIG_NAME] - for model_id_key in model_id_keys: - try: - model_id_from_tag = get_tag_value(model_id_key, list_tags_result) - except KeyError: - continue - if model_id_from_tag is not None: - if model_id is not None and model_id_from_tag != model_id: - constants.JUMPSTART_LOGGER.warning( - "Found multiple model ID tags on the following resource: %s", resource_arn - ) - model_id = None - break - model_id = model_id_from_tag + model_id: Optional[str] = _extract_value_from_list_of_tags( + tag_keys=model_id_keys, + list_tags_result=list_tags_result, + resource_name="model ID", + resource_arn=resource_arn, + ) - for model_version_key in model_version_keys: - try: - model_version_from_tag = get_tag_value(model_version_key, list_tags_result) - except KeyError: - continue - if model_version_from_tag is not None: - if model_version is not None and model_version_from_tag != model_version: - constants.JUMPSTART_LOGGER.warning( - "Found multiple model version tags on the following resource: %s", resource_arn - ) - model_version = None - break - model_version = model_version_from_tag + model_version: Optional[str] = _extract_value_from_list_of_tags( + tag_keys=model_version_keys, + list_tags_result=list_tags_result, + resource_name="model version", + resource_arn=resource_arn, + ) + + config_name: Optional[str] = _extract_value_from_list_of_tags( + tag_keys=model_config_name_keys, + list_tags_result=list_tags_result, + resource_name="model config name", + resource_arn=resource_arn, + ) - return model_id, model_version + return model_id, model_version, config_name def get_region_fallback( diff --git a/src/sagemaker/predictor.py b/src/sagemaker/predictor.py index 6f846bba65..14e2ae278b 100644 --- a/src/sagemaker/predictor.py +++ b/src/sagemaker/predictor.py @@ -18,7 +18,7 @@ from sagemaker.jumpstart.enums import JumpStartModelType from sagemaker.jumpstart.factory.model import get_default_predictor -from sagemaker.jumpstart.session_utils import get_model_id_version_from_endpoint +from sagemaker.jumpstart.session_utils import get_model_info_from_endpoint from sagemaker.session import Session @@ -78,9 +78,8 @@ def retrieve_default( inferred_model_id, inferred_model_version, inferred_inference_component_name, - ) = get_model_id_version_from_endpoint( - endpoint_name, inference_component_name, sagemaker_session - ) + inferred_config_name, + ) = get_model_info_from_endpoint(endpoint_name, inference_component_name, sagemaker_session) if not inferred_model_id: raise ValueError( @@ -92,8 +91,10 @@ def retrieve_default( model_id = inferred_model_id model_version = model_version or inferred_model_version or "*" inference_component_name = inference_component_name or inferred_inference_component_name + config_name = inferred_config_name or None else: model_version = model_version or "*" + config_name = None predictor = Predictor( endpoint_name=endpoint_name, @@ -110,4 +111,5 @@ def retrieve_default( tolerate_vulnerable_model=tolerate_vulnerable_model, sagemaker_session=sagemaker_session, model_type=model_type, + config_name=config_name, ) diff --git a/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py b/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py index f07bb44ba1..be0e06472c 100644 --- a/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py +++ b/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py @@ -1018,7 +1018,7 @@ def test_jumpstart_estimator_attach_eula_model( ) @mock.patch("sagemaker.jumpstart.estimator.JumpStartEstimator._attach") - @mock.patch("sagemaker.jumpstart.estimator.get_model_id_version_from_training_job") + @mock.patch("sagemaker.jumpstart.estimator.get_model_info_from_training_job") @mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.factory.estimator.JUMPSTART_DEFAULT_REGION_NAME", region) @@ -1026,15 +1026,16 @@ def test_jumpstart_estimator_attach_no_model_id_happy_case( self, mock_get_model_specs: mock.Mock, mock_validate_model_id_and_get_type: mock.Mock, - get_model_id_version_from_training_job: mock.Mock, + get_model_info_from_training_job: mock.Mock, mock_attach: mock.Mock, ): mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS - get_model_id_version_from_training_job.return_value = ( + get_model_info_from_training_job.return_value = ( "js-trainable-model-prepacked", "1.0.0", + None, ) mock_get_model_specs.side_effect = get_special_model_spec @@ -1045,7 +1046,7 @@ def test_jumpstart_estimator_attach_no_model_id_happy_case( training_job_name="some-training-job-name", sagemaker_session=mock_session ) - get_model_id_version_from_training_job.assert_called_once_with( + get_model_info_from_training_job.assert_called_once_with( training_job_name="some-training-job-name", sagemaker_session=mock_session, ) @@ -1063,7 +1064,7 @@ def test_jumpstart_estimator_attach_no_model_id_happy_case( ) @mock.patch("sagemaker.jumpstart.estimator.JumpStartEstimator._attach") - @mock.patch("sagemaker.jumpstart.estimator.get_model_id_version_from_training_job") + @mock.patch("sagemaker.jumpstart.estimator.get_model_info_from_training_job") @mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.factory.estimator.JUMPSTART_DEFAULT_REGION_NAME", region) @@ -1071,13 +1072,13 @@ def test_jumpstart_estimator_attach_no_model_id_sad_case( self, mock_get_model_specs: mock.Mock, mock_validate_model_id_and_get_type: mock.Mock, - get_model_id_version_from_training_job: mock.Mock, + get_model_info_from_training_job: mock.Mock, mock_attach: mock.Mock, ): mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS - get_model_id_version_from_training_job.side_effect = ValueError() + get_model_info_from_training_job.side_effect = ValueError() mock_get_model_specs.side_effect = get_special_model_spec @@ -1088,7 +1089,7 @@ def test_jumpstart_estimator_attach_no_model_id_sad_case( training_job_name="some-training-job-name", sagemaker_session=mock_session ) - get_model_id_version_from_training_job.assert_called_once_with( + get_model_info_from_training_job.assert_called_once_with( training_job_name="some-training-job-name", sagemaker_session=mock_session, ) @@ -1216,6 +1217,7 @@ def test_no_predictor_returns_default_predictor( tolerate_deprecated_model=False, tolerate_vulnerable_model=False, sagemaker_session=estimator.sagemaker_session, + config_name=None, ) self.assertEqual(type(predictor), Predictor) self.assertEqual(predictor, default_predictor_with_presets) @@ -1898,6 +1900,7 @@ def test_estimator_initialization_with_config_name( tags=[ {"Key": JumpStartTag.MODEL_ID, "Value": "pytorch-eqa-bert-base-cased"}, {"Key": JumpStartTag.MODEL_VERSION, "Value": "1.0.0"}, + {"Key": JumpStartTag.MODEL_CONFIG_NAME, "Value": "neuron-training"}, ], enable_network_isolation=False, ) diff --git a/tests/unit/sagemaker/jumpstart/model/test_model.py b/tests/unit/sagemaker/jumpstart/model/test_model.py index 2df904dce2..476002457b 100644 --- a/tests/unit/sagemaker/jumpstart/model/test_model.py +++ b/tests/unit/sagemaker/jumpstart/model/test_model.py @@ -1559,6 +1559,7 @@ def test_model_initialization_with_config_name( tags=[ {"Key": JumpStartTag.MODEL_ID, "Value": "pytorch-eqa-bert-base-cased"}, {"Key": JumpStartTag.MODEL_VERSION, "Value": "1.0.0"}, + {"Key": JumpStartTag.MODEL_CONFIG_NAME, "Value": "neuron-inference"}, ], wait=True, endpoint_logging=False, @@ -1618,6 +1619,7 @@ def test_model_set_deployment_config( tags=[ {"Key": JumpStartTag.MODEL_ID, "Value": "pytorch-eqa-bert-base-cased"}, {"Key": JumpStartTag.MODEL_VERSION, "Value": "1.0.0"}, + {"Key": JumpStartTag.MODEL_CONFIG_NAME, "Value": "neuron-inference"}, ], wait=True, endpoint_logging=False, @@ -1659,6 +1661,7 @@ def test_model_unset_deployment_config( tags=[ {"Key": JumpStartTag.MODEL_ID, "Value": "pytorch-eqa-bert-base-cased"}, {"Key": JumpStartTag.MODEL_VERSION, "Value": "1.0.0"}, + {"Key": JumpStartTag.MODEL_CONFIG_NAME, "Value": "neuron-inference"}, ], wait=True, endpoint_logging=False, diff --git a/tests/unit/sagemaker/jumpstart/test_predictor.py b/tests/unit/sagemaker/jumpstart/test_predictor.py index abce2bf687..1cc8f292f0 100644 --- a/tests/unit/sagemaker/jumpstart/test_predictor.py +++ b/tests/unit/sagemaker/jumpstart/test_predictor.py @@ -18,7 +18,7 @@ from tests.unit.sagemaker.jumpstart.utils import get_special_model_spec, get_spec_from_base_spec -@patch("sagemaker.predictor.get_model_id_version_from_endpoint") +@patch("sagemaker.predictor.get_model_info_from_endpoint") @patch("sagemaker.jumpstart.utils.verify_model_region_and_return_specs") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") def test_jumpstart_predictor_support( @@ -52,7 +52,7 @@ def test_jumpstart_predictor_support( assert js_predictor.accept == MIMEType.JSON -@patch("sagemaker.predictor.get_model_id_version_from_endpoint") +@patch("sagemaker.predictor.get_model_info_from_endpoint") @patch("sagemaker.jumpstart.utils.verify_model_region_and_return_specs") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") def test_proprietary_predictor_support( @@ -91,7 +91,7 @@ def test_proprietary_predictor_support( @patch("sagemaker.predictor.Predictor") @patch("sagemaker.predictor.get_default_predictor") -@patch("sagemaker.predictor.get_model_id_version_from_endpoint") +@patch("sagemaker.predictor.get_model_info_from_endpoint") @patch("sagemaker.jumpstart.utils.verify_model_region_and_return_specs") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") def test_jumpstart_predictor_support_no_model_id_supplied_happy_case( @@ -109,6 +109,7 @@ def test_jumpstart_predictor_support_no_model_id_supplied_happy_case( "predictor-specs-model", "1.2.3", None, + None, ) mock_session = Mock() @@ -128,11 +129,12 @@ def test_jumpstart_predictor_support_no_model_id_supplied_happy_case( tolerate_vulnerable_model=False, sagemaker_session=mock_session, model_type=JumpStartModelType.OPEN_WEIGHTS, + config_name=None, ) @patch("sagemaker.predictor.get_default_predictor") -@patch("sagemaker.predictor.get_model_id_version_from_endpoint") +@patch("sagemaker.predictor.get_model_info_from_endpoint") @patch("sagemaker.jumpstart.utils.verify_model_region_and_return_specs") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") def test_jumpstart_predictor_support_no_model_id_supplied_sad_case( @@ -159,10 +161,8 @@ def test_jumpstart_predictor_support_no_model_id_supplied_sad_case( patched_get_default_predictor.assert_not_called() -@mock.patch( - "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {} -) -@patch("sagemaker.predictor.get_model_id_version_from_endpoint") +@patch("sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {}) +@patch("sagemaker.predictor.get_model_info_from_endpoint") @patch("sagemaker.jumpstart.payload_utils.JumpStartS3PayloadAccessor.get_object_cached") @patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @patch("sagemaker.jumpstart.utils.verify_model_region_and_return_specs") @@ -172,7 +172,7 @@ def test_jumpstart_serializable_payload_with_predictor( patched_verify_model_region_and_return_specs, patched_validate_model_id_and_get_type, patched_get_object_cached, - patched_get_model_id_version_from_endpoint, + patched_get_model_info_from_endpoint, patched_get_jumpstart_configs, ): @@ -183,7 +183,7 @@ def test_jumpstart_serializable_payload_with_predictor( patched_get_model_specs.side_effect = get_special_model_spec model_id, model_version = "default_payloads", "*" - patched_get_model_id_version_from_endpoint.return_value = model_id, model_version, None + patched_get_model_info_from_endpoint.return_value = model_id, model_version, None js_predictor = predictor.retrieve_default( endpoint_name="blah", model_id=model_id, model_version=model_version diff --git a/tests/unit/sagemaker/jumpstart/test_session_utils.py b/tests/unit/sagemaker/jumpstart/test_session_utils.py index 76ad50f31c..9dc8acb32a 100644 --- a/tests/unit/sagemaker/jumpstart/test_session_utils.py +++ b/tests/unit/sagemaker/jumpstart/test_session_utils.py @@ -4,16 +4,16 @@ import pytest from sagemaker.jumpstart.session_utils import ( - _get_model_id_version_from_inference_component_endpoint_with_inference_component_name, - _get_model_id_version_from_inference_component_endpoint_without_inference_component_name, - _get_model_id_version_from_model_based_endpoint, - get_model_id_version_from_endpoint, - get_model_id_version_from_training_job, + _get_model_info_from_inference_component_endpoint_with_inference_component_name, + _get_model_info_from_inference_component_endpoint_without_inference_component_name, + _get_model_info_from_model_based_endpoint, + get_model_info_from_endpoint, + get_model_info_from_training_job, ) @patch("sagemaker.jumpstart.session_utils.get_jumpstart_model_id_version_from_resource_arn") -def test_get_model_id_version_from_training_job_happy_case( +def test_get_model_info_from_training_job_happy_case( mock_get_jumpstart_model_id_version_from_resource_arn, ): mock_sm_session = Mock() @@ -23,11 +23,35 @@ def test_get_model_id_version_from_training_job_happy_case( mock_get_jumpstart_model_id_version_from_resource_arn.return_value = ( "model_id", "model_version", + None, + ) + + retval = get_model_info_from_training_job("bLaH", sagemaker_session=mock_sm_session) + + assert retval == ("model_id", "model_version", None) + + mock_get_jumpstart_model_id_version_from_resource_arn.assert_called_once_with( + "arn:aws:sagemaker:us-west-2:123456789012:training-job/bLaH", mock_sm_session + ) + + +@patch("sagemaker.jumpstart.session_utils.get_jumpstart_model_id_version_from_resource_arn") +def test_get_model_info_from_training_job_config_name( + mock_get_jumpstart_model_id_version_from_resource_arn, +): + mock_sm_session = Mock() + mock_sm_session.boto_region_name = "us-west-2" + mock_sm_session.account_id = Mock(return_value="123456789012") + + mock_get_jumpstart_model_id_version_from_resource_arn.return_value = ( + "model_id", + "model_version", + "config_name", ) - retval = get_model_id_version_from_training_job("bLaH", sagemaker_session=mock_sm_session) + retval = get_model_info_from_training_job("bLaH", sagemaker_session=mock_sm_session) - assert retval == ("model_id", "model_version") + assert retval == ("model_id", "model_version", "config_name") mock_get_jumpstart_model_id_version_from_resource_arn.assert_called_once_with( "arn:aws:sagemaker:us-west-2:123456789012:training-job/bLaH", mock_sm_session @@ -35,7 +59,7 @@ def test_get_model_id_version_from_training_job_happy_case( @patch("sagemaker.jumpstart.session_utils.get_jumpstart_model_id_version_from_resource_arn") -def test_get_model_id_version_from_training_job_no_model_id_inferred( +def test_get_model_info_from_training_job_no_model_id_inferred( mock_get_jumpstart_model_id_version_from_resource_arn, ): mock_sm_session = Mock() @@ -48,11 +72,11 @@ def test_get_model_id_version_from_training_job_no_model_id_inferred( ) with pytest.raises(ValueError): - get_model_id_version_from_training_job("blah", sagemaker_session=mock_sm_session) + get_model_info_from_training_job("blah", sagemaker_session=mock_sm_session) @patch("sagemaker.jumpstart.session_utils.get_jumpstart_model_id_version_from_resource_arn") -def test_get_model_id_version_from_model_based_endpoint_happy_case( +def test_get_model_info_from_model_based_endpoint_happy_case( mock_get_jumpstart_model_id_version_from_resource_arn, ): mock_sm_session = Mock() @@ -62,13 +86,14 @@ def test_get_model_id_version_from_model_based_endpoint_happy_case( mock_get_jumpstart_model_id_version_from_resource_arn.return_value = ( "model_id", "model_version", + None, ) - retval = _get_model_id_version_from_model_based_endpoint( + retval = _get_model_info_from_model_based_endpoint( "bLaH", inference_component_name=None, sagemaker_session=mock_sm_session ) - assert retval == ("model_id", "model_version") + assert retval == ("model_id", "model_version", None) mock_get_jumpstart_model_id_version_from_resource_arn.assert_called_once_with( "arn:aws:sagemaker:us-west-2:123456789012:endpoint/blah", mock_sm_session @@ -76,7 +101,7 @@ def test_get_model_id_version_from_model_based_endpoint_happy_case( @patch("sagemaker.jumpstart.session_utils.get_jumpstart_model_id_version_from_resource_arn") -def test_get_model_id_version_from_model_based_endpoint_inference_component_supplied( +def test_get_model_info_from_model_based_endpoint_inference_component_supplied( mock_get_jumpstart_model_id_version_from_resource_arn, ): mock_sm_session = Mock() @@ -86,16 +111,17 @@ def test_get_model_id_version_from_model_based_endpoint_inference_component_supp mock_get_jumpstart_model_id_version_from_resource_arn.return_value = ( "model_id", "model_version", + None, ) with pytest.raises(ValueError): - _get_model_id_version_from_model_based_endpoint( + _get_model_info_from_model_based_endpoint( "blah", inference_component_name="some-name", sagemaker_session=mock_sm_session ) @patch("sagemaker.jumpstart.session_utils.get_jumpstart_model_id_version_from_resource_arn") -def test_get_model_id_version_from_model_based_endpoint_no_model_id_inferred( +def test_get_model_info_from_model_based_endpoint_no_model_id_inferred( mock_get_jumpstart_model_id_version_from_resource_arn, ): mock_sm_session = Mock() @@ -108,13 +134,13 @@ def test_get_model_id_version_from_model_based_endpoint_no_model_id_inferred( ) with pytest.raises(ValueError): - _get_model_id_version_from_model_based_endpoint( + _get_model_info_from_model_based_endpoint( "blah", inference_component_name="some-name", sagemaker_session=mock_sm_session ) @patch("sagemaker.jumpstart.session_utils.get_jumpstart_model_id_version_from_resource_arn") -def test_get_model_id_version_from_inference_component_endpoint_with_inference_component_name_happy_case( +def test_get_model_info_from_inference_component_endpoint_with_inference_component_name_happy_case( mock_get_jumpstart_model_id_version_from_resource_arn, ): mock_sm_session = Mock() @@ -124,13 +150,14 @@ def test_get_model_id_version_from_inference_component_endpoint_with_inference_c mock_get_jumpstart_model_id_version_from_resource_arn.return_value = ( "model_id", "model_version", + None, ) - retval = _get_model_id_version_from_inference_component_endpoint_with_inference_component_name( + retval = _get_model_info_from_inference_component_endpoint_with_inference_component_name( "bLaH", sagemaker_session=mock_sm_session ) - assert retval == ("model_id", "model_version") + assert retval == ("model_id", "model_version", None) mock_get_jumpstart_model_id_version_from_resource_arn.assert_called_once_with( "arn:aws:sagemaker:us-west-2:123456789012:inference-component/bLaH", mock_sm_session @@ -138,7 +165,7 @@ def test_get_model_id_version_from_inference_component_endpoint_with_inference_c @patch("sagemaker.jumpstart.session_utils.get_jumpstart_model_id_version_from_resource_arn") -def test_get_model_id_version_from_inference_component_endpoint_with_inference_component_name_no_model_id_inferred( +def test_get_model_info_from_inference_component_endpoint_with_inference_component_name_no_model_id_inferred( mock_get_jumpstart_model_id_version_from_resource_arn, ): mock_sm_session = Mock() @@ -148,23 +175,24 @@ def test_get_model_id_version_from_inference_component_endpoint_with_inference_c mock_get_jumpstart_model_id_version_from_resource_arn.return_value = ( None, None, + None, ) with pytest.raises(ValueError): - _get_model_id_version_from_inference_component_endpoint_with_inference_component_name( + _get_model_info_from_inference_component_endpoint_with_inference_component_name( "blah", sagemaker_session=mock_sm_session ) @patch( - "sagemaker.jumpstart.session_utils._get_model_id_version_from_inference_" + "sagemaker.jumpstart.session_utils._get_model_info_from_inference_" "component_endpoint_with_inference_component_name" ) -def test_get_model_id_version_from_inference_component_endpoint_without_inference_component_name_happy_case( - mock_get_model_id_version_from_inference_component_endpoint_with_inference_component_name, +def test_get_model_info_from_inference_component_endpoint_without_inference_component_name_happy_case( + mock_get_model_info_from_inference_component_endpoint_with_inference_component_name, ): mock_sm_session = Mock() - mock_get_model_id_version_from_inference_component_endpoint_with_inference_component_name.return_value = ( + mock_get_model_info_from_inference_component_endpoint_with_inference_component_name.return_value = ( "model_id", "model_version", ) @@ -172,10 +200,8 @@ def test_get_model_id_version_from_inference_component_endpoint_without_inferenc return_value=["icname"] ) - retval = ( - _get_model_id_version_from_inference_component_endpoint_without_inference_component_name( - "blahblah", mock_sm_session - ) + retval = _get_model_info_from_inference_component_endpoint_without_inference_component_name( + "blahblah", mock_sm_session ) assert retval == ("model_id", "model_version", "icname") @@ -185,14 +211,14 @@ def test_get_model_id_version_from_inference_component_endpoint_without_inferenc @patch( - "sagemaker.jumpstart.session_utils._get_model_id_version_from_inference_" + "sagemaker.jumpstart.session_utils._get_model_info_from_inference_" "component_endpoint_with_inference_component_name" ) -def test_get_model_id_version_from_inference_component_endpoint_without_ic_name_no_ic_for_endpoint( - mock_get_model_id_version_from_inference_component_endpoint_with_inference_component_name, +def test_get_model_info_from_inference_component_endpoint_without_ic_name_no_ic_for_endpoint( + mock_get_model_info_from_inference_component_endpoint_with_inference_component_name, ): mock_sm_session = Mock() - mock_get_model_id_version_from_inference_component_endpoint_with_inference_component_name.return_value = ( + mock_get_model_info_from_inference_component_endpoint_with_inference_component_name.return_value = ( "model_id", "model_version", ) @@ -200,7 +226,7 @@ def test_get_model_id_version_from_inference_component_endpoint_without_ic_name_ return_value=[] ) with pytest.raises(ValueError): - _get_model_id_version_from_inference_component_endpoint_without_inference_component_name( + _get_model_info_from_inference_component_endpoint_without_inference_component_name( "blahblah", mock_sm_session ) @@ -210,14 +236,14 @@ def test_get_model_id_version_from_inference_component_endpoint_without_ic_name_ @patch( - "sagemaker.jumpstart.session_utils._get_model_id" - "_version_from_inference_component_endpoint_with_inference_component_name" + "sagemaker.jumpstart.session_utils._get_model" + "_info_from_inference_component_endpoint_with_inference_component_name" ) def test_get_model_id_version_from_ic_endpoint_without_inference_component_name_multiple_ics_for_endpoint( - mock_get_model_id_version_from_inference_component_endpoint_with_inference_component_name, + mock_get_model_info_from_inference_component_endpoint_with_inference_component_name, ): mock_sm_session = Mock() - mock_get_model_id_version_from_inference_component_endpoint_with_inference_component_name.return_value = ( + mock_get_model_info_from_inference_component_endpoint_with_inference_component_name.return_value = ( "model_id", "model_version", ) @@ -227,7 +253,7 @@ def test_get_model_id_version_from_ic_endpoint_without_inference_component_name_ ) with pytest.raises(ValueError): - _get_model_id_version_from_inference_component_endpoint_without_inference_component_name( + _get_model_info_from_inference_component_endpoint_without_inference_component_name( "blahblah", mock_sm_session ) @@ -236,67 +262,92 @@ def test_get_model_id_version_from_ic_endpoint_without_inference_component_name_ ) -@patch("sagemaker.jumpstart.session_utils._get_model_id_version_from_model_based_endpoint") -def test_get_model_id_version_from_endpoint_non_inference_component_endpoint( - mock_get_model_id_version_from_model_based_endpoint, +@patch("sagemaker.jumpstart.session_utils._get_model_info_from_model_based_endpoint") +def test_get_model_info_from_endpoint_non_inference_component_endpoint( + mock_get_model_info_from_model_based_endpoint, ): mock_sm_session = Mock() mock_sm_session.is_inference_component_based_endpoint.return_value = False - mock_get_model_id_version_from_model_based_endpoint.return_value = ( + mock_get_model_info_from_model_based_endpoint.return_value = ( "model_id", "model_version", + None, ) - retval = get_model_id_version_from_endpoint("blah", sagemaker_session=mock_sm_session) + retval = get_model_info_from_endpoint("blah", sagemaker_session=mock_sm_session) - assert retval == ("model_id", "model_version", None) - mock_get_model_id_version_from_model_based_endpoint.assert_called_once_with( + assert retval == ("model_id", "model_version", None, None) + mock_get_model_info_from_model_based_endpoint.assert_called_once_with( "blah", None, mock_sm_session ) mock_sm_session.is_inference_component_based_endpoint.assert_called_once_with("blah") @patch( - "sagemaker.jumpstart.session_utils._get_model_id_version_from_inference_" + "sagemaker.jumpstart.session_utils._get_model_info_from_inference_" "component_endpoint_with_inference_component_name" ) -def test_get_model_id_version_from_endpoint_inference_component_endpoint_with_inference_component_name( - mock_get_model_id_version_from_inference_component_endpoint_with_inference_component_name, +def test_get_model_info_from_endpoint_inference_component_endpoint_with_inference_component_name( + mock_get_model_info_from_inference_component_endpoint_with_inference_component_name, ): mock_sm_session = Mock() mock_sm_session.is_inference_component_based_endpoint.return_value = True - mock_get_model_id_version_from_inference_component_endpoint_with_inference_component_name.return_value = ( + mock_get_model_info_from_inference_component_endpoint_with_inference_component_name.return_value = ( "model_id", "model_version", + None, ) - retval = get_model_id_version_from_endpoint( + retval = get_model_info_from_endpoint( "blah", inference_component_name="icname", sagemaker_session=mock_sm_session ) - assert retval == ("model_id", "model_version", "icname") - mock_get_model_id_version_from_inference_component_endpoint_with_inference_component_name.assert_called_once_with( + assert retval == ("model_id", "model_version", "icname", None) + mock_get_model_info_from_inference_component_endpoint_with_inference_component_name.assert_called_once_with( "icname", mock_sm_session ) mock_sm_session.is_inference_component_based_endpoint.assert_not_called() @patch( - "sagemaker.jumpstart.session_utils._get_model_id_version_from_inference_component_" + "sagemaker.jumpstart.session_utils._get_model_info_from_inference_component_" + "endpoint_without_inference_component_name" +) +def test_get_model_info_from_endpoint_inference_component_endpoint_without_inference_component_name( + mock_get_model_info_from_inference_component_endpoint_without_inference_component_name, +): + mock_sm_session = Mock() + mock_sm_session.is_inference_component_based_endpoint.return_value = True + mock_get_model_info_from_inference_component_endpoint_without_inference_component_name.return_value = ( + "model_id", + "model_version", + None, + "inferred-icname", + ) + + retval = get_model_info_from_endpoint("blah", sagemaker_session=mock_sm_session) + + assert retval == ("model_id", "model_version", "inferred-icname", None) + mock_get_model_info_from_inference_component_endpoint_without_inference_component_name.assert_called_once() + + +@patch( + "sagemaker.jumpstart.session_utils._get_model_info_from_inference_component_" "endpoint_without_inference_component_name" ) -def test_get_model_id_version_from_endpoint_inference_component_endpoint_without_inference_component_name( - mock_get_model_id_version_from_inference_component_endpoint_without_inference_component_name, +def test_get_model_info_from_endpoint_inference_component_endpoint_with_config_name( + mock_get_model_info_from_inference_component_endpoint_without_inference_component_name, ): mock_sm_session = Mock() mock_sm_session.is_inference_component_based_endpoint.return_value = True - mock_get_model_id_version_from_inference_component_endpoint_without_inference_component_name.return_value = ( + mock_get_model_info_from_inference_component_endpoint_without_inference_component_name.return_value = ( "model_id", "model_version", + "config_name", "inferred-icname", ) - retval = get_model_id_version_from_endpoint("blah", sagemaker_session=mock_sm_session) + retval = get_model_info_from_endpoint("blah", sagemaker_session=mock_sm_session) - assert retval == ("model_id", "model_version", "inferred-icname") - mock_get_model_id_version_from_inference_component_endpoint_without_inference_component_name.assert_called_once() + assert retval == ("model_id", "model_version", "inferred-icname", "config_name") + mock_get_model_info_from_inference_component_endpoint_without_inference_component_name.assert_called_once() diff --git a/tests/unit/sagemaker/jumpstart/test_utils.py b/tests/unit/sagemaker/jumpstart/test_utils.py index 85911a2854..83724e5e8a 100644 --- a/tests/unit/sagemaker/jumpstart/test_utils.py +++ b/tests/unit/sagemaker/jumpstart/test_utils.py @@ -1323,7 +1323,7 @@ def test_no_model_id_no_version_found(self): utils.get_jumpstart_model_id_version_from_resource_arn( "some-arn", mock_sagemaker_session ), - (None, None), + (None, None, None), ) mock_list_tags.assert_called_once_with("some-arn") @@ -1340,7 +1340,7 @@ def test_model_id_no_version_found(self): utils.get_jumpstart_model_id_version_from_resource_arn( "some-arn", mock_sagemaker_session ), - ("model_id", None), + ("model_id", None, None), ) mock_list_tags.assert_called_once_with("some-arn") @@ -1357,7 +1357,38 @@ def test_no_model_id_version_found(self): utils.get_jumpstart_model_id_version_from_resource_arn( "some-arn", mock_sagemaker_session ), - (None, "model_version"), + (None, "model_version", None), + ) + mock_list_tags.assert_called_once_with("some-arn") + + def test_no_config_name_found(self): + mock_list_tags = Mock() + mock_sagemaker_session = Mock() + mock_sagemaker_session.list_tags = mock_list_tags + mock_list_tags.return_value = [{"Key": "blah", "Value": "blah1"}] + + self.assertEquals( + utils.get_jumpstart_model_id_version_from_resource_arn( + "some-arn", mock_sagemaker_session + ), + (None, None, None), + ) + mock_list_tags.assert_called_once_with("some-arn") + + def test_config_name_found(self): + mock_list_tags = Mock() + mock_sagemaker_session = Mock() + mock_sagemaker_session.list_tags = mock_list_tags + mock_list_tags.return_value = [ + {"Key": "blah", "Value": "blah1"}, + {"Key": JumpStartTag.MODEL_CONFIG_NAME, "Value": "config_name"}, + ] + + self.assertEquals( + utils.get_jumpstart_model_id_version_from_resource_arn( + "some-arn", mock_sagemaker_session + ), + (None, None, "config_name"), ) mock_list_tags.assert_called_once_with("some-arn") @@ -1375,7 +1406,7 @@ def test_model_id_version_found(self): utils.get_jumpstart_model_id_version_from_resource_arn( "some-arn", mock_sagemaker_session ), - ("model_id", "model_version"), + ("model_id", "model_version", None), ) mock_list_tags.assert_called_once_with("some-arn") @@ -1395,7 +1426,7 @@ def test_multiple_model_id_versions_found(self): utils.get_jumpstart_model_id_version_from_resource_arn( "some-arn", mock_sagemaker_session ), - (None, None), + (None, None, None), ) mock_list_tags.assert_called_once_with("some-arn") @@ -1415,7 +1446,7 @@ def test_multiple_model_id_versions_found_aliases_consistent(self): utils.get_jumpstart_model_id_version_from_resource_arn( "some-arn", mock_sagemaker_session ), - ("model_id_1", "model_version_1"), + ("model_id_1", "model_version_1", None), ) mock_list_tags.assert_called_once_with("some-arn") @@ -1435,7 +1466,27 @@ def test_multiple_model_id_versions_found_aliases_inconsistent(self): utils.get_jumpstart_model_id_version_from_resource_arn( "some-arn", mock_sagemaker_session ), - (None, None), + (None, None, None), + ) + mock_list_tags.assert_called_once_with("some-arn") + + def test_multiple_config_names_found_aliases_inconsistent(self): + mock_list_tags = Mock() + mock_sagemaker_session = Mock() + mock_sagemaker_session.list_tags = mock_list_tags + mock_list_tags.return_value = [ + {"Key": "blah", "Value": "blah1"}, + {"Key": JumpStartTag.MODEL_ID, "Value": "model_id_1"}, + {"Key": JumpStartTag.MODEL_VERSION, "Value": "model_version_1"}, + {"Key": JumpStartTag.MODEL_CONFIG_NAME, "Value": "config_name_1"}, + {"Key": JumpStartTag.MODEL_CONFIG_NAME, "Value": "config_name_2"}, + ] + + self.assertEquals( + utils.get_jumpstart_model_id_version_from_resource_arn( + "some-arn", mock_sagemaker_session + ), + ("model_id_1", "model_version_1", None), ) mock_list_tags.assert_called_once_with("some-arn") From a8d30e0940dd26aa04569051edc6ab7e77036c56 Mon Sep 17 00:00:00 2001 From: Jonathan Makunga <54963715+makungaj1@users.noreply.github.com> Date: Thu, 25 Apr 2024 12:52:00 -0700 Subject: [PATCH 20/32] ModelBuilder: Add functionalities to get and set deployment config. (#4614) * Add funtionalities to get and set deployment config * Resolve PR comments * ModelBuilder-JS * Add Unit tests * Refactoring * Testing with Notebook * Test backward compatibility * Remove Accelerated column if all not enabled * Fix docstring * Resolved PR Review comments * Docstring * increase code coverage --------- Co-authored-by: Jonathan Makunga --- src/sagemaker/jumpstart/model.py | 34 +++- src/sagemaker/jumpstart/types.py | 18 +- src/sagemaker/jumpstart/utils.py | 28 ++- .../serve/builder/jumpstart_builder.py | 57 ++++-- tests/unit/sagemaker/jumpstart/constants.py | 52 +++--- .../sagemaker/jumpstart/model/test_model.py | 48 ++++++ tests/unit/sagemaker/jumpstart/test_utils.py | 10 +- tests/unit/sagemaker/jumpstart/utils.py | 8 + .../serve/builder/test_js_builder.py | 162 +++++++++++++++++- 9 files changed, 348 insertions(+), 69 deletions(-) diff --git a/src/sagemaker/jumpstart/model.py b/src/sagemaker/jumpstart/model.py index 2addb0a044..f939bc303b 100644 --- a/src/sagemaker/jumpstart/model.py +++ b/src/sagemaker/jumpstart/model.py @@ -15,7 +15,7 @@ from __future__ import absolute_import from functools import lru_cache -from typing import Dict, List, Optional, Union, Any +from typing import Dict, List, Optional, Any, Union import pandas as pd from botocore.exceptions import ClientError @@ -441,6 +441,15 @@ def set_deployment_config(self, config_name: Optional[str]) -> None: model_id=self.model_id, model_version=self.model_version, config_name=config_name ) + @property + def deployment_config(self) -> Optional[Dict[str, Any]]: + """The deployment config that will be applied to the model. + + Returns: + Optional[Dict[str, Any]]: Deployment config that will be applied to the model. + """ + return self._retrieve_selected_deployment_config(self.config_name) + @property def benchmark_metrics(self) -> pd.DataFrame: """Benchmark Metrics for deployment configs @@ -448,7 +457,7 @@ def benchmark_metrics(self) -> pd.DataFrame: Returns: Metrics: Pandas DataFrame object. """ - return pd.DataFrame(self._get_benchmark_data(self.config_name)) + return pd.DataFrame(self._get_benchmarks_data(self.config_name)) def display_benchmark_metrics(self) -> None: """Display Benchmark Metrics for deployment configs.""" @@ -851,8 +860,8 @@ def register_deploy_wrapper(*args, **kwargs): return model_package @lru_cache - def _get_benchmark_data(self, config_name: str) -> Dict[str, List[str]]: - """Constructs deployment configs benchmark data. + def _get_benchmarks_data(self, config_name: str) -> Dict[str, List[str]]: + """Deployment configs benchmark metrics. Args: config_name (str): The name of the selected deployment config. @@ -864,6 +873,23 @@ def _get_benchmark_data(self, config_name: str) -> Dict[str, List[str]]: config_name, ) + @lru_cache + def _retrieve_selected_deployment_config(self, config_name: str) -> Optional[Dict[str, Any]]: + """Retrieve the deployment config to apply to the model. + + Args: + config_name (str): The name of the deployment config to retrieve. + Returns: + Optional[Dict[str, Any]]: The retrieved deployment config. + """ + if config_name is None: + return None + + for deployment_config in self._deployment_configs: + if deployment_config.get("DeploymentConfigName") == config_name: + return deployment_config + return None + def _convert_to_deployment_config_metadata( self, config_name: str, metadata_config: JumpStartMetadataConfig ) -> Dict[str, Any]: diff --git a/src/sagemaker/jumpstart/types.py b/src/sagemaker/jumpstart/types.py index bf0a84319b..65b6b32739 100644 --- a/src/sagemaker/jumpstart/types.py +++ b/src/sagemaker/jumpstart/types.py @@ -2251,17 +2251,17 @@ def to_json(self) -> Dict[str, Any]: return json_obj -class DeploymentConfig(BaseDeploymentConfigDataHolder): +class DeploymentArgs(BaseDeploymentConfigDataHolder): """Dataclass representing a Deployment Config.""" __slots__ = [ - "model_data_download_timeout", - "container_startup_health_check_timeout", "image_uri", "model_data", - "instance_type", "environment", + "instance_type", "compute_resource_requirements", + "model_data_download_timeout", + "container_startup_health_check_timeout", ] def __init__( @@ -2288,9 +2288,10 @@ class DeploymentConfigMetadata(BaseDeploymentConfigDataHolder): """Dataclass representing a Deployment Config Metadata""" __slots__ = [ - "config_name", + "deployment_config_name", + "deployment_args", + "acceleration_configs", "benchmark_metrics", - "deployment_config", ] def __init__( @@ -2301,6 +2302,7 @@ def __init__( deploy_kwargs: JumpStartModelDeployKwargs, ): """Instantiates DeploymentConfigMetadata object.""" - self.config_name = config_name + self.deployment_config_name = config_name + self.deployment_args = DeploymentArgs(init_kwargs, deploy_kwargs) + self.acceleration_configs = None self.benchmark_metrics = benchmark_metrics - self.deployment_config = DeploymentConfig(init_kwargs, deploy_kwargs) diff --git a/src/sagemaker/jumpstart/utils.py b/src/sagemaker/jumpstart/utils.py index 59bf11b415..3fce6dd105 100644 --- a/src/sagemaker/jumpstart/utils.py +++ b/src/sagemaker/jumpstart/utils.py @@ -1040,24 +1040,40 @@ def extract_metrics_from_deployment_configs( config_name (str): The name of the deployment config use by the model. """ - data = {"Config Name": [], "Instance Type": [], "Selected": []} + data = {"Config Name": [], "Instance Type": [], "Selected": [], "Accelerated": []} for index, deployment_config in enumerate(deployment_configs): - if deployment_config.get("DeploymentConfig") is None: + if deployment_config.get("DeploymentArgs") is None: continue benchmark_metrics = deployment_config.get("BenchmarkMetrics") if benchmark_metrics is not None: - data["Config Name"].append(deployment_config.get("ConfigName")) + data["Config Name"].append(deployment_config.get("DeploymentConfigName")) data["Instance Type"].append( - deployment_config.get("DeploymentConfig").get("InstanceType") + deployment_config.get("DeploymentArgs").get("InstanceType") ) data["Selected"].append( "Yes" - if (config_name is not None and config_name == deployment_config.get("ConfigName")) + if ( + config_name is not None + and config_name == deployment_config.get("DeploymentConfigName") + ) else "No" ) + accelerated_configs = deployment_config.get("AccelerationConfigs") + if accelerated_configs is None: + data["Accelerated"].append("No") + else: + data["Accelerated"].append( + "Yes" + if ( + len(accelerated_configs) > 0 + and accelerated_configs[0].get("Enabled", False) + ) + else "No" + ) + if index == 0: for benchmark_metric in benchmark_metrics: column_name = f"{benchmark_metric.get('name')} ({benchmark_metric.get('unit')})" @@ -1068,4 +1084,6 @@ def extract_metrics_from_deployment_configs( if column_name in data.keys(): data[column_name].append(benchmark_metric.get("value")) + if "Yes" not in data["Accelerated"]: + del data["Accelerated"] return data diff --git a/src/sagemaker/serve/builder/jumpstart_builder.py b/src/sagemaker/serve/builder/jumpstart_builder.py index c1760311e7..d3c2581885 100644 --- a/src/sagemaker/serve/builder/jumpstart_builder.py +++ b/src/sagemaker/serve/builder/jumpstart_builder.py @@ -16,7 +16,7 @@ import copy from abc import ABC, abstractmethod from datetime import datetime, timedelta -from typing import Type, Any, List, Dict +from typing import Type, Any, List, Dict, Optional import logging from sagemaker.model import Model @@ -431,8 +431,35 @@ def tune_for_tgi_jumpstart(self, max_tuning_duration: int = 1800): sharded_supported=sharded_supported, max_tuning_duration=max_tuning_duration ) + def set_deployment_config(self, config_name: Optional[str]) -> None: + """Sets the deployment config to apply to the model. + + Args: + config_name (Optional[str]): + The name of the deployment config. Set to None to unset + any existing config that is applied to the model. + """ + if not hasattr(self, "pysdk_model") or self.pysdk_model is None: + raise Exception("Cannot set deployment config to an uninitialized model.") + + self.pysdk_model.set_deployment_config(config_name) + + def get_deployment_config(self) -> Optional[Dict[str, Any]]: + """Gets the deployment config to apply to the model. + + Returns: + Optional[Dict[str, Any]]: Deployment config to apply to this model. + """ + if not hasattr(self, "pysdk_model") or self.pysdk_model is None: + self.pysdk_model = self._create_pre_trained_js_model() + + return self.pysdk_model.deployment_config + def display_benchmark_metrics(self): """Display Markdown Benchmark Metrics for deployment configs.""" + if not hasattr(self, "pysdk_model") or self.pysdk_model is None: + self.pysdk_model = self._create_pre_trained_js_model() + self.pysdk_model.display_benchmark_metrics() def list_deployment_configs(self) -> List[Dict[str, Any]]: @@ -441,6 +468,9 @@ def list_deployment_configs(self) -> List[Dict[str, Any]]: Returns: List[Dict[str, Any]]: A list of deployment configs. """ + if not hasattr(self, "pysdk_model") or self.pysdk_model is None: + self.pysdk_model = self._create_pre_trained_js_model() + return self.pysdk_model.list_deployment_configs() def _build_for_jumpstart(self): @@ -449,32 +479,29 @@ def _build_for_jumpstart(self): self.secret_key = None self.jumpstart = True - pysdk_model = self._create_pre_trained_js_model() - - image_uri = pysdk_model.image_uri + if not hasattr(self, "pysdk_model") or self.pysdk_model is None: + self.pysdk_model = self._create_pre_trained_js_model() - logger.info("JumpStart ID %s is packaged with Image URI: %s", self.model, image_uri) + logger.info( + "JumpStart ID %s is packaged with Image URI: %s", self.model, self.pysdk_model.image_uri + ) - if self._is_gated_model(pysdk_model) and self.mode != Mode.SAGEMAKER_ENDPOINT: + if self._is_gated_model() and self.mode != Mode.SAGEMAKER_ENDPOINT: raise ValueError( "JumpStart Gated Models are only supported in SAGEMAKER_ENDPOINT mode." ) - if "djl-inference" in image_uri: + if "djl-inference" in self.pysdk_model.image_uri: logger.info("Building for DJL JumpStart Model ID...") self.model_server = ModelServer.DJL_SERVING - - self.pysdk_model = pysdk_model self.image_uri = self.pysdk_model.image_uri self._build_for_djl_jumpstart() self.pysdk_model.tune = self.tune_for_djl_jumpstart - elif "tgi-inference" in image_uri: + elif "tgi-inference" in self.pysdk_model.image_uri: logger.info("Building for TGI JumpStart Model ID...") self.model_server = ModelServer.TGI - - self.pysdk_model = pysdk_model self.image_uri = self.pysdk_model.image_uri self._build_for_tgi_jumpstart() @@ -487,15 +514,13 @@ def _build_for_jumpstart(self): return self.pysdk_model - def _is_gated_model(self, model) -> bool: + def _is_gated_model(self) -> bool: """Determine if ``this`` Model is Gated - Args: - model (Model): Jumpstart Model Returns: bool: ``True`` if ``this`` Model is Gated """ - s3_uri = model.model_data + s3_uri = self.pysdk_model.model_data if isinstance(s3_uri, dict): s3_uri = s3_uri.get("S3DataSource").get("S3Uri") diff --git a/tests/unit/sagemaker/jumpstart/constants.py b/tests/unit/sagemaker/jumpstart/constants.py index b83f85ffde..90f037daea 100644 --- a/tests/unit/sagemaker/jumpstart/constants.py +++ b/tests/unit/sagemaker/jumpstart/constants.py @@ -7911,11 +7911,8 @@ DEPLOYMENT_CONFIGS = [ { - "ConfigName": "neuron-inference", - "BenchmarkMetrics": [{"name": "Instance Rate", "value": "0.0083000000", "unit": "USD/Hrs"}], - "DeploymentConfig": { - "ModelDataDownloadTimeout": None, - "ContainerStartupHealthCheckTimeout": None, + "DeploymentConfigName": "neuron-inference", + "DeploymentArgs": { "ImageUri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4" ".0-gpu-py310-cu121-ubuntu20.04", "ModelData": { @@ -7926,7 +7923,6 @@ "CompressionType": "None", } }, - "InstanceType": "ml.p2.xlarge", "Environment": { "SAGEMAKER_PROGRAM": "inference.py", "ENDPOINT_SERVER_TIMEOUT": "3600", @@ -7938,15 +7934,17 @@ "MAX_TOTAL_TOKENS": "2048", "SAGEMAKER_MODEL_SERVER_WORKERS": "1", }, + "InstanceType": "ml.p2.xlarge", "ComputeResourceRequirements": {"MinMemoryRequiredInMb": None}, + "ModelDataDownloadTimeout": None, + "ContainerStartupHealthCheckTimeout": None, }, + "AccelerationConfigs": None, + "BenchmarkMetrics": [{"name": "Instance Rate", "value": "0.0083000000", "unit": "USD/Hrs"}], }, { - "ConfigName": "neuron-inference-budget", - "BenchmarkMetrics": [{"name": "Instance Rate", "value": "0.0083000000", "unit": "USD/Hrs"}], - "DeploymentConfig": { - "ModelDataDownloadTimeout": None, - "ContainerStartupHealthCheckTimeout": None, + "DeploymentConfigName": "neuron-inference-budget", + "DeploymentArgs": { "ImageUri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4" ".0-gpu-py310-cu121-ubuntu20.04", "ModelData": { @@ -7957,7 +7955,6 @@ "CompressionType": "None", } }, - "InstanceType": "ml.p2.xlarge", "Environment": { "SAGEMAKER_PROGRAM": "inference.py", "ENDPOINT_SERVER_TIMEOUT": "3600", @@ -7969,15 +7966,17 @@ "MAX_TOTAL_TOKENS": "2048", "SAGEMAKER_MODEL_SERVER_WORKERS": "1", }, + "InstanceType": "ml.p2.xlarge", "ComputeResourceRequirements": {"MinMemoryRequiredInMb": None}, + "ModelDataDownloadTimeout": None, + "ContainerStartupHealthCheckTimeout": None, }, + "AccelerationConfigs": None, + "BenchmarkMetrics": [{"name": "Instance Rate", "value": "0.0083000000", "unit": "USD/Hrs"}], }, { - "ConfigName": "gpu-inference-budget", - "BenchmarkMetrics": [{"name": "Instance Rate", "value": "0.0083000000", "unit": "USD/Hrs"}], - "DeploymentConfig": { - "ModelDataDownloadTimeout": None, - "ContainerStartupHealthCheckTimeout": None, + "DeploymentConfigName": "gpu-inference-budget", + "DeploymentArgs": { "ImageUri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4" ".0-gpu-py310-cu121-ubuntu20.04", "ModelData": { @@ -7988,7 +7987,6 @@ "CompressionType": "None", } }, - "InstanceType": "ml.p2.xlarge", "Environment": { "SAGEMAKER_PROGRAM": "inference.py", "ENDPOINT_SERVER_TIMEOUT": "3600", @@ -8000,15 +7998,17 @@ "MAX_TOTAL_TOKENS": "2048", "SAGEMAKER_MODEL_SERVER_WORKERS": "1", }, + "InstanceType": "ml.p2.xlarge", "ComputeResourceRequirements": {"MinMemoryRequiredInMb": None}, + "ModelDataDownloadTimeout": None, + "ContainerStartupHealthCheckTimeout": None, }, + "AccelerationConfigs": None, + "BenchmarkMetrics": [{"name": "Instance Rate", "value": "0.0083000000", "unit": "USD/Hrs"}], }, { - "ConfigName": "gpu-inference", - "BenchmarkMetrics": [{"name": "Instance Rate", "value": "0.0083000000", "unit": "USD/Hrs"}], - "DeploymentConfig": { - "ModelDataDownloadTimeout": None, - "ContainerStartupHealthCheckTimeout": None, + "DeploymentConfigName": "gpu-inference", + "DeploymentArgs": { "ImageUri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4" ".0-gpu-py310-cu121-ubuntu20.04", "ModelData": { @@ -8019,7 +8019,6 @@ "CompressionType": "None", } }, - "InstanceType": "ml.p2.xlarge", "Environment": { "SAGEMAKER_PROGRAM": "inference.py", "ENDPOINT_SERVER_TIMEOUT": "3600", @@ -8031,8 +8030,13 @@ "MAX_TOTAL_TOKENS": "2048", "SAGEMAKER_MODEL_SERVER_WORKERS": "1", }, + "InstanceType": "ml.p2.xlarge", "ComputeResourceRequirements": {"MinMemoryRequiredInMb": None}, + "ModelDataDownloadTimeout": None, + "ContainerStartupHealthCheckTimeout": None, }, + "AccelerationConfigs": None, + "BenchmarkMetrics": [{"name": "Instance Rate", "value": "0.0083000000", "unit": "USD/Hrs"}], }, ] diff --git a/tests/unit/sagemaker/jumpstart/model/test_model.py b/tests/unit/sagemaker/jumpstart/model/test_model.py index 476002457b..5d8d048501 100644 --- a/tests/unit/sagemaker/jumpstart/model/test_model.py +++ b/tests/unit/sagemaker/jumpstart/model/test_model.py @@ -1768,6 +1768,54 @@ def test_model_list_deployment_configs_empty( self.assertTrue(len(configs) == 0) + @mock.patch("sagemaker.jumpstart.model.get_init_kwargs") + @mock.patch("sagemaker.jumpstart.utils.verify_model_region_and_return_specs") + @mock.patch("sagemaker.jumpstart.model.get_instance_rate_per_hour") + @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") + @mock.patch("sagemaker.jumpstart.factory.model.Session") + @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") + @mock.patch("sagemaker.jumpstart.model.Model.deploy") + @mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region) + def test_model_retrieve_deployment_config( + self, + mock_model_deploy: mock.Mock, + mock_get_model_specs: mock.Mock, + mock_session: mock.Mock, + mock_get_manifest: mock.Mock, + mock_get_instance_rate_per_hour: mock.Mock, + mock_verify_model_region_and_return_specs: mock.Mock, + mock_get_init_kwargs: mock.Mock, + ): + model_id, _ = "pytorch-eqa-bert-base-cased", "*" + + mock_get_init_kwargs.side_effect = lambda *args, **kwargs: get_mock_init_kwargs(model_id) + mock_verify_model_region_and_return_specs.side_effect = ( + lambda *args, **kwargs: get_base_spec_with_prototype_configs() + ) + mock_get_instance_rate_per_hour.side_effect = lambda *args, **kwargs: { + "name": "Instance Rate", + "unit": "USD/Hrs", + "value": "0.0083000000", + } + mock_get_model_specs.side_effect = get_prototype_spec_with_configs + mock_get_manifest.side_effect = ( + lambda region, model_type, *args, **kwargs: get_prototype_manifest(region, model_type) + ) + mock_model_deploy.return_value = default_predictor + + mock_session.return_value = sagemaker_session + + model = JumpStartModel(model_id=model_id) + + expected = get_base_deployment_configs()[0] + model.set_deployment_config(expected.get("DeploymentConfigName")) + + self.assertEqual(model.deployment_config, expected) + + # Unset + model.set_deployment_config(None) + self.assertIsNone(model.deployment_config) + @mock.patch("sagemaker.jumpstart.model.get_init_kwargs") @mock.patch("sagemaker.jumpstart.utils.verify_model_region_and_return_specs") @mock.patch("sagemaker.jumpstart.model.get_instance_rate_per_hour") diff --git a/tests/unit/sagemaker/jumpstart/test_utils.py b/tests/unit/sagemaker/jumpstart/test_utils.py index 83724e5e8a..210bd8e074 100644 --- a/tests/unit/sagemaker/jumpstart/test_utils.py +++ b/tests/unit/sagemaker/jumpstart/test_utils.py @@ -50,6 +50,7 @@ get_special_model_spec, get_prototype_manifest, get_base_deployment_configs, + get_base_deployment_configs_with_acceleration_configs, ) from mock import MagicMock @@ -1763,10 +1764,11 @@ def test_get_jumpstart_benchmark_stats_training( @pytest.mark.parametrize( - "config_name, expected", + "config_name, configs, expected", [ ( None, + get_base_deployment_configs(), { "Config Name": [ "neuron-inference", @@ -1786,6 +1788,7 @@ def test_get_jumpstart_benchmark_stats_training( ), ( "neuron-inference", + get_base_deployment_configs_with_acceleration_configs(), { "Config Name": [ "neuron-inference", @@ -1795,6 +1798,7 @@ def test_get_jumpstart_benchmark_stats_training( ], "Instance Type": ["ml.p2.xlarge", "ml.p2.xlarge", "ml.p2.xlarge", "ml.p2.xlarge"], "Selected": ["Yes", "No", "No", "No"], + "Accelerated": ["Yes", "No", "No", "No"], "Instance Rate (USD/Hrs)": [ "0.0083000000", "0.0083000000", @@ -1805,7 +1809,7 @@ def test_get_jumpstart_benchmark_stats_training( ), ], ) -def test_extract_metrics_from_deployment_configs(config_name, expected): - data = utils.extract_metrics_from_deployment_configs(get_base_deployment_configs(), config_name) +def test_extract_metrics_from_deployment_configs(config_name, configs, expected): + data = utils.extract_metrics_from_deployment_configs(configs, config_name) assert data == expected diff --git a/tests/unit/sagemaker/jumpstart/utils.py b/tests/unit/sagemaker/jumpstart/utils.py index e0d6f645a8..96662837b4 100644 --- a/tests/unit/sagemaker/jumpstart/utils.py +++ b/tests/unit/sagemaker/jumpstart/utils.py @@ -307,6 +307,14 @@ def get_base_deployment_configs() -> List[Dict[str, Any]]: return DEPLOYMENT_CONFIGS +def get_base_deployment_configs_with_acceleration_configs() -> List[Dict[str, Any]]: + configs = copy.deepcopy(DEPLOYMENT_CONFIGS) + configs[0]["AccelerationConfigs"] = [ + {"Type": "Speculative-Decoding", "Enabled": True, "Spec": {"Version": "0.1"}} + ] + return configs + + def get_mock_init_kwargs(model_id) -> JumpStartModelInitKwargs: return JumpStartModelInitKwargs( model_id=model_id, diff --git a/tests/unit/sagemaker/serve/builder/test_js_builder.py b/tests/unit/sagemaker/serve/builder/test_js_builder.py index 3d5148772e..b83b113209 100644 --- a/tests/unit/sagemaker/serve/builder/test_js_builder.py +++ b/tests/unit/sagemaker/serve/builder/test_js_builder.py @@ -676,13 +676,122 @@ def test_list_deployment_configs( lambda: DEPLOYMENT_CONFIGS ) - model = builder.build() - builder.serve_settings.telemetry_opt_out = True - - configs = model.list_deployment_configs() + configs = builder.list_deployment_configs() self.assertEqual(configs, DEPLOYMENT_CONFIGS) + @patch("sagemaker.serve.builder.jumpstart_builder._capture_telemetry", side_effect=None) + @patch( + "sagemaker.serve.builder.jumpstart_builder.JumpStart._is_jumpstart_model_id", + return_value=True, + ) + @patch( + "sagemaker.serve.builder.jumpstart_builder.JumpStart._create_pre_trained_js_model", + return_value=MagicMock(), + ) + @patch( + "sagemaker.serve.builder.jumpstart_builder.prepare_tgi_js_resources", + return_value=({"model_type": "t5", "n_head": 71}, True), + ) + @patch("sagemaker.serve.builder.jumpstart_builder._get_ram_usage_mb", return_value=1024) + @patch( + "sagemaker.serve.builder.jumpstart_builder._get_nb_instance", return_value="ml.g5.24xlarge" + ) + def test_get_deployment_config( + self, + mock_get_nb_instance, + mock_get_ram_usage_mb, + mock_prepare_for_tgi, + mock_pre_trained_model, + mock_is_jumpstart_model, + mock_telemetry, + ): + builder = ModelBuilder( + model="facebook/galactica-mock-model-id", + schema_builder=mock_schema_builder, + ) + + mock_pre_trained_model.return_value.image_uri = mock_tgi_image_uri + + expected = DEPLOYMENT_CONFIGS[0] + mock_pre_trained_model.return_value.deployment_config = expected + + self.assertEqual(builder.get_deployment_config(), expected) + + @patch("sagemaker.serve.builder.jumpstart_builder._capture_telemetry", side_effect=None) + @patch( + "sagemaker.serve.builder.jumpstart_builder.JumpStart._is_jumpstart_model_id", + return_value=True, + ) + @patch( + "sagemaker.serve.builder.jumpstart_builder.JumpStart._create_pre_trained_js_model", + return_value=MagicMock(), + ) + @patch( + "sagemaker.serve.builder.jumpstart_builder.prepare_tgi_js_resources", + return_value=({"model_type": "t5", "n_head": 71}, True), + ) + @patch("sagemaker.serve.builder.jumpstart_builder._get_ram_usage_mb", return_value=1024) + @patch( + "sagemaker.serve.builder.jumpstart_builder._get_nb_instance", return_value="ml.g5.24xlarge" + ) + def test_set_deployment_config( + self, + mock_get_nb_instance, + mock_get_ram_usage_mb, + mock_prepare_for_tgi, + mock_pre_trained_model, + mock_is_jumpstart_model, + mock_telemetry, + ): + builder = ModelBuilder( + model="facebook/galactica-mock-model-id", + schema_builder=mock_schema_builder, + ) + + mock_pre_trained_model.return_value.image_uri = mock_tgi_image_uri + + builder.build() + builder.set_deployment_config("config-1") + + mock_pre_trained_model.return_value.set_deployment_config.assert_called_with("config-1") + + @patch("sagemaker.serve.builder.jumpstart_builder._capture_telemetry", side_effect=None) + @patch( + "sagemaker.serve.builder.jumpstart_builder.JumpStart._is_jumpstart_model_id", + return_value=True, + ) + @patch( + "sagemaker.serve.builder.jumpstart_builder.JumpStart._create_pre_trained_js_model", + return_value=MagicMock(), + ) + @patch( + "sagemaker.serve.builder.jumpstart_builder.prepare_tgi_js_resources", + return_value=({"model_type": "t5", "n_head": 71}, True), + ) + @patch("sagemaker.serve.builder.jumpstart_builder._get_ram_usage_mb", return_value=1024) + @patch( + "sagemaker.serve.builder.jumpstart_builder._get_nb_instance", return_value="ml.g5.24xlarge" + ) + def test_set_deployment_config_ex( + self, + mock_get_nb_instance, + mock_get_ram_usage_mb, + mock_prepare_for_tgi, + mock_pre_trained_model, + mock_is_jumpstart_model, + mock_telemetry, + ): + mock_pre_trained_model.return_value.image_uri = mock_tgi_image_uri + + self.assertRaisesRegex( + Exception, + "Cannot set deployment config to an uninitialized model.", + lambda: ModelBuilder( + model="facebook/galactica-mock-model-id", schema_builder=mock_schema_builder + ).set_deployment_config("config-2"), + ) + @patch("sagemaker.serve.builder.jumpstart_builder._capture_telemetry", side_effect=None) @patch( "sagemaker.serve.builder.jumpstart_builder.JumpStart._is_jumpstart_model_id", @@ -715,11 +824,46 @@ def test_display_benchmark_metrics( ) mock_pre_trained_model.return_value.image_uri = mock_tgi_image_uri - mock_pre_trained_model.return_value.display_benchmark_metrics.side_effect = ( - lambda *args, **kwargs: "metric data" + mock_pre_trained_model.return_value.list_deployment_configs.side_effect = ( + lambda: DEPLOYMENT_CONFIGS ) - model = builder.build() - builder.serve_settings.telemetry_opt_out = True + builder.list_deployment_configs() + + builder.display_benchmark_metrics() + + mock_pre_trained_model.return_value.display_benchmark_metrics.assert_called_once() + + @patch("sagemaker.serve.builder.jumpstart_builder._capture_telemetry", side_effect=None) + @patch( + "sagemaker.serve.builder.jumpstart_builder.JumpStart._is_jumpstart_model_id", + return_value=True, + ) + @patch( + "sagemaker.serve.builder.jumpstart_builder.JumpStart._create_pre_trained_js_model", + return_value=MagicMock(), + ) + @patch( + "sagemaker.serve.builder.jumpstart_builder.prepare_tgi_js_resources", + return_value=({"model_type": "t5", "n_head": 71}, True), + ) + @patch("sagemaker.serve.builder.jumpstart_builder._get_ram_usage_mb", return_value=1024) + @patch( + "sagemaker.serve.builder.jumpstart_builder._get_nb_instance", return_value="ml.g5.24xlarge" + ) + def test_display_benchmark_metrics_initial( + self, + mock_get_nb_instance, + mock_get_ram_usage_mb, + mock_prepare_for_tgi, + mock_pre_trained_model, + mock_is_jumpstart_model, + mock_telemetry, + ): + builder = ModelBuilder( + model="facebook/galactica-mock-model-id", + schema_builder=mock_schema_builder, + ) + builder.display_benchmark_metrics() - model.display_benchmark_metrics() + mock_pre_trained_model.return_value.display_benchmark_metrics.assert_called_once() From 265cfc8af0668a3b853bd5c38c7659407a2fd217 Mon Sep 17 00:00:00 2001 From: Jonathan Makunga <54963715+makungaj1@users.noreply.github.com> Date: Thu, 25 Apr 2024 16:16:09 -0700 Subject: [PATCH 21/32] Benchmark feature v2 (#4618) * Add funtionalities to get and set deployment config * Resolve PR comments * ModelBuilder-JS * Add Unit tests * Refactoring * Testing with Notebook * Test backward compatibility * Remove Accelerated column if all not enabled * Fix docstring * Resolved PR Review comments * Docstring * increase code coverage * Testing fix with Notebook * Only fetch instance rate metrics if not present * Increase code coverage --------- Co-authored-by: Jonathan Makunga --- src/sagemaker/jumpstart/model.py | 29 ++++++++++++------- src/sagemaker/jumpstart/utils.py | 2 +- .../sagemaker/jumpstart/model/test_model.py | 5 ++-- tests/unit/sagemaker/jumpstart/test_utils.py | 2 +- tests/unit/sagemaker/jumpstart/utils.py | 20 +++++++++++++ 5 files changed, 44 insertions(+), 14 deletions(-) diff --git a/src/sagemaker/jumpstart/model.py b/src/sagemaker/jumpstart/model.py index f939bc303b..d98b9e7dd6 100644 --- a/src/sagemaker/jumpstart/model.py +++ b/src/sagemaker/jumpstart/model.py @@ -48,7 +48,7 @@ validate_model_id_and_get_type, verify_model_region_and_return_specs, get_jumpstart_configs, - extract_metrics_from_deployment_configs, + get_metrics_from_deployment_configs, ) from sagemaker.jumpstart.constants import JUMPSTART_LOGGER from sagemaker.jumpstart.enums import JumpStartModelType @@ -868,7 +868,7 @@ def _get_benchmarks_data(self, config_name: str) -> Dict[str, List[str]]: Returns: Dict[str, List[str]]: Deployment config benchmark data. """ - return extract_metrics_from_deployment_configs( + return get_metrics_from_deployment_configs( self._deployment_configs, config_name, ) @@ -905,20 +905,29 @@ def _convert_to_deployment_config_metadata( "default_inference_instance_type" ) - instance_rate = get_instance_rate_per_hour( - instance_type=default_inference_instance_type, region=self.region - ) - benchmark_metrics = ( metadata_config.benchmark_metrics.get(default_inference_instance_type) if metadata_config.benchmark_metrics is not None else None ) - if instance_rate is not None: - if benchmark_metrics is not None: - benchmark_metrics.append(JumpStartBenchmarkStat(instance_rate)) + + should_fetch_instance_rate_metric = True + if benchmark_metrics is not None: + for benchmark_metric in benchmark_metrics: + if benchmark_metric.name.lower() == "instance rate": + should_fetch_instance_rate_metric = False + break + + if should_fetch_instance_rate_metric: + instance_rate = get_instance_rate_per_hour( + instance_type=default_inference_instance_type, region=self.region + ) + instance_rate_metric = JumpStartBenchmarkStat(instance_rate) + + if benchmark_metrics is None: + benchmark_metrics = [instance_rate_metric] else: - benchmark_metrics = [JumpStartBenchmarkStat(instance_rate)] + benchmark_metrics.append(instance_rate_metric) init_kwargs = get_init_kwargs( model_id=self.model_id, diff --git a/src/sagemaker/jumpstart/utils.py b/src/sagemaker/jumpstart/utils.py index 3fce6dd105..357bdb6eb7 100644 --- a/src/sagemaker/jumpstart/utils.py +++ b/src/sagemaker/jumpstart/utils.py @@ -1030,7 +1030,7 @@ def get_jumpstart_configs( ) -def extract_metrics_from_deployment_configs( +def get_metrics_from_deployment_configs( deployment_configs: List[Dict[str, Any]], config_name: str ) -> Dict[str, List[str]]: """Extracts metrics from deployment configs. diff --git a/tests/unit/sagemaker/jumpstart/model/test_model.py b/tests/unit/sagemaker/jumpstart/model/test_model.py index 5d8d048501..3f07df8bfa 100644 --- a/tests/unit/sagemaker/jumpstart/model/test_model.py +++ b/tests/unit/sagemaker/jumpstart/model/test_model.py @@ -51,6 +51,7 @@ get_base_spec_with_prototype_configs, get_mock_init_kwargs, get_base_deployment_configs, + get_base_spec_with_prototype_configs_with_missing_benchmarks, ) import boto3 @@ -1790,7 +1791,7 @@ def test_model_retrieve_deployment_config( mock_get_init_kwargs.side_effect = lambda *args, **kwargs: get_mock_init_kwargs(model_id) mock_verify_model_region_and_return_specs.side_effect = ( - lambda *args, **kwargs: get_base_spec_with_prototype_configs() + lambda *args, **kwargs: get_base_spec_with_prototype_configs_with_missing_benchmarks() ) mock_get_instance_rate_per_hour.side_effect = lambda *args, **kwargs: { "name": "Instance Rate", @@ -1838,7 +1839,7 @@ def test_model_display_benchmark_metrics( mock_get_init_kwargs.side_effect = lambda *args, **kwargs: get_mock_init_kwargs(model_id) mock_verify_model_region_and_return_specs.side_effect = ( - lambda *args, **kwargs: get_base_spec_with_prototype_configs() + lambda *args, **kwargs: get_base_spec_with_prototype_configs_with_missing_benchmarks() ) mock_get_instance_rate_per_hour.side_effect = lambda *args, **kwargs: { "name": "Instance Rate", diff --git a/tests/unit/sagemaker/jumpstart/test_utils.py b/tests/unit/sagemaker/jumpstart/test_utils.py index 210bd8e074..f576e36185 100644 --- a/tests/unit/sagemaker/jumpstart/test_utils.py +++ b/tests/unit/sagemaker/jumpstart/test_utils.py @@ -1810,6 +1810,6 @@ def test_get_jumpstart_benchmark_stats_training( ], ) def test_extract_metrics_from_deployment_configs(config_name, configs, expected): - data = utils.extract_metrics_from_deployment_configs(configs, config_name) + data = utils.get_metrics_from_deployment_configs(configs, config_name) assert data == expected diff --git a/tests/unit/sagemaker/jumpstart/utils.py b/tests/unit/sagemaker/jumpstart/utils.py index 96662837b4..77913ea73e 100644 --- a/tests/unit/sagemaker/jumpstart/utils.py +++ b/tests/unit/sagemaker/jumpstart/utils.py @@ -226,6 +226,26 @@ def get_base_spec_with_prototype_configs( return JumpStartModelSpecs(spec) +def get_base_spec_with_prototype_configs_with_missing_benchmarks( + region: str = None, + model_id: str = None, + version: str = None, + s3_client: boto3.client = None, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, +) -> JumpStartModelSpecs: + spec = copy.deepcopy(BASE_SPEC) + copy_inference_configs = copy.deepcopy(INFERENCE_CONFIGS) + copy_inference_configs["inference_configs"]["neuron-inference"]["benchmark_metrics"] = None + + inference_configs = {**INFERENCE_CONFIGS, **INFERENCE_CONFIG_RANKINGS} + training_configs = {**TRAINING_CONFIGS, **TRAINING_CONFIG_RANKINGS} + + spec.update(inference_configs) + spec.update(training_configs) + + return JumpStartModelSpecs(spec) + + def get_prototype_spec_with_configs( region: str = None, model_id: str = None, From 76263f5ed7109398be87a50c2dee24a61f504294 Mon Sep 17 00:00:00 2001 From: Haotian An <33510317+Captainia@users.noreply.github.com> Date: Fri, 26 Apr 2024 11:32:07 -0400 Subject: [PATCH 22/32] Merge Master --- src/sagemaker/jumpstart/estimator.py | 7 ++++- src/sagemaker/jumpstart/factory/model.py | 26 +++++++++++++++++++ src/sagemaker/jumpstart/model.py | 2 +- src/sagemaker/jumpstart/types.py | 5 ++++ src/sagemaker/predictor.py | 6 +++-- .../jumpstart/estimator/test_estimator.py | 2 ++ .../sagemaker/jumpstart/model/test_model.py | 21 ++++++++++++--- tests/unit/sagemaker/jumpstart/utils.py | 9 ++++--- 8 files changed, 68 insertions(+), 10 deletions(-) diff --git a/src/sagemaker/jumpstart/estimator.py b/src/sagemaker/jumpstart/estimator.py index 33bb73a83c..b20a5c1ead 100644 --- a/src/sagemaker/jumpstart/estimator.py +++ b/src/sagemaker/jumpstart/estimator.py @@ -739,7 +739,12 @@ def attach( model_version = model_version or "*" - additional_kwargs = {"model_id": model_id, "model_version": model_version} + additional_kwargs = { + "model_id": model_id, + "model_version": model_version, + "tolerate_vulnerable_model": True, # model is already trained + "tolerate_deprecated_model": True, # model is already trained + } model_specs = verify_model_region_and_return_specs( model_id=model_id, diff --git a/src/sagemaker/jumpstart/factory/model.py b/src/sagemaker/jumpstart/factory/model.py index b4f6d70583..68fbf2e861 100644 --- a/src/sagemaker/jumpstart/factory/model.py +++ b/src/sagemaker/jumpstart/factory/model.py @@ -543,6 +543,31 @@ def _add_resources_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModel return kwargs +def _add_config_name_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModelInitKwargs: + """Sets default config name to the kwargs. Returns full kwargs.""" + + specs = verify_model_region_and_return_specs( + model_id=kwargs.model_id, + version=kwargs.model_version, + scope=JumpStartScriptScope.INFERENCE, + region=kwargs.region, + tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, + tolerate_deprecated_model=kwargs.tolerate_deprecated_model, + sagemaker_session=kwargs.sagemaker_session, + model_type=kwargs.model_type, + config_name=kwargs.config_name, + ) + if ( + specs.inference_configs + and specs.inference_configs.get_top_config_from_ranking().config_name + ): + kwargs.config_name = ( + kwargs.config_name or specs.inference_configs.get_top_config_from_ranking().config_name + ) + + return kwargs + + def get_deploy_kwargs( model_id: str, model_version: Optional[str] = None, @@ -808,5 +833,6 @@ def get_init_kwargs( model_init_kwargs = _add_model_package_arn_to_kwargs(kwargs=model_init_kwargs) model_init_kwargs = _add_resources_to_kwargs(kwargs=model_init_kwargs) + model_init_kwargs = _add_config_name_to_kwargs(kwargs=model_init_kwargs) return model_init_kwargs diff --git a/src/sagemaker/jumpstart/model.py b/src/sagemaker/jumpstart/model.py index d98b9e7dd6..f3e18c306d 100644 --- a/src/sagemaker/jumpstart/model.py +++ b/src/sagemaker/jumpstart/model.py @@ -351,7 +351,7 @@ def _validate_model_id_and_type(): self.tolerate_deprecated_model = model_init_kwargs.tolerate_deprecated_model self.region = model_init_kwargs.region self.sagemaker_session = model_init_kwargs.sagemaker_session - self.config_name = config_name + self.config_name = model_init_kwargs.config_name if self.model_type == JumpStartModelType.PROPRIETARY: self.log_subscription_warning() diff --git a/src/sagemaker/jumpstart/types.py b/src/sagemaker/jumpstart/types.py index 65b6b32739..cd74a03e5a 100644 --- a/src/sagemaker/jumpstart/types.py +++ b/src/sagemaker/jumpstart/types.py @@ -1076,10 +1076,12 @@ class JumpStartMetadataConfig(JumpStartDataHolderType): "benchmark_metrics", "config_components", "resolved_metadata_config", + "config_name", ] def __init__( self, + config_name: str, base_fields: Dict[str, Any], config_components: Dict[str, JumpStartConfigComponent], benchmark_metrics: Dict[str, List[JumpStartBenchmarkStat]], @@ -1098,6 +1100,7 @@ def __init__( self.config_components: Dict[str, JumpStartConfigComponent] = config_components self.benchmark_metrics: Dict[str, List[JumpStartBenchmarkStat]] = benchmark_metrics self.resolved_metadata_config: Optional[Dict[str, Any]] = None + self.config_name: Optional[str] = config_name def to_json(self) -> Dict[str, Any]: """Returns json representation of JumpStartMetadataConfig object.""" @@ -1251,6 +1254,7 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: inference_configs_dict: Optional[Dict[str, JumpStartMetadataConfig]] = ( { alias: JumpStartMetadataConfig( + alias, json_obj, ( { @@ -1303,6 +1307,7 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: training_configs_dict: Optional[Dict[str, JumpStartMetadataConfig]] = ( { alias: JumpStartMetadataConfig( + alias, json_obj, ( { diff --git a/src/sagemaker/predictor.py b/src/sagemaker/predictor.py index 14e2ae278b..c5d08eb8f4 100644 --- a/src/sagemaker/predictor.py +++ b/src/sagemaker/predictor.py @@ -43,6 +43,7 @@ def retrieve_default( tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, + config_name: Optional[str] = None, ) -> Predictor: """Retrieves the default predictor for the model matching the given arguments. @@ -65,6 +66,8 @@ def retrieve_default( tolerate_deprecated_model (bool): True if deprecated models should be tolerated (exception not raised). False if these models should raise an exception. (Default: False). + config_name (Optional[str]): The name of the configuration to use for the + predictor. (Default: None) Returns: Predictor: The default predictor to use for the model. @@ -91,10 +94,9 @@ def retrieve_default( model_id = inferred_model_id model_version = model_version or inferred_model_version or "*" inference_component_name = inference_component_name or inferred_inference_component_name - config_name = inferred_config_name or None + config_name = config_name or inferred_config_name or None else: model_version = model_version or "*" - config_name = None predictor = Predictor( endpoint_name=endpoint_name, diff --git a/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py b/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py index be0e06472c..6a9bd9ff10 100644 --- a/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py +++ b/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py @@ -1011,6 +1011,8 @@ def test_jumpstart_estimator_attach_eula_model( additional_kwargs={ "model_id": "gemma-model", "model_version": "*", + "tolerate_vulnerable_model": True, + "tolerate_deprecated_model": True, "environment": {"accept_eula": "true"}, "tolerate_vulnerable_model": True, "tolerate_deprecated_model": True, diff --git a/tests/unit/sagemaker/jumpstart/model/test_model.py b/tests/unit/sagemaker/jumpstart/model/test_model.py index 3f07df8bfa..6d859aecdb 100644 --- a/tests/unit/sagemaker/jumpstart/model/test_model.py +++ b/tests/unit/sagemaker/jumpstart/model/test_model.py @@ -1552,6 +1552,8 @@ def test_model_initialization_with_config_name( model = JumpStartModel(model_id=model_id, config_name="neuron-inference") + assert model.config_name == "neuron-inference" + model.deploy() mock_model_deploy.assert_called_once_with( @@ -1594,6 +1596,8 @@ def test_model_set_deployment_config( model = JumpStartModel(model_id=model_id) + assert model.config_name is None + model.deploy() mock_model_deploy.assert_called_once_with( @@ -1612,6 +1616,8 @@ def test_model_set_deployment_config( mock_get_model_specs.side_effect = get_prototype_spec_with_configs model.set_deployment_config("neuron-inference") + assert model.config_name == "neuron-inference" + model.deploy() mock_model_deploy.assert_called_once_with( @@ -1654,6 +1660,8 @@ def test_model_unset_deployment_config( model = JumpStartModel(model_id=model_id, config_name="neuron-inference") + assert model.config_name == "neuron-inference" + model.deploy() mock_model_deploy.assert_called_once_with( @@ -1789,7 +1797,6 @@ def test_model_retrieve_deployment_config( ): model_id, _ = "pytorch-eqa-bert-base-cased", "*" - mock_get_init_kwargs.side_effect = lambda *args, **kwargs: get_mock_init_kwargs(model_id) mock_verify_model_region_and_return_specs.side_effect = ( lambda *args, **kwargs: get_base_spec_with_prototype_configs_with_missing_benchmarks() ) @@ -1804,15 +1811,23 @@ def test_model_retrieve_deployment_config( ) mock_model_deploy.return_value = default_predictor + expected = get_base_deployment_configs()[0] + config_name = expected.get("DeploymentConfigName") + mock_get_init_kwargs.side_effect = lambda *args, **kwargs: get_mock_init_kwargs( + model_id, config_name + ) + mock_session.return_value = sagemaker_session model = JumpStartModel(model_id=model_id) - expected = get_base_deployment_configs()[0] - model.set_deployment_config(expected.get("DeploymentConfigName")) + model.set_deployment_config(config_name) self.assertEqual(model.deployment_config, expected) + mock_get_init_kwargs.reset_mock() + mock_get_init_kwargs.side_effect = lambda *args, **kwargs: get_mock_init_kwargs(model_id) + # Unset model.set_deployment_config(None) self.assertIsNone(model.deployment_config) diff --git a/tests/unit/sagemaker/jumpstart/utils.py b/tests/unit/sagemaker/jumpstart/utils.py index 77913ea73e..8b814c3d71 100644 --- a/tests/unit/sagemaker/jumpstart/utils.py +++ b/tests/unit/sagemaker/jumpstart/utils.py @@ -12,7 +12,7 @@ # language governing permissions and limitations under the License. from __future__ import absolute_import import copy -from typing import List, Dict, Any +from typing import List, Dict, Any, Optional import boto3 from sagemaker.compute_resource_requirements import ResourceRequirements @@ -237,7 +237,7 @@ def get_base_spec_with_prototype_configs_with_missing_benchmarks( copy_inference_configs = copy.deepcopy(INFERENCE_CONFIGS) copy_inference_configs["inference_configs"]["neuron-inference"]["benchmark_metrics"] = None - inference_configs = {**INFERENCE_CONFIGS, **INFERENCE_CONFIG_RANKINGS} + inference_configs = {**copy_inference_configs, **INFERENCE_CONFIG_RANKINGS} training_configs = {**TRAINING_CONFIGS, **TRAINING_CONFIG_RANKINGS} spec.update(inference_configs) @@ -335,7 +335,9 @@ def get_base_deployment_configs_with_acceleration_configs() -> List[Dict[str, An return configs -def get_mock_init_kwargs(model_id) -> JumpStartModelInitKwargs: +def get_mock_init_kwargs( + model_id: str, config_name: Optional[str] = None +) -> JumpStartModelInitKwargs: return JumpStartModelInitKwargs( model_id=model_id, model_type=JumpStartModelType.OPEN_WEIGHTS, @@ -344,4 +346,5 @@ def get_mock_init_kwargs(model_id) -> JumpStartModelInitKwargs: instance_type=INIT_KWARGS.get("instance_type"), env=INIT_KWARGS.get("env"), resources=ResourceRequirements(), + config_name=config_name, ) From c91cda52f212e03273b9ecc4a497345a203d0da1 Mon Sep 17 00:00:00 2001 From: Jonathan Makunga <54963715+makungaj1@users.noreply.github.com> Date: Fri, 26 Apr 2024 12:25:58 -0700 Subject: [PATCH 23/32] Fix fetch instance rate bug (#4624) Co-authored-by: Jonathan Makunga --- src/sagemaker/jumpstart/model.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/src/sagemaker/jumpstart/model.py b/src/sagemaker/jumpstart/model.py index f3e18c306d..78c36cb954 100644 --- a/src/sagemaker/jumpstart/model.py +++ b/src/sagemaker/jumpstart/model.py @@ -922,12 +922,13 @@ def _convert_to_deployment_config_metadata( instance_rate = get_instance_rate_per_hour( instance_type=default_inference_instance_type, region=self.region ) - instance_rate_metric = JumpStartBenchmarkStat(instance_rate) + if instance_rate is not None: + instance_rate_metric = JumpStartBenchmarkStat(instance_rate) - if benchmark_metrics is None: - benchmark_metrics = [instance_rate_metric] - else: - benchmark_metrics.append(instance_rate_metric) + if benchmark_metrics is None: + benchmark_metrics = [instance_rate_metric] + else: + benchmark_metrics.append(instance_rate_metric) init_kwargs = get_init_kwargs( model_id=self.model_id, From 2b188b6ebe75e007f0a2632e2028db792d7260a0 Mon Sep 17 00:00:00 2001 From: Haotian An <33510317+Captainia@users.noreply.github.com> Date: Fri, 26 Apr 2024 17:21:50 -0400 Subject: [PATCH 24/32] chore: require config name and instance type in set_deployment_config (#4625) * require config_name and instance_type in set config * docstring * add supported instance types check * add more tests * format * fix tests --- src/sagemaker/jumpstart/factory/model.py | 22 ++++++- src/sagemaker/jumpstart/model.py | 16 +++-- .../sagemaker/jumpstart/model/test_model.py | 65 ++++++++++++------- 3 files changed, 72 insertions(+), 31 deletions(-) diff --git a/src/sagemaker/jumpstart/factory/model.py b/src/sagemaker/jumpstart/factory/model.py index 68fbf2e861..53508b56f3 100644 --- a/src/sagemaker/jumpstart/factory/model.py +++ b/src/sagemaker/jumpstart/factory/model.py @@ -544,7 +544,11 @@ def _add_resources_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModel def _add_config_name_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModelInitKwargs: - """Sets default config name to the kwargs. Returns full kwargs.""" + """Sets default config name to the kwargs. Returns full kwargs. + + Raises: + ValueError: If the instance_type is not supported with the current config. + """ specs = verify_model_region_and_return_specs( model_id=kwargs.model_id, @@ -565,6 +569,22 @@ def _add_config_name_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartMod kwargs.config_name or specs.inference_configs.get_top_config_from_ranking().config_name ) + if not kwargs.config_name: + return kwargs + + if kwargs.config_name not in set(specs.inference_configs.configs.keys()): + raise ValueError( + f"Config {kwargs.config_name} is not supported for model {kwargs.model_id}." + ) + + resolved_config = specs.inference_configs.configs[kwargs.config_name].resolved_config + supported_instance_types = resolved_config.get("supported_inference_instance_types", []) + if kwargs.instance_type not in supported_instance_types: + raise ValueError( + f"Instance type {kwargs.instance_type} " + f"is not supported for config {kwargs.config_name}." + ) + return kwargs diff --git a/src/sagemaker/jumpstart/model.py b/src/sagemaker/jumpstart/model.py index 78c36cb954..8b1badb94b 100644 --- a/src/sagemaker/jumpstart/model.py +++ b/src/sagemaker/jumpstart/model.py @@ -429,16 +429,22 @@ def retrieve_example_payload(self) -> JumpStartSerializablePayload: sagemaker_session=self.sagemaker_session, ) - def set_deployment_config(self, config_name: Optional[str]) -> None: + def set_deployment_config(self, config_name: str, instance_type: str) -> None: """Sets the deployment config to apply to the model. Args: - config_name (Optional[str]): - The name of the deployment config. Set to None to unset - any existing config that is applied to the model. + config_name (str): + The name of the deployment config to apply to the model. + Call list_deployment_configs to see the list of config names. + instance_type (str): + The instance_type that the model will use after setting + the config. """ self.__init__( - model_id=self.model_id, model_version=self.model_version, config_name=config_name + model_id=self.model_id, + model_version=self.model_version, + instance_type=instance_type, + config_name=config_name, ) @property diff --git a/tests/unit/sagemaker/jumpstart/model/test_model.py b/tests/unit/sagemaker/jumpstart/model/test_model.py index 6d859aecdb..5bbc31a5b1 100644 --- a/tests/unit/sagemaker/jumpstart/model/test_model.py +++ b/tests/unit/sagemaker/jumpstart/model/test_model.py @@ -1614,7 +1614,25 @@ def test_model_set_deployment_config( mock_get_model_specs.reset_mock() mock_model_deploy.reset_mock() mock_get_model_specs.side_effect = get_prototype_spec_with_configs - model.set_deployment_config("neuron-inference") + model.set_deployment_config("neuron-inference", "ml.inf2.2xlarge") + + assert model.config_name == "neuron-inference" + + model.deploy() + + mock_model_deploy.assert_called_once_with( + initial_instance_count=1, + instance_type="ml.inf2.2xlarge", + tags=[ + {"Key": JumpStartTag.MODEL_ID, "Value": "pytorch-eqa-bert-base-cased"}, + {"Key": JumpStartTag.MODEL_VERSION, "Value": "1.0.0"}, + {"Key": JumpStartTag.MODEL_CONFIG_NAME, "Value": "neuron-inference"}, + ], + wait=True, + endpoint_logging=False, + ) + mock_model_deploy.reset_mock() + model.set_deployment_config("neuron-inference", "ml.inf2.xlarge") assert model.config_name == "neuron-inference" @@ -1640,7 +1658,7 @@ def test_model_set_deployment_config( @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.model.Model.deploy") @mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region) - def test_model_unset_deployment_config( + def test_model_set_deployment_config_incompatible_instance_type_or_name( self, mock_model_deploy: mock.Mock, mock_get_model_specs: mock.Mock, @@ -1648,7 +1666,7 @@ def test_model_unset_deployment_config( mock_get_manifest: mock.Mock, mock_get_jumpstart_configs: mock.Mock, ): - mock_get_model_specs.side_effect = get_prototype_spec_with_configs + mock_get_model_specs.side_effect = get_prototype_model_spec mock_get_manifest.side_effect = ( lambda region, model_type, *args, **kwargs: get_prototype_manifest(region, model_type) ) @@ -1658,19 +1676,18 @@ def test_model_unset_deployment_config( mock_session.return_value = sagemaker_session - model = JumpStartModel(model_id=model_id, config_name="neuron-inference") + model = JumpStartModel(model_id=model_id) - assert model.config_name == "neuron-inference" + assert model.config_name is None model.deploy() mock_model_deploy.assert_called_once_with( initial_instance_count=1, - instance_type="ml.inf2.xlarge", + instance_type="ml.p2.xlarge", tags=[ {"Key": JumpStartTag.MODEL_ID, "Value": "pytorch-eqa-bert-base-cased"}, {"Key": JumpStartTag.MODEL_VERSION, "Value": "1.0.0"}, - {"Key": JumpStartTag.MODEL_CONFIG_NAME, "Value": "neuron-inference"}, ], wait=True, endpoint_logging=False, @@ -1678,20 +1695,21 @@ def test_model_unset_deployment_config( mock_get_model_specs.reset_mock() mock_model_deploy.reset_mock() - mock_get_model_specs.side_effect = get_prototype_model_spec - model.set_deployment_config(None) - - model.deploy() + mock_get_model_specs.side_effect = get_prototype_spec_with_configs + with pytest.raises(ValueError) as error: + model.set_deployment_config("neuron-inference", "ml.inf2.32xlarge") + assert ( + "Instance type ml.inf2.32xlarge is not supported for config neuron-inference." + in str(error) + ) - mock_model_deploy.assert_called_once_with( - initial_instance_count=1, - instance_type="ml.p2.xlarge", - tags=[ - {"Key": JumpStartTag.MODEL_ID, "Value": "pytorch-eqa-bert-base-cased"}, - {"Key": JumpStartTag.MODEL_VERSION, "Value": "1.0.0"}, - ], - wait=True, - endpoint_logging=False, + with pytest.raises(ValueError) as error: + model.set_deployment_config("neuron-inference-unknown-name", "ml.inf2.32xlarge") + assert ( + "Cannot find Jumpstart config name neuron-inference-unknown-name. " + "List of config names that is supported by the model: " + "['neuron-inference', 'neuron-inference-budget', 'gpu-inference-budget', 'gpu-inference']" + in str(error) ) @mock.patch("sagemaker.jumpstart.model.get_init_kwargs") @@ -1813,6 +1831,7 @@ def test_model_retrieve_deployment_config( expected = get_base_deployment_configs()[0] config_name = expected.get("DeploymentConfigName") + instance_type = expected.get("InstanceType") mock_get_init_kwargs.side_effect = lambda *args, **kwargs: get_mock_init_kwargs( model_id, config_name ) @@ -1821,17 +1840,13 @@ def test_model_retrieve_deployment_config( model = JumpStartModel(model_id=model_id) - model.set_deployment_config(config_name) + model.set_deployment_config(config_name, instance_type) self.assertEqual(model.deployment_config, expected) mock_get_init_kwargs.reset_mock() mock_get_init_kwargs.side_effect = lambda *args, **kwargs: get_mock_init_kwargs(model_id) - # Unset - model.set_deployment_config(None) - self.assertIsNone(model.deployment_config) - @mock.patch("sagemaker.jumpstart.model.get_init_kwargs") @mock.patch("sagemaker.jumpstart.utils.verify_model_region_and_return_specs") @mock.patch("sagemaker.jumpstart.model.get_instance_rate_per_hour") From db6967208a0f0b8e2048a6dfe17296a87395f516 Mon Sep 17 00:00:00 2001 From: Jonathan Makunga <54963715+makungaj1@users.noreply.github.com> Date: Sun, 28 Apr 2024 14:14:29 -0700 Subject: [PATCH 25/32] Deployment Configs - Follow-ups (#4626) * Init Deployment configs outside Model init. * Testing with NB * Testing with NB-V2 * Refactoring, NB testing * NB Testing and Refactoring * Testing * Refactoring * Testing with NB * Debug * Debug display API * Debug with NB * Testing with NB * Refactoring * Refactoring * Refactoring and NB testing * Testing with NB * Refactoring * Prefix instance type with ml * Fix unit tests --------- Co-authored-by: Jonathan Makunga --- src/sagemaker/jumpstart/model.py | 160 ++++++++++-------- src/sagemaker/jumpstart/types.py | 72 +++++--- src/sagemaker/jumpstart/utils.py | 142 +++++++++++----- .../serve/builder/jumpstart_builder.py | 13 +- src/sagemaker/utils.py | 56 +++--- .../sagemaker/jumpstart/model/test_model.py | 76 +++------ tests/unit/sagemaker/jumpstart/test_types.py | 39 +++++ tests/unit/sagemaker/jumpstart/test_utils.py | 143 +++++++++++----- tests/unit/sagemaker/jumpstart/utils.py | 64 ++++++- .../serve/builder/test_js_builder.py | 8 +- tests/unit/test_utils.py | 116 ++++++++++--- 11 files changed, 582 insertions(+), 307 deletions(-) diff --git a/src/sagemaker/jumpstart/model.py b/src/sagemaker/jumpstart/model.py index 8b1badb94b..619af2f7a9 100644 --- a/src/sagemaker/jumpstart/model.py +++ b/src/sagemaker/jumpstart/model.py @@ -41,18 +41,17 @@ from sagemaker.jumpstart.types import ( JumpStartSerializablePayload, DeploymentConfigMetadata, - JumpStartBenchmarkStat, - JumpStartMetadataConfig, ) from sagemaker.jumpstart.utils import ( validate_model_id_and_get_type, verify_model_region_and_return_specs, get_jumpstart_configs, get_metrics_from_deployment_configs, + add_instance_rate_stats_to_benchmark_metrics, ) from sagemaker.jumpstart.constants import JUMPSTART_LOGGER from sagemaker.jumpstart.enums import JumpStartModelType -from sagemaker.utils import stringify_object, format_tags, Tags, get_instance_rate_per_hour +from sagemaker.utils import stringify_object, format_tags, Tags from sagemaker.model import ( Model, ModelPackage, @@ -361,17 +360,13 @@ def _validate_model_id_and_type(): self.model_package_arn = model_init_kwargs.model_package_arn self.init_kwargs = model_init_kwargs.to_kwargs_dict(False) - metadata_configs = get_jumpstart_configs( + self._metadata_configs = get_jumpstart_configs( region=self.region, model_id=self.model_id, model_version=self.model_version, sagemaker_session=self.sagemaker_session, model_type=self.model_type, ) - self._deployment_configs = [ - self._convert_to_deployment_config_metadata(config_name, config) - for config_name, config in metadata_configs.items() - ] def log_subscription_warning(self) -> None: """Log message prompting the customer to subscribe to the proprietary model.""" @@ -449,25 +444,33 @@ def set_deployment_config(self, config_name: str, instance_type: str) -> None: @property def deployment_config(self) -> Optional[Dict[str, Any]]: - """The deployment config that will be applied to the model. + """The deployment config that will be applied to ``This`` model. Returns: - Optional[Dict[str, Any]]: Deployment config that will be applied to the model. + Optional[Dict[str, Any]]: Deployment config. """ - return self._retrieve_selected_deployment_config(self.config_name) + deployment_config = self._retrieve_selected_deployment_config( + self.config_name, self.instance_type + ) + return deployment_config.to_json() if deployment_config is not None else None @property def benchmark_metrics(self) -> pd.DataFrame: - """Benchmark Metrics for deployment configs + """Benchmark Metrics for deployment configs. Returns: - Metrics: Pandas DataFrame object. + Benchmark Metrics: Pandas DataFrame object. """ - return pd.DataFrame(self._get_benchmarks_data(self.config_name)) + benchmark_metrics_data = self._get_deployment_configs_benchmarks_data( + self.config_name, self.instance_type + ) + keys = list(benchmark_metrics_data.keys()) + df = pd.DataFrame(benchmark_metrics_data).sort_values(by=[keys[0], keys[1]]) + return df def display_benchmark_metrics(self) -> None: - """Display Benchmark Metrics for deployment configs.""" - print(self.benchmark_metrics.to_markdown()) + """Display deployment configs benchmark metrics.""" + print(self.benchmark_metrics.to_markdown(index=False)) def list_deployment_configs(self) -> List[Dict[str, Any]]: """List deployment configs for ``This`` model. @@ -475,7 +478,12 @@ def list_deployment_configs(self) -> List[Dict[str, Any]]: Returns: List[Dict[str, Any]]: A list of deployment configs. """ - return self._deployment_configs + return [ + deployment_config.to_json() + for deployment_config in self._get_deployment_configs( + self.config_name, self.instance_type + ) + ] def _create_sagemaker_model( self, @@ -866,92 +874,94 @@ def register_deploy_wrapper(*args, **kwargs): return model_package @lru_cache - def _get_benchmarks_data(self, config_name: str) -> Dict[str, List[str]]: + def _get_deployment_configs_benchmarks_data( + self, config_name: str, instance_type: str + ) -> Dict[str, Any]: """Deployment configs benchmark metrics. Args: - config_name (str): The name of the selected deployment config. + config_name (str): Name of selected deployment config. + instance_type (str): The selected Instance type. Returns: Dict[str, List[str]]: Deployment config benchmark data. """ return get_metrics_from_deployment_configs( - self._deployment_configs, - config_name, + self._get_deployment_configs(config_name, instance_type) ) @lru_cache - def _retrieve_selected_deployment_config(self, config_name: str) -> Optional[Dict[str, Any]]: - """Retrieve the deployment config to apply to the model. + def _retrieve_selected_deployment_config( + self, config_name: str, instance_type: str + ) -> Optional[DeploymentConfigMetadata]: + """Retrieve the deployment config to apply to `This` model. Args: config_name (str): The name of the deployment config to retrieve. + instance_type (str): The instance type of the deployment config to retrieve. Returns: Optional[Dict[str, Any]]: The retrieved deployment config. """ if config_name is None: return None - for deployment_config in self._deployment_configs: - if deployment_config.get("DeploymentConfigName") == config_name: + for deployment_config in self._get_deployment_configs(config_name, instance_type): + if deployment_config.deployment_config_name == config_name: return deployment_config return None - def _convert_to_deployment_config_metadata( - self, config_name: str, metadata_config: JumpStartMetadataConfig - ) -> Dict[str, Any]: - """Retrieve deployment config for config name. + @lru_cache + def _get_deployment_configs( + self, selected_config_name: str, selected_instance_type: str + ) -> List[DeploymentConfigMetadata]: + """Retrieve deployment configs metadata. Args: - config_name (str): Name of deployment config. - metadata_config (JumpStartMetadataConfig): Metadata config for deployment config. - Returns: - A deployment metadata config for config name (dict[str, Any]). + selected_config_name (str): The name of the selected deployment config. + selected_instance_type (str): The selected instance type. """ - default_inference_instance_type = metadata_config.resolved_config.get( - "default_inference_instance_type" - ) - - benchmark_metrics = ( - metadata_config.benchmark_metrics.get(default_inference_instance_type) - if metadata_config.benchmark_metrics is not None - else None - ) - - should_fetch_instance_rate_metric = True - if benchmark_metrics is not None: - for benchmark_metric in benchmark_metrics: - if benchmark_metric.name.lower() == "instance rate": - should_fetch_instance_rate_metric = False - break - - if should_fetch_instance_rate_metric: - instance_rate = get_instance_rate_per_hour( - instance_type=default_inference_instance_type, region=self.region + deployment_configs = [] + if self._metadata_configs is None: + return deployment_configs + + err = None + for config_name, metadata_config in self._metadata_configs.items(): + if err is None or "is not authorized to perform: pricing:GetProducts" not in err: + err, metadata_config.benchmark_metrics = ( + add_instance_rate_stats_to_benchmark_metrics( + self.region, metadata_config.benchmark_metrics + ) + ) + + resolved_config = metadata_config.resolved_config + if selected_config_name == config_name: + instance_type_to_use = selected_instance_type + else: + instance_type_to_use = resolved_config.get("default_inference_instance_type") + + init_kwargs = get_init_kwargs( + model_id=self.model_id, + instance_type=instance_type_to_use, + sagemaker_session=self.sagemaker_session, ) - if instance_rate is not None: - instance_rate_metric = JumpStartBenchmarkStat(instance_rate) - - if benchmark_metrics is None: - benchmark_metrics = [instance_rate_metric] - else: - benchmark_metrics.append(instance_rate_metric) - - init_kwargs = get_init_kwargs( - model_id=self.model_id, - instance_type=default_inference_instance_type, - sagemaker_session=self.sagemaker_session, - ) - deploy_kwargs = get_deploy_kwargs( - model_id=self.model_id, - instance_type=default_inference_instance_type, - sagemaker_session=self.sagemaker_session, - ) + deploy_kwargs = get_deploy_kwargs( + model_id=self.model_id, + instance_type=instance_type_to_use, + sagemaker_session=self.sagemaker_session, + ) + deployment_config_metadata = DeploymentConfigMetadata( + config_name, + metadata_config.benchmark_metrics, + resolved_config, + init_kwargs, + deploy_kwargs, + ) + deployment_configs.append(deployment_config_metadata) - deployment_config_metadata = DeploymentConfigMetadata( - config_name, benchmark_metrics, init_kwargs, deploy_kwargs - ) + if err is not None and "is not authorized to perform: pricing:GetProducts" in err: + error_message = "Instance rate metrics will be omitted. Reason: %s" + JUMPSTART_LOGGER.warning(error_message, err) - return deployment_config_metadata.to_json() + return deployment_configs def __str__(self) -> str: """Overriding str(*) method to make more human-readable.""" diff --git a/src/sagemaker/jumpstart/types.py b/src/sagemaker/jumpstart/types.py index cd74a03e5a..e0a0f9bea7 100644 --- a/src/sagemaker/jumpstart/types.py +++ b/src/sagemaker/jumpstart/types.py @@ -2235,29 +2235,37 @@ def to_json(self) -> Dict[str, Any]: if hasattr(self, att): cur_val = getattr(self, att) att = self._convert_to_pascal_case(att) - if issubclass(type(cur_val), JumpStartDataHolderType): - json_obj[att] = cur_val.to_json() - elif isinstance(cur_val, list): - json_obj[att] = [] - for obj in cur_val: - if issubclass(type(obj), JumpStartDataHolderType): - json_obj[att].append(obj.to_json()) - else: - json_obj[att].append(obj) - elif isinstance(cur_val, dict): - json_obj[att] = {} - for key, val in cur_val.items(): - if issubclass(type(val), JumpStartDataHolderType): - json_obj[att][self._convert_to_pascal_case(key)] = val.to_json() - else: - json_obj[att][key] = val - else: - json_obj[att] = cur_val + json_obj[att] = self._val_to_json(cur_val) return json_obj + def _val_to_json(self, val: Any) -> Any: + """Converts the given value to JSON. + + Args: + val (Any): The value to convert. + Returns: + Any: The converted json value. + """ + if issubclass(type(val), JumpStartDataHolderType): + return val.to_json() + if isinstance(val, list): + list_obj = [] + for obj in val: + list_obj.append(self._val_to_json(obj)) + return list_obj + if isinstance(val, dict): + dict_obj = {} + for k, v in val.items(): + if isinstance(v, JumpStartDataHolderType): + dict_obj[self._convert_to_pascal_case(k)] = self._val_to_json(v) + else: + dict_obj[k] = self._val_to_json(v) + return dict_obj + return val + class DeploymentArgs(BaseDeploymentConfigDataHolder): - """Dataclass representing a Deployment Config.""" + """Dataclass representing a Deployment Args.""" __slots__ = [ "image_uri", @@ -2270,9 +2278,12 @@ class DeploymentArgs(BaseDeploymentConfigDataHolder): ] def __init__( - self, init_kwargs: JumpStartModelInitKwargs, deploy_kwargs: JumpStartModelDeployKwargs + self, + init_kwargs: Optional[JumpStartModelInitKwargs] = None, + deploy_kwargs: Optional[JumpStartModelDeployKwargs] = None, + resolved_config: Optional[Dict[str, Any]] = None, ): - """Instantiates DeploymentConfig object.""" + """Instantiates DeploymentArgs object.""" if init_kwargs is not None: self.image_uri = init_kwargs.image_uri self.model_data = init_kwargs.model_data @@ -2287,6 +2298,11 @@ def __init__( self.container_startup_health_check_timeout = ( deploy_kwargs.container_startup_health_check_timeout ) + if resolved_config is not None: + self.default_instance_type = resolved_config.get("default_inference_instance_type") + self.supported_instance_types = resolved_config.get( + "supported_inference_instance_types" + ) class DeploymentConfigMetadata(BaseDeploymentConfigDataHolder): @@ -2301,13 +2317,15 @@ class DeploymentConfigMetadata(BaseDeploymentConfigDataHolder): def __init__( self, - config_name: str, - benchmark_metrics: List[JumpStartBenchmarkStat], - init_kwargs: JumpStartModelInitKwargs, - deploy_kwargs: JumpStartModelDeployKwargs, + config_name: Optional[str] = None, + benchmark_metrics: Optional[Dict[str, List[JumpStartBenchmarkStat]]] = None, + resolved_config: Optional[Dict[str, Any]] = None, + init_kwargs: Optional[JumpStartModelInitKwargs] = None, + deploy_kwargs: Optional[JumpStartModelDeployKwargs] = None, ): """Instantiates DeploymentConfigMetadata object.""" self.deployment_config_name = config_name - self.deployment_args = DeploymentArgs(init_kwargs, deploy_kwargs) - self.acceleration_configs = None + self.deployment_args = DeploymentArgs(init_kwargs, deploy_kwargs, resolved_config) self.benchmark_metrics = benchmark_metrics + if resolved_config is not None: + self.acceleration_configs = resolved_config.get("acceleration_configs") diff --git a/src/sagemaker/jumpstart/utils.py b/src/sagemaker/jumpstart/utils.py index 357bdb6eb7..a8c4bd7c21 100644 --- a/src/sagemaker/jumpstart/utils.py +++ b/src/sagemaker/jumpstart/utils.py @@ -17,6 +17,7 @@ from typing import Any, Dict, List, Set, Optional, Tuple, Union from urllib.parse import urlparse import boto3 +from botocore.exceptions import ClientError from packaging.version import Version import sagemaker from sagemaker.config.config_schema import ( @@ -41,10 +42,11 @@ JumpStartModelHeader, JumpStartModelSpecs, JumpStartVersionedModelId, + DeploymentConfigMetadata, ) from sagemaker.session import Session from sagemaker.config import load_sagemaker_config -from sagemaker.utils import resolve_value_from_config, TagsDict +from sagemaker.utils import resolve_value_from_config, TagsDict, get_instance_rate_per_hour from sagemaker.workflow import is_pipeline_variable @@ -1030,60 +1032,110 @@ def get_jumpstart_configs( ) -def get_metrics_from_deployment_configs( - deployment_configs: List[Dict[str, Any]], config_name: str -) -> Dict[str, List[str]]: - """Extracts metrics from deployment configs. +def add_instance_rate_stats_to_benchmark_metrics( + region: str, + benchmark_metrics: Optional[Dict[str, List[JumpStartBenchmarkStat]]], +) -> Optional[Tuple[str, Dict[str, List[JumpStartBenchmarkStat]]]]: + """Adds instance types metric stats to the given benchmark_metrics dict. Args: - deployment_configs (list[dict[str, Any]]): List of deployment configs. - config_name (str): The name of the deployment config use by the model. + region (str): AWS region. + benchmark_metrics (Dict[str, List[JumpStartBenchmarkStat]]): + Returns: + Tuple[str, Dict[str, List[JumpStartBenchmarkStat]]]: + Contains Error message and metrics dict. """ - data = {"Config Name": [], "Instance Type": [], "Selected": [], "Accelerated": []} + if benchmark_metrics is None: + return None + + final_benchmark_metrics = {} - for index, deployment_config in enumerate(deployment_configs): - if deployment_config.get("DeploymentArgs") is None: - continue + err_message = None + for instance_type, benchmark_metric_stats in benchmark_metrics.items(): + instance_type = instance_type if instance_type.startswith("ml.") else f"ml.{instance_type}" - benchmark_metrics = deployment_config.get("BenchmarkMetrics") - if benchmark_metrics is not None: - data["Config Name"].append(deployment_config.get("DeploymentConfigName")) - data["Instance Type"].append( - deployment_config.get("DeploymentArgs").get("InstanceType") - ) - data["Selected"].append( - "Yes" - if ( - config_name is not None - and config_name == deployment_config.get("DeploymentConfigName") + if not has_instance_rate_stat(benchmark_metric_stats): + try: + instance_type_rate = get_instance_rate_per_hour( + instance_type=instance_type, region=region ) - else "No" - ) - accelerated_configs = deployment_config.get("AccelerationConfigs") - if accelerated_configs is None: - data["Accelerated"].append("No") - else: - data["Accelerated"].append( - "Yes" - if ( - len(accelerated_configs) > 0 - and accelerated_configs[0].get("Enabled", False) - ) - else "No" + benchmark_metric_stats.append(JumpStartBenchmarkStat(instance_type_rate)) + final_benchmark_metrics[instance_type] = benchmark_metric_stats + + except ClientError as e: + final_benchmark_metrics[instance_type] = benchmark_metric_stats + err_message = e.response["Error"]["Message"] + except Exception: # pylint: disable=W0703 + final_benchmark_metrics[instance_type] = benchmark_metric_stats + err_message = ( + f"Unable to get instance rate per hour for instance type: {instance_type}." ) - if index == 0: - for benchmark_metric in benchmark_metrics: - column_name = f"{benchmark_metric.get('name')} ({benchmark_metric.get('unit')})" - data[column_name] = [] + return err_message, final_benchmark_metrics + + +def has_instance_rate_stat(benchmark_metric_stats: Optional[List[JumpStartBenchmarkStat]]) -> bool: + """Determines whether a benchmark metric stats contains instance rate metric stat. + + Args: + benchmark_metric_stats (Optional[List[JumpStartBenchmarkStat]]): + List of benchmark metric stats. + Returns: + bool: Whether the benchmark metric stats contains instance rate metric stat. + """ + if benchmark_metric_stats is None: + return False + + for benchmark_metric_stat in benchmark_metric_stats: + if benchmark_metric_stat.name.lower() == "instance rate": + return True + + return False + - for benchmark_metric in benchmark_metrics: - column_name = f"{benchmark_metric.get('name')} ({benchmark_metric.get('unit')})" - if column_name in data.keys(): - data[column_name].append(benchmark_metric.get("value")) +def get_metrics_from_deployment_configs( + deployment_configs: List[DeploymentConfigMetadata], +) -> Dict[str, List[str]]: + """Extracts benchmark metrics from deployment configs metadata. - if "Yes" not in data["Accelerated"]: - del data["Accelerated"] + Args: + deployment_configs (List[DeploymentConfigMetadata]): List of deployment configs metadata. + """ + data = {"Config Name": [], "Instance Type": []} + + for outer_index, deployment_config in enumerate(deployment_configs): + if deployment_config.deployment_args is None: + continue + + benchmark_metrics = deployment_config.benchmark_metrics + if benchmark_metrics is None: + continue + + for inner_index, current_instance_type in enumerate(benchmark_metrics): + current_instance_type_metrics = benchmark_metrics[current_instance_type] + + data["Config Name"].append(deployment_config.deployment_config_name) + instance_type_to_display = ( + f"{current_instance_type} (Default)" + if current_instance_type == deployment_config.deployment_args.default_instance_type + else current_instance_type + ) + data["Instance Type"].append(instance_type_to_display) + + if outer_index == 0 and inner_index == 0: + temp_data = {} + for metric in current_instance_type_metrics: + column_name = f"{metric.name.replace('_', ' ').title()} ({metric.unit})" + if metric.name.lower() == "instance rate": + data[column_name] = [] + else: + temp_data[column_name] = [] + data = {**data, **temp_data} + + for metric in current_instance_type_metrics: + column_name = f"{metric.name.replace('_', ' ').title()} ({metric.unit})" + if column_name in data: + data[column_name].append(metric.value) return data diff --git a/src/sagemaker/serve/builder/jumpstart_builder.py b/src/sagemaker/serve/builder/jumpstart_builder.py index d3c2581885..ec987dd9fe 100644 --- a/src/sagemaker/serve/builder/jumpstart_builder.py +++ b/src/sagemaker/serve/builder/jumpstart_builder.py @@ -431,18 +431,21 @@ def tune_for_tgi_jumpstart(self, max_tuning_duration: int = 1800): sharded_supported=sharded_supported, max_tuning_duration=max_tuning_duration ) - def set_deployment_config(self, config_name: Optional[str]) -> None: + def set_deployment_config(self, config_name: str, instance_type: str) -> None: """Sets the deployment config to apply to the model. Args: - config_name (Optional[str]): - The name of the deployment config. Set to None to unset - any existing config that is applied to the model. + config_name (str): + The name of the deployment config to apply to the model. + Call list_deployment_configs to see the list of config names. + instance_type (str): + The instance_type that the model will use after setting + the config. """ if not hasattr(self, "pysdk_model") or self.pysdk_model is None: raise Exception("Cannot set deployment config to an uninitialized model.") - self.pysdk_model.set_deployment_config(config_name) + self.pysdk_model.set_deployment_config(config_name, instance_type) def get_deployment_config(self) -> Optional[Dict[str, Any]]: """Gets the deployment config to apply to the model. diff --git a/src/sagemaker/utils.py b/src/sagemaker/utils.py index 35f60b37e1..6c9e1b4b16 100644 --- a/src/sagemaker/utils.py +++ b/src/sagemaker/utils.py @@ -1664,17 +1664,21 @@ def deep_override_dict( def get_instance_rate_per_hour( instance_type: str, region: str, -) -> Union[Dict[str, str], None]: +) -> Optional[Dict[str, str]]: """Gets instance rate per hour for the given instance type. Args: instance_type (str): The instance type. region (str): The region. Returns: - Union[Dict[str, str], None]: Instance rate per hour. - Example: {'name': 'Instance Rate', 'unit': 'USD/Hrs', 'value': '1.1250000000'}}. - """ + Optional[Dict[str, str]]: Instance rate per hour. + Example: {'name': 'Instance Rate', 'unit': 'USD/Hrs', 'value': '1.125'}. + Raises: + Exception: An exception is raised if + the IAM role is not authorized to perform pricing:GetProducts. + or unexpected event happened. + """ region_name = "us-east-1" if region.startswith("eu") or region.startswith("af"): region_name = "eu-central-1" @@ -1682,35 +1686,34 @@ def get_instance_rate_per_hour( region_name = "ap-south-1" pricing_client: boto3.client = boto3.client("pricing", region_name=region_name) - try: - res = pricing_client.get_products( - ServiceCode="AmazonSageMaker", - Filters=[ - {"Type": "TERM_MATCH", "Field": "instanceName", "Value": instance_type}, - {"Type": "TERM_MATCH", "Field": "locationType", "Value": "AWS Region"}, - {"Type": "TERM_MATCH", "Field": "regionCode", "Value": region}, - ], - ) + res = pricing_client.get_products( + ServiceCode="AmazonSageMaker", + Filters=[ + {"Type": "TERM_MATCH", "Field": "instanceName", "Value": instance_type}, + {"Type": "TERM_MATCH", "Field": "locationType", "Value": "AWS Region"}, + {"Type": "TERM_MATCH", "Field": "regionCode", "Value": region}, + ], + ) - price_list = res.get("PriceList", []) - if len(price_list) > 0: - price_data = price_list[0] - if isinstance(price_data, str): - price_data = json.loads(price_data) + price_list = res.get("PriceList", []) + if len(price_list) > 0: + price_data = price_list[0] + if isinstance(price_data, str): + price_data = json.loads(price_data) - return extract_instance_rate_per_hour(price_data) - except Exception as e: # pylint: disable=W0703 - logging.exception("Error getting instance rate: %s", e) - return None + instance_rate_per_hour = extract_instance_rate_per_hour(price_data) + if instance_rate_per_hour is not None: + return instance_rate_per_hour + raise Exception(f"Unable to get instance rate per hour for instance type: {instance_type}.") -def extract_instance_rate_per_hour(price_data: Dict[str, Any]) -> Union[Dict[str, str], None]: +def extract_instance_rate_per_hour(price_data: Dict[str, Any]) -> Optional[Dict[str, str]]: """Extract instance rate per hour for the given Price JSON data. Args: price_data (Dict[str, Any]): The Price JSON data. Returns: - Union[Dict[str, str], None]: Instance rate per hour. + Optional[Dict[str, str], None]: Instance rate per hour. """ if price_data is not None: @@ -1718,9 +1721,12 @@ def extract_instance_rate_per_hour(price_data: Dict[str, Any]) -> Union[Dict[str for dimension in price_dimensions: for price in dimension.get("priceDimensions", {}).values(): for currency in price.get("pricePerUnit", {}).keys(): + value = price.get("pricePerUnit", {}).get(currency) + if value is not None: + value = str(round(float(value), 3)) return { "unit": f"{currency}/{price.get('unit', 'Hrs')}", - "value": price.get("pricePerUnit", {}).get(currency), + "value": value, "name": "Instance Rate", } return None diff --git a/tests/unit/sagemaker/jumpstart/model/test_model.py b/tests/unit/sagemaker/jumpstart/model/test_model.py index 5bbc31a5b1..cd11d950d5 100644 --- a/tests/unit/sagemaker/jumpstart/model/test_model.py +++ b/tests/unit/sagemaker/jumpstart/model/test_model.py @@ -15,6 +15,7 @@ from typing import Optional, Set from unittest import mock import unittest + import pandas as pd from mock import MagicMock, Mock import pytest @@ -52,6 +53,7 @@ get_mock_init_kwargs, get_base_deployment_configs, get_base_spec_with_prototype_configs_with_missing_benchmarks, + append_instance_stat_metrics, ) import boto3 @@ -66,7 +68,6 @@ class ModelTest(unittest.TestCase): - mock_session_empty_config = MagicMock(sagemaker_config={}) @mock.patch( @@ -1714,19 +1715,17 @@ def test_model_set_deployment_config_incompatible_instance_type_or_name( @mock.patch("sagemaker.jumpstart.model.get_init_kwargs") @mock.patch("sagemaker.jumpstart.utils.verify_model_region_and_return_specs") - @mock.patch("sagemaker.jumpstart.model.get_instance_rate_per_hour") + @mock.patch("sagemaker.jumpstart.model.add_instance_rate_stats_to_benchmark_metrics") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") @mock.patch("sagemaker.jumpstart.factory.model.Session") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") - @mock.patch("sagemaker.jumpstart.model.Model.deploy") @mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region) def test_model_list_deployment_configs( self, - mock_model_deploy: mock.Mock, mock_get_model_specs: mock.Mock, mock_session: mock.Mock, mock_get_manifest: mock.Mock, - mock_get_instance_rate_per_hour: mock.Mock, + mock_add_instance_rate_stats_to_benchmark_metrics: mock.Mock, mock_verify_model_region_and_return_specs: mock.Mock, mock_get_init_kwargs: mock.Mock, ): @@ -1736,16 +1735,14 @@ def test_model_list_deployment_configs( mock_verify_model_region_and_return_specs.side_effect = ( lambda *args, **kwargs: get_base_spec_with_prototype_configs() ) - mock_get_instance_rate_per_hour.side_effect = lambda *args, **kwargs: { - "name": "Instance Rate", - "unit": "USD/Hrs", - "value": "0.0083000000", - } + mock_add_instance_rate_stats_to_benchmark_metrics.side_effect = lambda region, metrics: ( + None, + append_instance_stat_metrics(metrics), + ) mock_get_model_specs.side_effect = get_prototype_spec_with_configs mock_get_manifest.side_effect = ( lambda region, model_type, *args, **kwargs: get_prototype_manifest(region, model_type) ) - mock_model_deploy.return_value = default_predictor mock_session.return_value = sagemaker_session @@ -1756,19 +1753,15 @@ def test_model_list_deployment_configs( self.assertEqual(configs, get_base_deployment_configs()) @mock.patch("sagemaker.jumpstart.utils.verify_model_region_and_return_specs") - @mock.patch("sagemaker.jumpstart.model.get_instance_rate_per_hour") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") @mock.patch("sagemaker.jumpstart.factory.model.Session") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") - @mock.patch("sagemaker.jumpstart.model.Model.deploy") @mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region) def test_model_list_deployment_configs_empty( self, - mock_model_deploy: mock.Mock, mock_get_model_specs: mock.Mock, mock_session: mock.Mock, mock_get_manifest: mock.Mock, - mock_get_instance_rate_per_hour: mock.Mock, mock_verify_model_region_and_return_specs: mock.Mock, ): model_id, _ = "pytorch-eqa-bert-base-cased", "*" @@ -1776,16 +1769,10 @@ def test_model_list_deployment_configs_empty( mock_verify_model_region_and_return_specs.side_effect = ( lambda *args, **kwargs: get_special_model_spec(model_id="gemma-model") ) - mock_get_instance_rate_per_hour.side_effect = lambda *args, **kwargs: { - "name": "Instance Rate", - "unit": "USD/Hrs", - "value": "0.0083000000", - } mock_get_model_specs.side_effect = get_prototype_spec_with_configs mock_get_manifest.side_effect = ( lambda region, model_type, *args, **kwargs: get_prototype_manifest(region, model_type) ) - mock_model_deploy.return_value = default_predictor mock_session.return_value = sagemaker_session @@ -1797,7 +1784,7 @@ def test_model_list_deployment_configs_empty( @mock.patch("sagemaker.jumpstart.model.get_init_kwargs") @mock.patch("sagemaker.jumpstart.utils.verify_model_region_and_return_specs") - @mock.patch("sagemaker.jumpstart.model.get_instance_rate_per_hour") + @mock.patch("sagemaker.jumpstart.model.add_instance_rate_stats_to_benchmark_metrics") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") @mock.patch("sagemaker.jumpstart.factory.model.Session") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @@ -1809,7 +1796,7 @@ def test_model_retrieve_deployment_config( mock_get_model_specs: mock.Mock, mock_session: mock.Mock, mock_get_manifest: mock.Mock, - mock_get_instance_rate_per_hour: mock.Mock, + mock_add_instance_rate_stats_to_benchmark_metrics: mock.Mock, mock_verify_model_region_and_return_specs: mock.Mock, mock_get_init_kwargs: mock.Mock, ): @@ -1818,18 +1805,17 @@ def test_model_retrieve_deployment_config( mock_verify_model_region_and_return_specs.side_effect = ( lambda *args, **kwargs: get_base_spec_with_prototype_configs_with_missing_benchmarks() ) - mock_get_instance_rate_per_hour.side_effect = lambda *args, **kwargs: { - "name": "Instance Rate", - "unit": "USD/Hrs", - "value": "0.0083000000", - } + mock_add_instance_rate_stats_to_benchmark_metrics.side_effect = lambda region, metrics: ( + None, + append_instance_stat_metrics(metrics), + ) mock_get_model_specs.side_effect = get_prototype_spec_with_configs mock_get_manifest.side_effect = ( lambda region, model_type, *args, **kwargs: get_prototype_manifest(region, model_type) ) mock_model_deploy.return_value = default_predictor - expected = get_base_deployment_configs()[0] + expected = get_base_deployment_configs(True)[0] config_name = expected.get("DeploymentConfigName") instance_type = expected.get("InstanceType") mock_get_init_kwargs.side_effect = lambda *args, **kwargs: get_mock_init_kwargs( @@ -1849,19 +1835,17 @@ def test_model_retrieve_deployment_config( @mock.patch("sagemaker.jumpstart.model.get_init_kwargs") @mock.patch("sagemaker.jumpstart.utils.verify_model_region_and_return_specs") - @mock.patch("sagemaker.jumpstart.model.get_instance_rate_per_hour") + @mock.patch("sagemaker.jumpstart.model.add_instance_rate_stats_to_benchmark_metrics") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") @mock.patch("sagemaker.jumpstart.factory.model.Session") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") - @mock.patch("sagemaker.jumpstart.model.Model.deploy") @mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region) def test_model_display_benchmark_metrics( self, - mock_model_deploy: mock.Mock, mock_get_model_specs: mock.Mock, mock_session: mock.Mock, mock_get_manifest: mock.Mock, - mock_get_instance_rate_per_hour: mock.Mock, + mock_add_instance_rate_stats_to_benchmark_metrics: mock.Mock, mock_verify_model_region_and_return_specs: mock.Mock, mock_get_init_kwargs: mock.Mock, ): @@ -1871,16 +1855,14 @@ def test_model_display_benchmark_metrics( mock_verify_model_region_and_return_specs.side_effect = ( lambda *args, **kwargs: get_base_spec_with_prototype_configs_with_missing_benchmarks() ) - mock_get_instance_rate_per_hour.side_effect = lambda *args, **kwargs: { - "name": "Instance Rate", - "unit": "USD/Hrs", - "value": "0.0083000000", - } + mock_add_instance_rate_stats_to_benchmark_metrics.side_effect = lambda region, metrics: ( + None, + append_instance_stat_metrics(metrics), + ) mock_get_model_specs.side_effect = get_prototype_spec_with_configs mock_get_manifest.side_effect = ( lambda region, model_type, *args, **kwargs: get_prototype_manifest(region, model_type) ) - mock_model_deploy.return_value = default_predictor mock_session.return_value = sagemaker_session @@ -1890,19 +1872,17 @@ def test_model_display_benchmark_metrics( @mock.patch("sagemaker.jumpstart.model.get_init_kwargs") @mock.patch("sagemaker.jumpstart.utils.verify_model_region_and_return_specs") - @mock.patch("sagemaker.jumpstart.model.get_instance_rate_per_hour") + @mock.patch("sagemaker.jumpstart.model.add_instance_rate_stats_to_benchmark_metrics") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") @mock.patch("sagemaker.jumpstart.factory.model.Session") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") - @mock.patch("sagemaker.jumpstart.model.Model.deploy") @mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region) def test_model_benchmark_metrics( self, - mock_model_deploy: mock.Mock, mock_get_model_specs: mock.Mock, mock_session: mock.Mock, mock_get_manifest: mock.Mock, - mock_get_instance_rate_per_hour: mock.Mock, + mock_add_instance_rate_stats_to_benchmark_metrics: mock.Mock, mock_verify_model_region_and_return_specs: mock.Mock, mock_get_init_kwargs: mock.Mock, ): @@ -1912,16 +1892,14 @@ def test_model_benchmark_metrics( mock_verify_model_region_and_return_specs.side_effect = ( lambda *args, **kwargs: get_base_spec_with_prototype_configs() ) - mock_get_instance_rate_per_hour.side_effect = lambda *args, **kwargs: { - "name": "Instance Rate", - "unit": "USD/Hrs", - "value": "0.0083000000", - } + mock_add_instance_rate_stats_to_benchmark_metrics.side_effect = lambda region, metrics: ( + None, + append_instance_stat_metrics(metrics), + ) mock_get_model_specs.side_effect = get_prototype_spec_with_configs mock_get_manifest.side_effect = ( lambda region, model_type, *args, **kwargs: get_prototype_manifest(region, model_type) ) - mock_model_deploy.return_value = default_predictor mock_session.return_value = sagemaker_session diff --git a/tests/unit/sagemaker/jumpstart/test_types.py b/tests/unit/sagemaker/jumpstart/test_types.py index 5ca01c3c52..c52bf76f4e 100644 --- a/tests/unit/sagemaker/jumpstart/test_types.py +++ b/tests/unit/sagemaker/jumpstart/test_types.py @@ -22,6 +22,8 @@ JumpStartModelSpecs, JumpStartModelHeader, JumpStartConfigComponent, + DeploymentConfigMetadata, + JumpStartModelInitKwargs, ) from tests.unit.sagemaker.jumpstart.constants import ( BASE_SPEC, @@ -29,6 +31,7 @@ INFERENCE_CONFIGS, TRAINING_CONFIG_RANKINGS, TRAINING_CONFIGS, + INIT_KWARGS, ) INSTANCE_TYPE_VARIANT = JumpStartInstanceTypeVariants( @@ -1248,3 +1251,39 @@ def test_set_training_config(): with pytest.raises(ValueError) as error: specs1.set_config("invalid_name", scope="unknown scope") + + +def test_deployment_config_metadata(): + spec = {**BASE_SPEC, **INFERENCE_CONFIGS, **INFERENCE_CONFIG_RANKINGS} + specs = JumpStartModelSpecs(spec) + jumpstart_config = specs.inference_configs.get_top_config_from_ranking() + + deployment_config_metadata = DeploymentConfigMetadata( + jumpstart_config.config_name, + jumpstart_config.benchmark_metrics, + jumpstart_config.resolved_config, + JumpStartModelInitKwargs( + model_id=specs.model_id, + model_data=INIT_KWARGS.get("model_data"), + image_uri=INIT_KWARGS.get("image_uri"), + instance_type=INIT_KWARGS.get("instance_type"), + env=INIT_KWARGS.get("env"), + config_name=jumpstart_config.config_name, + ), + ) + + json_obj = deployment_config_metadata.to_json() + + assert isinstance(json_obj, dict) + assert json_obj["DeploymentConfigName"] == jumpstart_config.config_name + for key in json_obj["BenchmarkMetrics"]: + assert len(json_obj["BenchmarkMetrics"][key]) == len( + jumpstart_config.benchmark_metrics.get(key) + ) + assert json_obj["AccelerationConfigs"] == jumpstart_config.resolved_config.get( + "acceleration_configs" + ) + assert json_obj["DeploymentArgs"]["ImageUri"] == INIT_KWARGS.get("image_uri") + assert json_obj["DeploymentArgs"]["ModelData"] == INIT_KWARGS.get("model_data") + assert json_obj["DeploymentArgs"]["Environment"] == INIT_KWARGS.get("env") + assert json_obj["DeploymentArgs"]["InstanceType"] == INIT_KWARGS.get("instance_type") diff --git a/tests/unit/sagemaker/jumpstart/test_utils.py b/tests/unit/sagemaker/jumpstart/test_utils.py index f576e36185..f7458a29e9 100644 --- a/tests/unit/sagemaker/jumpstart/test_utils.py +++ b/tests/unit/sagemaker/jumpstart/test_utils.py @@ -13,6 +13,8 @@ from __future__ import absolute_import import os from unittest import TestCase + +from botocore.exceptions import ClientError from mock.mock import Mock, patch import pytest import boto3 @@ -49,8 +51,7 @@ get_spec_from_base_spec, get_special_model_spec, get_prototype_manifest, - get_base_deployment_configs, - get_base_deployment_configs_with_acceleration_configs, + get_base_deployment_configs_metadata, ) from mock import MagicMock @@ -1763,53 +1764,103 @@ def test_get_jumpstart_benchmark_stats_training( } +def test_extract_metrics_from_deployment_configs(): + configs = get_base_deployment_configs_metadata() + configs[0].benchmark_metrics = None + configs[2].deployment_args = None + + data = utils.get_metrics_from_deployment_configs(configs) + + for key in data: + assert len(data[key]) == (len(configs) - 2) + + +@patch("sagemaker.jumpstart.utils.get_instance_rate_per_hour") +def test_add_instance_rate_stats_to_benchmark_metrics( + mock_get_instance_rate_per_hour, +): + mock_get_instance_rate_per_hour.side_effect = lambda *args, **kwargs: { + "name": "Instance Rate", + "unit": "USD/Hrs", + "value": "3.76", + } + + err, out = utils.add_instance_rate_stats_to_benchmark_metrics( + "us-west-2", + { + "ml.p2.xlarge": [ + JumpStartBenchmarkStat({"name": "Latency", "value": "100", "unit": "Tokens/S"}) + ], + "ml.gd4.xlarge": [ + JumpStartBenchmarkStat({"name": "Latency", "value": "100", "unit": "Tokens/S"}) + ], + }, + ) + + assert err is None + for key in out: + assert len(out[key]) == 2 + for metric in out[key]: + if metric.name == "Instance Rate": + assert metric.to_json() == { + "name": "Instance Rate", + "unit": "USD/Hrs", + "value": "3.76", + } + + +@patch("sagemaker.jumpstart.utils.get_instance_rate_per_hour") +def test_add_instance_rate_stats_to_benchmark_metrics_client_ex( + mock_get_instance_rate_per_hour, +): + mock_get_instance_rate_per_hour.side_effect = ClientError( + {"Error": {"Message": "is not authorized to perform: pricing:GetProducts"}}, "GetProducts" + ) + + err, out = utils.add_instance_rate_stats_to_benchmark_metrics( + "us-west-2", + { + "ml.p2.xlarge": [ + JumpStartBenchmarkStat({"name": "Latency", "value": "100", "unit": "Tokens/S"}) + ], + }, + ) + + assert err == "is not authorized to perform: pricing:GetProducts" + for key in out: + assert len(out[key]) == 1 + + +@patch("sagemaker.jumpstart.utils.get_instance_rate_per_hour") +def test_add_instance_rate_stats_to_benchmark_metrics_ex( + mock_get_instance_rate_per_hour, +): + mock_get_instance_rate_per_hour.side_effect = Exception() + + err, out = utils.add_instance_rate_stats_to_benchmark_metrics( + "us-west-2", + { + "ml.p2.xlarge": [ + JumpStartBenchmarkStat({"name": "Latency", "value": "100", "unit": "Tokens/S"}) + ], + }, + ) + + assert err == "Unable to get instance rate per hour for instance type: ml.p2.xlarge." + for key in out: + assert len(out[key]) == 1 + + @pytest.mark.parametrize( - "config_name, configs, expected", + "stats, expected", [ + (None, False), ( - None, - get_base_deployment_configs(), - { - "Config Name": [ - "neuron-inference", - "neuron-inference-budget", - "gpu-inference-budget", - "gpu-inference", - ], - "Instance Type": ["ml.p2.xlarge", "ml.p2.xlarge", "ml.p2.xlarge", "ml.p2.xlarge"], - "Selected": ["No", "No", "No", "No"], - "Instance Rate (USD/Hrs)": [ - "0.0083000000", - "0.0083000000", - "0.0083000000", - "0.0083000000", - ], - }, - ), - ( - "neuron-inference", - get_base_deployment_configs_with_acceleration_configs(), - { - "Config Name": [ - "neuron-inference", - "neuron-inference-budget", - "gpu-inference-budget", - "gpu-inference", - ], - "Instance Type": ["ml.p2.xlarge", "ml.p2.xlarge", "ml.p2.xlarge", "ml.p2.xlarge"], - "Selected": ["Yes", "No", "No", "No"], - "Accelerated": ["Yes", "No", "No", "No"], - "Instance Rate (USD/Hrs)": [ - "0.0083000000", - "0.0083000000", - "0.0083000000", - "0.0083000000", - ], - }, + [JumpStartBenchmarkStat({"name": "Instance Rate", "unit": "USD/Hrs", "value": "3.76"})], + True, ), + ([JumpStartBenchmarkStat({"name": "Latency", "value": "100", "unit": "Tokens/S"})], False), ], ) -def test_extract_metrics_from_deployment_configs(config_name, configs, expected): - data = utils.get_metrics_from_deployment_configs(configs, config_name) - - assert data == expected +def test_has_instance_rate_stat(stats, expected): + assert utils.has_instance_rate_stat(stats) is expected diff --git a/tests/unit/sagemaker/jumpstart/utils.py b/tests/unit/sagemaker/jumpstart/utils.py index 8b814c3d71..e8a93dff6c 100644 --- a/tests/unit/sagemaker/jumpstart/utils.py +++ b/tests/unit/sagemaker/jumpstart/utils.py @@ -29,6 +29,9 @@ JumpStartS3FileType, JumpStartModelHeader, JumpStartModelInitKwargs, + DeploymentConfigMetadata, + JumpStartModelDeployKwargs, + JumpStartBenchmarkStat, ) from sagemaker.jumpstart.enums import JumpStartModelType from sagemaker.jumpstart.utils import get_formatted_manifest @@ -323,10 +326,6 @@ def overwrite_dictionary( return base_dictionary -def get_base_deployment_configs() -> List[Dict[str, Any]]: - return DEPLOYMENT_CONFIGS - - def get_base_deployment_configs_with_acceleration_configs() -> List[Dict[str, Any]]: configs = copy.deepcopy(DEPLOYMENT_CONFIGS) configs[0]["AccelerationConfigs"] = [ @@ -348,3 +347,60 @@ def get_mock_init_kwargs( resources=ResourceRequirements(), config_name=config_name, ) + + +def get_base_deployment_configs_metadata( + omit_benchmark_metrics: bool = False, +) -> List[DeploymentConfigMetadata]: + specs = ( + get_base_spec_with_prototype_configs_with_missing_benchmarks() + if omit_benchmark_metrics + else get_base_spec_with_prototype_configs() + ) + configs = [] + for config_name, jumpstart_config in specs.inference_configs.configs.items(): + benchmark_metrics = jumpstart_config.benchmark_metrics + + if benchmark_metrics: + for instance_type in benchmark_metrics: + benchmark_metrics[instance_type].append( + JumpStartBenchmarkStat( + {"name": "Instance Rate", "unit": "USD/Hrs", "value": "3.76"} + ) + ) + + configs.append( + DeploymentConfigMetadata( + config_name=config_name, + benchmark_metrics=jumpstart_config.benchmark_metrics, + resolved_config=jumpstart_config.resolved_config, + init_kwargs=get_mock_init_kwargs( + get_base_spec_with_prototype_configs().model_id, config_name + ), + deploy_kwargs=JumpStartModelDeployKwargs( + model_id=get_base_spec_with_prototype_configs().model_id, + ), + ) + ) + return configs + + +def get_base_deployment_configs( + omit_benchmark_metrics: bool = False, +) -> List[Dict[str, Any]]: + return [ + config.to_json() for config in get_base_deployment_configs_metadata(omit_benchmark_metrics) + ] + + +def append_instance_stat_metrics( + metrics: Dict[str, List[JumpStartBenchmarkStat]] +) -> Dict[str, List[JumpStartBenchmarkStat]]: + if metrics is not None: + for key in metrics: + metrics[key].append( + JumpStartBenchmarkStat( + {"name": "Instance Rate", "value": "3.76", "unit": "USD/Hrs"} + ) + ) + return metrics diff --git a/tests/unit/sagemaker/serve/builder/test_js_builder.py b/tests/unit/sagemaker/serve/builder/test_js_builder.py index b83b113209..56b01cd9e3 100644 --- a/tests/unit/sagemaker/serve/builder/test_js_builder.py +++ b/tests/unit/sagemaker/serve/builder/test_js_builder.py @@ -752,9 +752,11 @@ def test_set_deployment_config( mock_pre_trained_model.return_value.image_uri = mock_tgi_image_uri builder.build() - builder.set_deployment_config("config-1") + builder.set_deployment_config("config-1", "ml.g5.24xlarge") - mock_pre_trained_model.return_value.set_deployment_config.assert_called_with("config-1") + mock_pre_trained_model.return_value.set_deployment_config.assert_called_with( + "config-1", "ml.g5.24xlarge" + ) @patch("sagemaker.serve.builder.jumpstart_builder._capture_telemetry", side_effect=None) @patch( @@ -789,7 +791,7 @@ def test_set_deployment_config_ex( "Cannot set deployment config to an uninitialized model.", lambda: ModelBuilder( model="facebook/galactica-mock-model-id", schema_builder=mock_schema_builder - ).set_deployment_config("config-2"), + ).set_deployment_config("config-2", "ml.g5.24xlarge"), ) @patch("sagemaker.serve.builder.jumpstart_builder._capture_telemetry", side_effect=None) diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py index bf6a7cb09f..e94f3087ad 100644 --- a/tests/unit/test_utils.py +++ b/tests/unit/test_utils.py @@ -1871,43 +1871,103 @@ def test_deep_override_skip_keys(self): @pytest.mark.parametrize( - "instance, region", + "instance, region, amazon_sagemaker_price_result, expected", [ - ("t4g.nano", "us-west-2"), - ("t4g.nano", "eu-central-1"), - ("t4g.nano", "af-south-1"), - ("t4g.nano", "ap-northeast-2"), - ("t4g.nano", "cn-north-1"), + ( + "ml.t4g.nano", + "us-west-2", + { + "PriceList": [ + { + "terms": { + "OnDemand": { + "3WK7G7WSYVS3K492.JRTCKXETXF": { + "priceDimensions": { + "3WK7G7WSYVS3K492.JRTCKXETXF.6YS6EN2CT7": { + "unit": "Hrs", + "endRange": "Inf", + "description": "$0.9 per Unused Reservation Linux p2.xlarge Instance Hour", + "appliesTo": [], + "rateCode": "3WK7G7WSYVS3K492.JRTCKXETXF.6YS6EN2CT7", + "beginRange": "0", + "pricePerUnit": {"USD": "0.9000000000"}, + } + } + } + } + }, + } + ] + }, + {"name": "Instance Rate", "unit": "USD/Hrs", "value": "0.9"}, + ), + ( + "ml.t4g.nano", + "eu-central-1", + { + "PriceList": [ + '{"terms": {"OnDemand": {"22VNQ3N6GZGZMXYM.JRTCKXETXF": {"priceDimensions":{' + '"22VNQ3N6GZGZMXYM.JRTCKXETXF.6YS6EN2CT7": {"unit": "Hrs", "endRange": "Inf", "description": ' + '"$0.0083 per' + "On" + 'Demand Ubuntu Pro t4g.nano Instance Hour", "appliesTo": [], "rateCode": ' + '"22VNQ3N6GZGZMXYM.JRTCKXETXF.6YS6EN2CT7", "beginRange": "0", "pricePerUnit":{"USD": ' + '"0.0083000000"}}},' + '"sku": "22VNQ3N6GZGZMXYM", "effectiveDate": "2024-04-01T00:00:00Z", "offerTermCode": "JRTCKXETXF",' + '"termAttributes": {}}}}}' + ] + }, + {"name": "Instance Rate", "unit": "USD/Hrs", "value": "0.008"}, + ), + ( + "ml.t4g.nano", + "af-south-1", + { + "PriceList": [ + '{"terms": {"OnDemand": {"22VNQ3N6GZGZMXYM.JRTCKXETXF": {"priceDimensions":{' + '"22VNQ3N6GZGZMXYM.JRTCKXETXF.6YS6EN2CT7": {"unit": "Hrs", "endRange": "Inf", "description": ' + '"$0.0083 per' + "On" + 'Demand Ubuntu Pro t4g.nano Instance Hour", "appliesTo": [], "rateCode": ' + '"22VNQ3N6GZGZMXYM.JRTCKXETXF.6YS6EN2CT7", "beginRange": "0", "pricePerUnit":{"USD": ' + '"0.0083000000"}}},' + '"sku": "22VNQ3N6GZGZMXYM", "effectiveDate": "2024-04-01T00:00:00Z", "offerTermCode": "JRTCKXETXF",' + '"termAttributes": {}}}}}' + ] + }, + {"name": "Instance Rate", "unit": "USD/Hrs", "value": "0.008"}, + ), + ( + "ml.t4g.nano", + "ap-northeast-2", + { + "PriceList": [ + '{"terms": {"OnDemand": {"22VNQ3N6GZGZMXYM.JRTCKXETXF": {"priceDimensions":{' + '"22VNQ3N6GZGZMXYM.JRTCKXETXF.6YS6EN2CT7": {"unit": "Hrs", "endRange": "Inf", "description": ' + '"$0.0083 per' + "On" + 'Demand Ubuntu Pro t4g.nano Instance Hour", "appliesTo": [], "rateCode": ' + '"22VNQ3N6GZGZMXYM.JRTCKXETXF.6YS6EN2CT7", "beginRange": "0", "pricePerUnit":{"USD": ' + '"0.0083000000"}}},' + '"sku": "22VNQ3N6GZGZMXYM", "effectiveDate": "2024-04-01T00:00:00Z", "offerTermCode": "JRTCKXETXF",' + '"termAttributes": {}}}}}' + ] + }, + {"name": "Instance Rate", "unit": "USD/Hrs", "value": "0.008"}, + ), ], ) @patch("boto3.client") -def test_get_instance_rate_per_hour(mock_client, instance, region): - amazon_sagemaker_price_result = { - "PriceList": [ - '{"terms": {"OnDemand": {"22VNQ3N6GZGZMXYM.JRTCKXETXF": {"priceDimensions":{' - '"22VNQ3N6GZGZMXYM.JRTCKXETXF.6YS6EN2CT7": {"unit": "Hrs", "endRange": "Inf", "description": "$0.0083 per ' - "On" - 'Demand Ubuntu Pro t4g.nano Instance Hour", "appliesTo": [], "rateCode": ' - '"22VNQ3N6GZGZMXYM.JRTCKXETXF.6YS6EN2CT7", "beginRange": "0", "pricePerUnit":{"USD": "0.0083000000"}}}, ' - '"sku": "22VNQ3N6GZGZMXYM", "effectiveDate": "2024-04-01T00:00:00Z", "offerTermCode": "JRTCKXETXF", ' - '"termAttributes": {}}}}}' - ] - } +def test_get_instance_rate_per_hour( + mock_client, instance, region, amazon_sagemaker_price_result, expected +): mock_client.return_value.get_products.side_effect = ( lambda *args, **kwargs: amazon_sagemaker_price_result ) instance_rate = get_instance_rate_per_hour(instance_type=instance, region=region) - assert instance_rate == {"name": "Instance Rate", "unit": "USD/Hrs", "value": "0.0083000000"} - - -@patch("boto3.client") -def test_get_instance_rate_per_hour_ex(mock_client): - mock_client.return_value.get_products.side_effect = lambda *args, **kwargs: Exception() - instance_rate = get_instance_rate_per_hour(instance_type="ml.t4g.nano", region="us-west-2") - - assert instance_rate is None + assert instance_rate == expected @pytest.mark.parametrize( @@ -1934,7 +1994,7 @@ def test_get_instance_rate_per_hour_ex(mock_client): } } }, - {"name": "Instance Rate", "unit": "USD/Hrs", "value": "0.9000000000"}, + {"name": "Instance Rate", "unit": "USD/Hrs", "value": "0.9"}, ), ], ) From 7fa391b21cfa046d2d90ca90c2551bfac06811f1 Mon Sep 17 00:00:00 2001 From: Haotian An <33510317+Captainia@users.noreply.github.com> Date: Mon, 29 Apr 2024 16:09:15 -0400 Subject: [PATCH 26/32] fix: use different separator to flatten dict (#4629) --- src/sagemaker/utils.py | 4 ++-- tests/unit/test_utils.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/sagemaker/utils.py b/src/sagemaker/utils.py index 6c9e1b4b16..20701edb01 100644 --- a/src/sagemaker/utils.py +++ b/src/sagemaker/utils.py @@ -1605,7 +1605,7 @@ def can_model_package_source_uri_autopopulate(source_uri: str): ) -def flatten_dict(source_dict: Dict[str, Any], sep: str = ".") -> Dict[str, Any]: +def flatten_dict(source_dict: Dict[str, Any], sep: str = "^") -> Dict[str, Any]: """Flatten a nested dictionary. Args: @@ -1620,7 +1620,7 @@ def flatten_dict(source_dict: Dict[str, Any], sep: str = ".") -> Dict[str, Any]: return {} -def unflatten_dict(source_dict: Dict[str, Any], sep: str = ".") -> Dict[str, Any]: +def unflatten_dict(source_dict: Dict[str, Any], sep: str = "^") -> Dict[str, Any]: """Unflatten a flattened dictionary back into a nested dictionary. Args: diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py index e94f3087ad..083e2dd09a 100644 --- a/tests/unit/test_utils.py +++ b/tests/unit/test_utils.py @@ -1819,7 +1819,7 @@ def test_can_model_package_source_uri_autopopulate(): class TestDeepMergeDict(TestCase): def test_flatten_dict_basic(self): nested_dict = {"a": 1, "b": {"x": 2, "y": {"p": 3, "q": 4}}, "c": 5} - flattened_dict = {"a": 1, "b.x": 2, "b.y.p": 3, "b.y.q": 4, "c": 5} + flattened_dict = {"a": 1, "b^x": 2, "b^y^p": 3, "b^y^q": 4, "c": 5} self.assertDictEqual(flatten_dict(nested_dict), flattened_dict) self.assertDictEqual(unflatten_dict(flattened_dict), nested_dict) @@ -1837,7 +1837,7 @@ def test_flatten_dict_no_nested(self): def test_flatten_dict_with_various_types(self): nested_dict = {"a": [1, 2, 3], "b": {"x": None, "y": {"p": [], "q": ""}}, "c": 9} - flattened_dict = {"a": [1, 2, 3], "b.x": None, "b.y.p": [], "b.y.q": "", "c": 9} + flattened_dict = {"a": [1, 2, 3], "b^x": None, "b^y^p": [], "b^y^q": "", "c": 9} self.assertDictEqual(flatten_dict(nested_dict), flattened_dict) self.assertDictEqual(unflatten_dict(flattened_dict), nested_dict) From 5af9b6b7c6837f7d81437b35e286213e7cf4e9a6 Mon Sep 17 00:00:00 2001 From: Haotian An <33510317+Captainia@users.noreply.github.com> Date: Tue, 30 Apr 2024 12:42:32 -0400 Subject: [PATCH 27/32] Use separate tags for inference and training configs (#4635) * Use separate tags for inference and training * format * format * format * format --- src/sagemaker/jumpstart/enums.py | 4 +- src/sagemaker/jumpstart/estimator.py | 6 +- src/sagemaker/jumpstart/factory/estimator.py | 5 +- src/sagemaker/jumpstart/factory/model.py | 11 +- src/sagemaker/jumpstart/session_utils.py | 58 ++++-- src/sagemaker/jumpstart/utils.py | 39 ++-- src/sagemaker/predictor.py | 1 + src/sagemaker/utils.py | 100 ++++++---- .../jumpstart/estimator/test_estimator.py | 55 +++++- .../sagemaker/jumpstart/model/test_model.py | 6 +- .../sagemaker/jumpstart/test_predictor.py | 9 +- .../sagemaker/jumpstart/test_session_utils.py | 112 ++++++++---- tests/unit/sagemaker/jumpstart/test_utils.py | 173 ++++++++++++------ tests/unit/test_utils.py | 18 +- 14 files changed, 417 insertions(+), 180 deletions(-) diff --git a/src/sagemaker/jumpstart/enums.py b/src/sagemaker/jumpstart/enums.py index 0c192084ec..9666ce828f 100644 --- a/src/sagemaker/jumpstart/enums.py +++ b/src/sagemaker/jumpstart/enums.py @@ -92,7 +92,9 @@ class JumpStartTag(str, Enum): MODEL_ID = "sagemaker-sdk:jumpstart-model-id" MODEL_VERSION = "sagemaker-sdk:jumpstart-model-version" MODEL_TYPE = "sagemaker-sdk:jumpstart-model-type" - MODEL_CONFIG_NAME = "sagemaker-sdk:jumpstart-model-config-name" + + INFERENCE_CONFIG_NAME = "sagemaker-sdk:jumpstart-inference-config-name" + TRAINING_CONFIG_NAME = "sagemaker-sdk:jumpstart-training-config-name" class SerializerType(str, Enum): diff --git a/src/sagemaker/jumpstart/estimator.py b/src/sagemaker/jumpstart/estimator.py index b20a5c1ead..4939be4041 100644 --- a/src/sagemaker/jumpstart/estimator.py +++ b/src/sagemaker/jumpstart/estimator.py @@ -733,7 +733,7 @@ def attach( config_name = None if model_id is None: - model_id, model_version, config_name = get_model_info_from_training_job( + model_id, model_version, _, config_name = get_model_info_from_training_job( training_job_name=training_job_name, sagemaker_session=sagemaker_session ) @@ -1139,7 +1139,9 @@ def set_training_config(self, config_name: str) -> None: Args: config_name (str): The name of the config. """ - self.__init__(**self.init_kwargs, config_name=config_name) + self.__init__( + model_id=self.model_id, model_version=self.model_version, config_name=config_name + ) def __str__(self) -> str: """Overriding str(*) method to make more human-readable.""" diff --git a/src/sagemaker/jumpstart/factory/estimator.py b/src/sagemaker/jumpstart/factory/estimator.py index 2d5b29b52f..604b20bc81 100644 --- a/src/sagemaker/jumpstart/factory/estimator.py +++ b/src/sagemaker/jumpstart/factory/estimator.py @@ -61,7 +61,7 @@ JumpStartModelInitKwargs, ) from sagemaker.jumpstart.utils import ( - add_jumpstart_model_id_version_tags, + add_jumpstart_model_info_tags, get_eula_message, update_dict_if_key_not_present, resolve_estimator_sagemaker_config_field, @@ -477,11 +477,12 @@ def _add_tags_to_kwargs(kwargs: JumpStartEstimatorInitKwargs) -> JumpStartEstima ).version if kwargs.sagemaker_session.settings.include_jumpstart_tags: - kwargs.tags = add_jumpstart_model_id_version_tags( + kwargs.tags = add_jumpstart_model_info_tags( kwargs.tags, kwargs.model_id, full_model_version, config_name=kwargs.config_name, + scope=JumpStartScriptScope.TRAINING, ) return kwargs diff --git a/src/sagemaker/jumpstart/factory/model.py b/src/sagemaker/jumpstart/factory/model.py index 53508b56f3..54301973e8 100644 --- a/src/sagemaker/jumpstart/factory/model.py +++ b/src/sagemaker/jumpstart/factory/model.py @@ -44,7 +44,7 @@ JumpStartModelRegisterKwargs, ) from sagemaker.jumpstart.utils import ( - add_jumpstart_model_id_version_tags, + add_jumpstart_model_info_tags, update_dict_if_key_not_present, resolve_model_sagemaker_config_field, verify_model_region_and_return_specs, @@ -495,8 +495,13 @@ def _add_tags_to_kwargs(kwargs: JumpStartModelDeployKwargs) -> Dict[str, Any]: ).version if kwargs.sagemaker_session.settings.include_jumpstart_tags: - kwargs.tags = add_jumpstart_model_id_version_tags( - kwargs.tags, kwargs.model_id, full_model_version, kwargs.model_type, kwargs.config_name + kwargs.tags = add_jumpstart_model_info_tags( + kwargs.tags, + kwargs.model_id, + full_model_version, + kwargs.model_type, + config_name=kwargs.config_name, + scope=JumpStartScriptScope.INFERENCE, ) return kwargs diff --git a/src/sagemaker/jumpstart/session_utils.py b/src/sagemaker/jumpstart/session_utils.py index 0fa7722f91..7953b67913 100644 --- a/src/sagemaker/jumpstart/session_utils.py +++ b/src/sagemaker/jumpstart/session_utils.py @@ -17,7 +17,7 @@ from typing import Optional, Tuple from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION -from sagemaker.jumpstart.utils import get_jumpstart_model_id_version_from_resource_arn +from sagemaker.jumpstart.utils import get_jumpstart_model_info_from_resource_arn from sagemaker.session import Session from sagemaker.utils import aws_partition @@ -26,7 +26,7 @@ def get_model_info_from_endpoint( endpoint_name: str, inference_component_name: Optional[str] = None, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, -) -> Tuple[str, str, Optional[str], Optional[str]]: +) -> Tuple[str, str, Optional[str], Optional[str], Optional[str]]: """Optionally inference component names, return the model ID, version and config name. Infers the model ID and version based on the resource tags. Returns a tuple of the model ID @@ -46,7 +46,8 @@ def get_model_info_from_endpoint( ( model_id, model_version, - config_name, + inference_config_name, + training_config_name, ) = _get_model_info_from_inference_component_endpoint_with_inference_component_name( # noqa E501 # pylint: disable=c0301 inference_component_name, sagemaker_session ) @@ -55,17 +56,29 @@ def get_model_info_from_endpoint( ( model_id, model_version, - config_name, + inference_config_name, + training_config_name, inference_component_name, ) = _get_model_info_from_inference_component_endpoint_without_inference_component_name( # noqa E501 # pylint: disable=c0301 endpoint_name, sagemaker_session ) else: - model_id, model_version, config_name = _get_model_info_from_model_based_endpoint( + ( + model_id, + model_version, + inference_config_name, + training_config_name, + ) = _get_model_info_from_model_based_endpoint( endpoint_name, inference_component_name, sagemaker_session ) - return model_id, model_version, inference_component_name, config_name + return ( + model_id, + model_version, + inference_component_name, + inference_config_name, + training_config_name, + ) def _get_model_info_from_inference_component_endpoint_without_inference_component_name( @@ -125,9 +138,12 @@ def _get_model_info_from_inference_component_endpoint_with_inference_component_n f"inference-component/{inference_component_name}" ) - model_id, model_version, config_name = get_jumpstart_model_id_version_from_resource_arn( - inference_component_arn, sagemaker_session - ) + ( + model_id, + model_version, + inference_config_name, + training_config_name, + ) = get_jumpstart_model_info_from_resource_arn(inference_component_arn, sagemaker_session) if not model_id: raise ValueError( @@ -136,14 +152,14 @@ def _get_model_info_from_inference_component_endpoint_with_inference_component_n "when retrieving default predictor for this inference component." ) - return model_id, model_version, config_name + return model_id, model_version, inference_config_name, training_config_name def _get_model_info_from_model_based_endpoint( endpoint_name: str, inference_component_name: Optional[str], sagemaker_session: Session, -) -> Tuple[str, str, Optional[str]]: +) -> Tuple[str, str, Optional[str], Optional[str]]: """Returns the model ID, version and config name inferred from a model-based endpoint. Raises: @@ -163,9 +179,12 @@ def _get_model_info_from_model_based_endpoint( endpoint_arn = f"arn:{partition}:sagemaker:{region}:{account_id}:endpoint/{endpoint_name}" - model_id, model_version, config_name = get_jumpstart_model_id_version_from_resource_arn( - endpoint_arn, sagemaker_session - ) + ( + model_id, + model_version, + inference_config_name, + training_config_name, + ) = get_jumpstart_model_info_from_resource_arn(endpoint_arn, sagemaker_session) if not model_id: raise ValueError( @@ -174,13 +193,13 @@ def _get_model_info_from_model_based_endpoint( "predictor for this endpoint." ) - return model_id, model_version, config_name + return model_id, model_version, inference_config_name, training_config_name def get_model_info_from_training_job( training_job_name: str, sagemaker_session: Optional[Session] = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, -) -> Tuple[str, str, Optional[str]]: +) -> Tuple[str, str, Optional[str], Optional[str]]: """Returns the model ID and version and config name inferred from a training job. Raises: @@ -199,8 +218,9 @@ def get_model_info_from_training_job( ( model_id, inferred_model_version, - config_name, - ) = get_jumpstart_model_id_version_from_resource_arn(training_job_arn, sagemaker_session) + inference_config_name, + trainig_config_name, + ) = get_jumpstart_model_info_from_resource_arn(training_job_arn, sagemaker_session) model_version = inferred_model_version or None @@ -211,4 +231,4 @@ def get_model_info_from_training_job( "for this training job." ) - return model_id, model_version, config_name + return model_id, model_version, inference_config_name, trainig_config_name diff --git a/src/sagemaker/jumpstart/utils.py b/src/sagemaker/jumpstart/utils.py index a8c4bd7c21..d2a0a396b5 100644 --- a/src/sagemaker/jumpstart/utils.py +++ b/src/sagemaker/jumpstart/utils.py @@ -320,7 +320,8 @@ def add_single_jumpstart_tag( tag_key_in_array(enums.JumpStartTag.MODEL_ID, curr_tags) or tag_key_in_array(enums.JumpStartTag.MODEL_VERSION, curr_tags) or tag_key_in_array(enums.JumpStartTag.MODEL_TYPE, curr_tags) - or tag_key_in_array(enums.JumpStartTag.MODEL_CONFIG_NAME, curr_tags) + or tag_key_in_array(enums.JumpStartTag.INFERENCE_CONFIG_NAME, curr_tags) + or tag_key_in_array(enums.JumpStartTag.TRAINING_CONFIG_NAME, curr_tags) ) if is_uri else False @@ -351,12 +352,13 @@ def get_jumpstart_base_name_if_jumpstart_model( return None -def add_jumpstart_model_id_version_tags( +def add_jumpstart_model_info_tags( tags: Optional[List[TagsDict]], model_id: str, model_version: str, model_type: Optional[enums.JumpStartModelType] = None, config_name: Optional[str] = None, + scope: enums.JumpStartScriptScope = None, ) -> List[TagsDict]: """Add custom model ID and version tags to JumpStart related resources.""" if model_id is None or model_version is None: @@ -380,10 +382,17 @@ def add_jumpstart_model_id_version_tags( tags, is_uri=False, ) - if config_name: + if config_name and scope == enums.JumpStartScriptScope.INFERENCE: tags = add_single_jumpstart_tag( config_name, - enums.JumpStartTag.MODEL_CONFIG_NAME, + enums.JumpStartTag.INFERENCE_CONFIG_NAME, + tags, + is_uri=False, + ) + if config_name and scope == enums.JumpStartScriptScope.TRAINING: + tags = add_single_jumpstart_tag( + config_name, + enums.JumpStartTag.TRAINING_CONFIG_NAME, tags, is_uri=False, ) @@ -840,10 +849,10 @@ def _extract_value_from_list_of_tags( return resolved_value -def get_jumpstart_model_id_version_from_resource_arn( +def get_jumpstart_model_info_from_resource_arn( resource_arn: str, sagemaker_session: Session = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION, -) -> Tuple[Optional[str], Optional[str], Optional[str]]: +) -> Tuple[Optional[str], Optional[str], Optional[str], Optional[str]]: """Returns the JumpStart model ID, version and config name if in resource tags. Returns 'None' if model ID or version or config name cannot be inferred from tags. @@ -853,7 +862,8 @@ def get_jumpstart_model_id_version_from_resource_arn( model_id_keys = [enums.JumpStartTag.MODEL_ID, *constants.EXTRA_MODEL_ID_TAGS] model_version_keys = [enums.JumpStartTag.MODEL_VERSION, *constants.EXTRA_MODEL_VERSION_TAGS] - model_config_name_keys = [enums.JumpStartTag.MODEL_CONFIG_NAME] + inference_config_name_keys = [enums.JumpStartTag.INFERENCE_CONFIG_NAME] + training_config_name_keys = [enums.JumpStartTag.TRAINING_CONFIG_NAME] model_id: Optional[str] = _extract_value_from_list_of_tags( tag_keys=model_id_keys, @@ -869,14 +879,21 @@ def get_jumpstart_model_id_version_from_resource_arn( resource_arn=resource_arn, ) - config_name: Optional[str] = _extract_value_from_list_of_tags( - tag_keys=model_config_name_keys, + inference_config_name: Optional[str] = _extract_value_from_list_of_tags( + tag_keys=inference_config_name_keys, + list_tags_result=list_tags_result, + resource_name="inference config name", + resource_arn=resource_arn, + ) + + training_config_name: Optional[str] = _extract_value_from_list_of_tags( + tag_keys=training_config_name_keys, list_tags_result=list_tags_result, - resource_name="model config name", + resource_name="training config name", resource_arn=resource_arn, ) - return model_id, model_version, config_name + return model_id, model_version, inference_config_name, training_config_name def get_region_fallback( diff --git a/src/sagemaker/predictor.py b/src/sagemaker/predictor.py index c5d08eb8f4..780a1a56c8 100644 --- a/src/sagemaker/predictor.py +++ b/src/sagemaker/predictor.py @@ -82,6 +82,7 @@ def retrieve_default( inferred_model_version, inferred_inference_component_name, inferred_config_name, + _, ) = get_model_info_from_endpoint(endpoint_name, inference_component_name, sagemaker_session) if not inferred_model_id: diff --git a/src/sagemaker/utils.py b/src/sagemaker/utils.py index 20701edb01..89db48ffd8 100644 --- a/src/sagemaker/utils.py +++ b/src/sagemaker/utils.py @@ -39,7 +39,7 @@ import botocore from botocore.utils import merge_dicts from six.moves.urllib import parse -import pandas as pd +from six import viewitems from sagemaker import deprecations from sagemaker.config import validate_sagemaker_config @@ -1605,44 +1605,80 @@ def can_model_package_source_uri_autopopulate(source_uri: str): ) -def flatten_dict(source_dict: Dict[str, Any], sep: str = "^") -> Dict[str, Any]: - """Flatten a nested dictionary. +def flatten_dict( + d: Dict[str, Any], + max_flatten_depth=None, +) -> Dict[str, Any]: + """Flatten a dictionary object. - Args: - source_dict (dict): The dictionary to be flattened. - sep (str): The separator to be used in the flattened dictionary. - Returns: - transformed_dict: The flattened dictionary. + d (Dict[str, Any]): + The dict that will be flattened. + max_flatten_depth (Optional[int]): + Maximum depth to merge. """ - flat_dict_list = pd.json_normalize(source_dict, sep=sep).to_dict(orient="records") - if flat_dict_list: - return flat_dict_list[0] - return {} + def tuple_reducer(k1, k2): + if k1 is None: + return (k2,) + return k1 + (k2,) -def unflatten_dict(source_dict: Dict[str, Any], sep: str = "^") -> Dict[str, Any]: - """Unflatten a flattened dictionary back into a nested dictionary. + # check max_flatten_depth + if max_flatten_depth is not None and max_flatten_depth < 1: + raise ValueError("max_flatten_depth should not be less than 1.") - Args: - source_dict (dict): The input flattened dictionary. - sep (str): The separator used in the flattened keys. + reducer = tuple_reducer - Returns: - transformed_dict: The reconstructed nested dictionary. + flat_dict = {} + + def _flatten(_d, depth, parent=None): + key_value_iterable = viewitems(_d) + has_item = False + for key, value in key_value_iterable: + has_item = True + flat_key = reducer(parent, key) + if isinstance(value, dict) and (max_flatten_depth is None or depth < max_flatten_depth): + has_child = _flatten(value, depth=depth + 1, parent=flat_key) + if has_child: + continue + + if flat_key in flat_dict: + raise ValueError("duplicated key '{}'".format(flat_key)) + flat_dict[flat_key] = value + + return has_item + + _flatten(d, depth=1) + return flat_dict + + +def nested_set_dict(d: Dict[str, Any], keys: List[str], value: Any) -> None: + """Set a value to a sequence of nested keys.""" + + key = keys[0] + + if len(keys) == 1: + d[key] = value + return + if not d: + return + + d = d.setdefault(key, {}) + nested_set_dict(d, keys[1:], value) + + +def unflatten_dict(d: Dict[str, Any]) -> Dict[str, Any]: + """Unflatten dict-like object. + + d (Dict[str, Any]) : + The dict that will be unflattened. """ - if not source_dict: - return {} - - result = {} - for key, value in source_dict.items(): - keys = key.split(sep) - current = result - for k in keys[:-1]: - if k not in current: - current[k] = {} - current = current[k] if current[k] is not None else current - current[keys[-1]] = value - return result + + unflattened_dict = {} + for flat_key, value in viewitems(d): + key_tuple = flat_key + nested_set_dict(unflattened_dict, key_tuple, value) + + return unflattened_dict def deep_override_dict( diff --git a/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py b/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py index 6a9bd9ff10..2794a5319f 100644 --- a/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py +++ b/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py @@ -1038,6 +1038,7 @@ def test_jumpstart_estimator_attach_no_model_id_happy_case( "js-trainable-model-prepacked", "1.0.0", None, + None, ) mock_get_model_specs.side_effect = get_special_model_spec @@ -1902,7 +1903,59 @@ def test_estimator_initialization_with_config_name( tags=[ {"Key": JumpStartTag.MODEL_ID, "Value": "pytorch-eqa-bert-base-cased"}, {"Key": JumpStartTag.MODEL_VERSION, "Value": "1.0.0"}, - {"Key": JumpStartTag.MODEL_CONFIG_NAME, "Value": "neuron-training"}, + {"Key": JumpStartTag.TRAINING_CONFIG_NAME, "Value": "neuron-training"}, + ], + enable_network_isolation=False, + ) + + estimator.fit() + + mock_estimator_fit.assert_called_once_with(wait=True) + + @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") + @mock.patch("sagemaker.jumpstart.factory.estimator.Session") + @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") + @mock.patch("sagemaker.jumpstart.estimator.Estimator.fit") + @mock.patch("sagemaker.jumpstart.estimator.Estimator.__init__") + @mock.patch("sagemaker.jumpstart.factory.estimator.JUMPSTART_DEFAULT_REGION_NAME", region) + def test_estimator_set_config_name( + self, + mock_estimator_init: mock.Mock, + mock_estimator_fit: mock.Mock, + mock_get_model_specs: mock.Mock, + mock_session: mock.Mock, + mock_get_manifest: mock.Mock, + ): + mock_get_model_specs.side_effect = get_prototype_spec_with_configs + mock_get_manifest.side_effect = ( + lambda region, model_type, *args, **kwargs: get_prototype_manifest(region, model_type) + ) + mock_estimator_fit.return_value = default_predictor + + model_id, _ = "pytorch-eqa-bert-base-cased", "*" + + mock_session.return_value = sagemaker_session + + estimator = JumpStartEstimator(model_id=model_id) + + estimator.set_training_config(config_name="neuron-training") + + mock_estimator_init.assert_called_with( + instance_type="ml.p2.xlarge", + instance_count=1, + image_uri="763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-training:1.5.0-gpu-py3", + model_uri="s3://jumpstart-cache-prod-us-west-2/artifacts/meta-textgeneration-llama-2-7b/" + "neuron-training/model/", + source_dir="s3://jumpstart-cache-prod-us-west-2/source-directory-tarballs/pytorch/" + "transfer_learning/eqa/v1.0.0/sourcedir.tar.gz", + entry_point="transfer_learning.py", + hyperparameters={"epochs": "3", "adam-learning-rate": "2e-05", "batch-size": "4"}, + role="fake role! do not use!", + sagemaker_session=sagemaker_session, + tags=[ + {"Key": JumpStartTag.MODEL_ID, "Value": "pytorch-eqa-bert-base-cased"}, + {"Key": JumpStartTag.MODEL_VERSION, "Value": "1.0.0"}, + {"Key": JumpStartTag.TRAINING_CONFIG_NAME, "Value": "neuron-training"}, ], enable_network_isolation=False, ) diff --git a/tests/unit/sagemaker/jumpstart/model/test_model.py b/tests/unit/sagemaker/jumpstart/model/test_model.py index cd11d950d5..f32687bd99 100644 --- a/tests/unit/sagemaker/jumpstart/model/test_model.py +++ b/tests/unit/sagemaker/jumpstart/model/test_model.py @@ -1563,7 +1563,7 @@ def test_model_initialization_with_config_name( tags=[ {"Key": JumpStartTag.MODEL_ID, "Value": "pytorch-eqa-bert-base-cased"}, {"Key": JumpStartTag.MODEL_VERSION, "Value": "1.0.0"}, - {"Key": JumpStartTag.MODEL_CONFIG_NAME, "Value": "neuron-inference"}, + {"Key": JumpStartTag.INFERENCE_CONFIG_NAME, "Value": "neuron-inference"}, ], wait=True, endpoint_logging=False, @@ -1627,7 +1627,7 @@ def test_model_set_deployment_config( tags=[ {"Key": JumpStartTag.MODEL_ID, "Value": "pytorch-eqa-bert-base-cased"}, {"Key": JumpStartTag.MODEL_VERSION, "Value": "1.0.0"}, - {"Key": JumpStartTag.MODEL_CONFIG_NAME, "Value": "neuron-inference"}, + {"Key": JumpStartTag.INFERENCE_CONFIG_NAME, "Value": "neuron-inference"}, ], wait=True, endpoint_logging=False, @@ -1645,7 +1645,7 @@ def test_model_set_deployment_config( tags=[ {"Key": JumpStartTag.MODEL_ID, "Value": "pytorch-eqa-bert-base-cased"}, {"Key": JumpStartTag.MODEL_VERSION, "Value": "1.0.0"}, - {"Key": JumpStartTag.MODEL_CONFIG_NAME, "Value": "neuron-inference"}, + {"Key": JumpStartTag.INFERENCE_CONFIG_NAME, "Value": "neuron-inference"}, ], wait=True, endpoint_logging=False, diff --git a/tests/unit/sagemaker/jumpstart/test_predictor.py b/tests/unit/sagemaker/jumpstart/test_predictor.py index 1cc8f292f0..a3425a7b90 100644 --- a/tests/unit/sagemaker/jumpstart/test_predictor.py +++ b/tests/unit/sagemaker/jumpstart/test_predictor.py @@ -97,7 +97,7 @@ def test_proprietary_predictor_support( def test_jumpstart_predictor_support_no_model_id_supplied_happy_case( patched_get_model_specs, patched_verify_model_region_and_return_specs, - patched_get_jumpstart_model_id_version_from_endpoint, + patched_get_model_info_from_endpoint, patched_get_default_predictor, patched_predictor, ): @@ -105,20 +105,19 @@ def test_jumpstart_predictor_support_no_model_id_supplied_happy_case( patched_verify_model_region_and_return_specs.side_effect = verify_model_region_and_return_specs patched_get_model_specs.side_effect = get_special_model_spec - patched_get_jumpstart_model_id_version_from_endpoint.return_value = ( + patched_get_model_info_from_endpoint.return_value = ( "predictor-specs-model", "1.2.3", None, None, + None, ) mock_session = Mock() predictor.retrieve_default(endpoint_name="blah", sagemaker_session=mock_session) - patched_get_jumpstart_model_id_version_from_endpoint.assert_called_once_with( - "blah", None, mock_session - ) + patched_get_model_info_from_endpoint.assert_called_once_with("blah", None, mock_session) patched_get_default_predictor.assert_called_once_with( predictor=patched_predictor.return_value, diff --git a/tests/unit/sagemaker/jumpstart/test_session_utils.py b/tests/unit/sagemaker/jumpstart/test_session_utils.py index 9dc8acb32a..ce06a189bd 100644 --- a/tests/unit/sagemaker/jumpstart/test_session_utils.py +++ b/tests/unit/sagemaker/jumpstart/test_session_utils.py @@ -12,61 +12,63 @@ ) -@patch("sagemaker.jumpstart.session_utils.get_jumpstart_model_id_version_from_resource_arn") +@patch("sagemaker.jumpstart.session_utils.get_jumpstart_model_info_from_resource_arn") def test_get_model_info_from_training_job_happy_case( - mock_get_jumpstart_model_id_version_from_resource_arn, + mock_get_jumpstart_model_info_from_resource_arn, ): mock_sm_session = Mock() mock_sm_session.boto_region_name = "us-west-2" mock_sm_session.account_id = Mock(return_value="123456789012") - mock_get_jumpstart_model_id_version_from_resource_arn.return_value = ( + mock_get_jumpstart_model_info_from_resource_arn.return_value = ( "model_id", "model_version", None, + None, ) retval = get_model_info_from_training_job("bLaH", sagemaker_session=mock_sm_session) - assert retval == ("model_id", "model_version", None) + assert retval == ("model_id", "model_version", None, None) - mock_get_jumpstart_model_id_version_from_resource_arn.assert_called_once_with( + mock_get_jumpstart_model_info_from_resource_arn.assert_called_once_with( "arn:aws:sagemaker:us-west-2:123456789012:training-job/bLaH", mock_sm_session ) -@patch("sagemaker.jumpstart.session_utils.get_jumpstart_model_id_version_from_resource_arn") +@patch("sagemaker.jumpstart.session_utils.get_jumpstart_model_info_from_resource_arn") def test_get_model_info_from_training_job_config_name( - mock_get_jumpstart_model_id_version_from_resource_arn, + mock_get_jumpstart_model_info_from_resource_arn, ): mock_sm_session = Mock() mock_sm_session.boto_region_name = "us-west-2" mock_sm_session.account_id = Mock(return_value="123456789012") - mock_get_jumpstart_model_id_version_from_resource_arn.return_value = ( + mock_get_jumpstart_model_info_from_resource_arn.return_value = ( "model_id", "model_version", - "config_name", + None, + "training_config_name", ) retval = get_model_info_from_training_job("bLaH", sagemaker_session=mock_sm_session) - assert retval == ("model_id", "model_version", "config_name") + assert retval == ("model_id", "model_version", None, "training_config_name") - mock_get_jumpstart_model_id_version_from_resource_arn.assert_called_once_with( + mock_get_jumpstart_model_info_from_resource_arn.assert_called_once_with( "arn:aws:sagemaker:us-west-2:123456789012:training-job/bLaH", mock_sm_session ) -@patch("sagemaker.jumpstart.session_utils.get_jumpstart_model_id_version_from_resource_arn") +@patch("sagemaker.jumpstart.session_utils.get_jumpstart_model_info_from_resource_arn") def test_get_model_info_from_training_job_no_model_id_inferred( - mock_get_jumpstart_model_id_version_from_resource_arn, + mock_get_jumpstart_model_info_from_resource_arn, ): mock_sm_session = Mock() mock_sm_session.boto_region_name = "us-west-2" mock_sm_session.account_id = Mock(return_value="123456789012") - mock_get_jumpstart_model_id_version_from_resource_arn.return_value = ( + mock_get_jumpstart_model_info_from_resource_arn.return_value = ( None, None, ) @@ -75,43 +77,45 @@ def test_get_model_info_from_training_job_no_model_id_inferred( get_model_info_from_training_job("blah", sagemaker_session=mock_sm_session) -@patch("sagemaker.jumpstart.session_utils.get_jumpstart_model_id_version_from_resource_arn") +@patch("sagemaker.jumpstart.session_utils.get_jumpstart_model_info_from_resource_arn") def test_get_model_info_from_model_based_endpoint_happy_case( - mock_get_jumpstart_model_id_version_from_resource_arn, + mock_get_jumpstart_model_info_from_resource_arn, ): mock_sm_session = Mock() mock_sm_session.boto_region_name = "us-west-2" mock_sm_session.account_id = Mock(return_value="123456789012") - mock_get_jumpstart_model_id_version_from_resource_arn.return_value = ( + mock_get_jumpstart_model_info_from_resource_arn.return_value = ( "model_id", "model_version", None, + None, ) retval = _get_model_info_from_model_based_endpoint( "bLaH", inference_component_name=None, sagemaker_session=mock_sm_session ) - assert retval == ("model_id", "model_version", None) + assert retval == ("model_id", "model_version", None, None) - mock_get_jumpstart_model_id_version_from_resource_arn.assert_called_once_with( + mock_get_jumpstart_model_info_from_resource_arn.assert_called_once_with( "arn:aws:sagemaker:us-west-2:123456789012:endpoint/blah", mock_sm_session ) -@patch("sagemaker.jumpstart.session_utils.get_jumpstart_model_id_version_from_resource_arn") +@patch("sagemaker.jumpstart.session_utils.get_jumpstart_model_info_from_resource_arn") def test_get_model_info_from_model_based_endpoint_inference_component_supplied( - mock_get_jumpstart_model_id_version_from_resource_arn, + mock_get_jumpstart_model_info_from_resource_arn, ): mock_sm_session = Mock() mock_sm_session.boto_region_name = "us-west-2" mock_sm_session.account_id = Mock(return_value="123456789012") - mock_get_jumpstart_model_id_version_from_resource_arn.return_value = ( + mock_get_jumpstart_model_info_from_resource_arn.return_value = ( "model_id", "model_version", None, + None, ) with pytest.raises(ValueError): @@ -120,15 +124,16 @@ def test_get_model_info_from_model_based_endpoint_inference_component_supplied( ) -@patch("sagemaker.jumpstart.session_utils.get_jumpstart_model_id_version_from_resource_arn") +@patch("sagemaker.jumpstart.session_utils.get_jumpstart_model_info_from_resource_arn") def test_get_model_info_from_model_based_endpoint_no_model_id_inferred( - mock_get_jumpstart_model_id_version_from_resource_arn, + mock_get_jumpstart_model_info_from_resource_arn, ): mock_sm_session = Mock() mock_sm_session.boto_region_name = "us-west-2" mock_sm_session.account_id = Mock(return_value="123456789012") - mock_get_jumpstart_model_id_version_from_resource_arn.return_value = ( + mock_get_jumpstart_model_info_from_resource_arn.return_value = ( + None, None, None, ) @@ -139,40 +144,42 @@ def test_get_model_info_from_model_based_endpoint_no_model_id_inferred( ) -@patch("sagemaker.jumpstart.session_utils.get_jumpstart_model_id_version_from_resource_arn") +@patch("sagemaker.jumpstart.session_utils.get_jumpstart_model_info_from_resource_arn") def test_get_model_info_from_inference_component_endpoint_with_inference_component_name_happy_case( - mock_get_jumpstart_model_id_version_from_resource_arn, + mock_get_jumpstart_model_info_from_resource_arn, ): mock_sm_session = Mock() mock_sm_session.boto_region_name = "us-west-2" mock_sm_session.account_id = Mock(return_value="123456789012") - mock_get_jumpstart_model_id_version_from_resource_arn.return_value = ( + mock_get_jumpstart_model_info_from_resource_arn.return_value = ( "model_id", "model_version", None, + None, ) retval = _get_model_info_from_inference_component_endpoint_with_inference_component_name( "bLaH", sagemaker_session=mock_sm_session ) - assert retval == ("model_id", "model_version", None) + assert retval == ("model_id", "model_version", None, None) - mock_get_jumpstart_model_id_version_from_resource_arn.assert_called_once_with( + mock_get_jumpstart_model_info_from_resource_arn.assert_called_once_with( "arn:aws:sagemaker:us-west-2:123456789012:inference-component/bLaH", mock_sm_session ) -@patch("sagemaker.jumpstart.session_utils.get_jumpstart_model_id_version_from_resource_arn") +@patch("sagemaker.jumpstart.session_utils.get_jumpstart_model_info_from_resource_arn") def test_get_model_info_from_inference_component_endpoint_with_inference_component_name_no_model_id_inferred( - mock_get_jumpstart_model_id_version_from_resource_arn, + mock_get_jumpstart_model_info_from_resource_arn, ): mock_sm_session = Mock() mock_sm_session.boto_region_name = "us-west-2" mock_sm_session.account_id = Mock(return_value="123456789012") - mock_get_jumpstart_model_id_version_from_resource_arn.return_value = ( + mock_get_jumpstart_model_info_from_resource_arn.return_value = ( + None, None, None, None, @@ -272,11 +279,12 @@ def test_get_model_info_from_endpoint_non_inference_component_endpoint( "model_id", "model_version", None, + None, ) retval = get_model_info_from_endpoint("blah", sagemaker_session=mock_sm_session) - assert retval == ("model_id", "model_version", None, None) + assert retval == ("model_id", "model_version", None, None, None) mock_get_model_info_from_model_based_endpoint.assert_called_once_with( "blah", None, mock_sm_session ) @@ -296,13 +304,14 @@ def test_get_model_info_from_endpoint_inference_component_endpoint_with_inferenc "model_id", "model_version", None, + None, ) retval = get_model_info_from_endpoint( "blah", inference_component_name="icname", sagemaker_session=mock_sm_session ) - assert retval == ("model_id", "model_version", "icname", None) + assert retval == ("model_id", "model_version", "icname", None, None) mock_get_model_info_from_inference_component_endpoint_with_inference_component_name.assert_called_once_with( "icname", mock_sm_session ) @@ -322,12 +331,13 @@ def test_get_model_info_from_endpoint_inference_component_endpoint_without_infer "model_id", "model_version", None, + None, "inferred-icname", ) retval = get_model_info_from_endpoint("blah", sagemaker_session=mock_sm_session) - assert retval == ("model_id", "model_version", "inferred-icname", None) + assert retval == ("model_id", "model_version", "inferred-icname", None, None) mock_get_model_info_from_inference_component_endpoint_without_inference_component_name.assert_called_once() @@ -335,7 +345,7 @@ def test_get_model_info_from_endpoint_inference_component_endpoint_without_infer "sagemaker.jumpstart.session_utils._get_model_info_from_inference_component_" "endpoint_without_inference_component_name" ) -def test_get_model_info_from_endpoint_inference_component_endpoint_with_config_name( +def test_get_model_info_from_endpoint_inference_component_endpoint_with_inference_config_name( mock_get_model_info_from_inference_component_endpoint_without_inference_component_name, ): mock_sm_session = Mock() @@ -343,11 +353,35 @@ def test_get_model_info_from_endpoint_inference_component_endpoint_with_config_n mock_get_model_info_from_inference_component_endpoint_without_inference_component_name.return_value = ( "model_id", "model_version", - "config_name", + "inference_config_name", + None, + "inferred-icname", + ) + + retval = get_model_info_from_endpoint("blah", sagemaker_session=mock_sm_session) + + assert retval == ("model_id", "model_version", "inferred-icname", "inference_config_name", None) + mock_get_model_info_from_inference_component_endpoint_without_inference_component_name.assert_called_once() + + +@patch( + "sagemaker.jumpstart.session_utils._get_model_info_from_inference_component_" + "endpoint_without_inference_component_name" +) +def test_get_model_info_from_endpoint_inference_component_endpoint_with_training_config_name( + mock_get_model_info_from_inference_component_endpoint_without_inference_component_name, +): + mock_sm_session = Mock() + mock_sm_session.is_inference_component_based_endpoint.return_value = True + mock_get_model_info_from_inference_component_endpoint_without_inference_component_name.return_value = ( + "model_id", + "model_version", + None, + "training_config_name", "inferred-icname", ) retval = get_model_info_from_endpoint("blah", sagemaker_session=mock_sm_session) - assert retval == ("model_id", "model_version", "inferred-icname", "config_name") + assert retval == ("model_id", "model_version", "inferred-icname", None, "training_config_name") mock_get_model_info_from_inference_component_endpoint_without_inference_component_name.assert_called_once() diff --git a/tests/unit/sagemaker/jumpstart/test_utils.py b/tests/unit/sagemaker/jumpstart/test_utils.py index f7458a29e9..5b30a94dd6 100644 --- a/tests/unit/sagemaker/jumpstart/test_utils.py +++ b/tests/unit/sagemaker/jumpstart/test_utils.py @@ -210,16 +210,16 @@ def test_is_jumpstart_model_uri(): assert utils.is_jumpstart_model_uri(random_jumpstart_s3_uri("random_key")) -def test_add_jumpstart_model_id_version_tags(): +def test_add_jumpstart_model_info_tags(): tags = None model_id = "model_id" version = "version" + inference_config_name = "inference_config_name" + training_config_name = "training_config_name" assert [ {"Key": "sagemaker-sdk:jumpstart-model-id", "Value": "model_id"}, {"Key": "sagemaker-sdk:jumpstart-model-version", "Value": "version"}, - ] == utils.add_jumpstart_model_id_version_tags( - tags=tags, model_id=model_id, model_version=version - ) + ] == utils.add_jumpstart_model_info_tags(tags=tags, model_id=model_id, model_version=version) tags = [ {"Key": "sagemaker-sdk:jumpstart-model-id", "Value": "model_id_2"}, @@ -231,9 +231,7 @@ def test_add_jumpstart_model_id_version_tags(): assert [ {"Key": "sagemaker-sdk:jumpstart-model-id", "Value": "model_id_2"}, {"Key": "sagemaker-sdk:jumpstart-model-version", "Value": "version_2"}, - ] == utils.add_jumpstart_model_id_version_tags( - tags=tags, model_id=model_id, model_version=version - ) + ] == utils.add_jumpstart_model_info_tags(tags=tags, model_id=model_id, model_version=version) tags = [ {"Key": "random key", "Value": "random_value"}, @@ -244,9 +242,7 @@ def test_add_jumpstart_model_id_version_tags(): {"Key": "random key", "Value": "random_value"}, {"Key": "sagemaker-sdk:jumpstart-model-id", "Value": "model_id"}, {"Key": "sagemaker-sdk:jumpstart-model-version", "Value": "version"}, - ] == utils.add_jumpstart_model_id_version_tags( - tags=tags, model_id=model_id, model_version=version - ) + ] == utils.add_jumpstart_model_info_tags(tags=tags, model_id=model_id, model_version=version) tags = [ {"Key": "sagemaker-sdk:jumpstart-model-id", "Value": "model_id_2"}, @@ -257,9 +253,7 @@ def test_add_jumpstart_model_id_version_tags(): assert [ {"Key": "sagemaker-sdk:jumpstart-model-id", "Value": "model_id_2"}, {"Key": "sagemaker-sdk:jumpstart-model-version", "Value": "version"}, - ] == utils.add_jumpstart_model_id_version_tags( - tags=tags, model_id=model_id, model_version=version - ) + ] == utils.add_jumpstart_model_info_tags(tags=tags, model_id=model_id, model_version=version) tags = [ {"Key": "random key", "Value": "random_value"}, @@ -268,8 +262,58 @@ def test_add_jumpstart_model_id_version_tags(): version = None assert [ {"Key": "random key", "Value": "random_value"}, - ] == utils.add_jumpstart_model_id_version_tags( - tags=tags, model_id=model_id, model_version=version + ] == utils.add_jumpstart_model_info_tags(tags=tags, model_id=model_id, model_version=version) + + tags = [ + {"Key": "random key", "Value": "random_value"}, + ] + model_id = "model_id" + version = "version" + assert [ + {"Key": "random key", "Value": "random_value"}, + {"Key": "sagemaker-sdk:jumpstart-model-id", "Value": "model_id"}, + {"Key": "sagemaker-sdk:jumpstart-model-version", "Value": "version"}, + {"Key": "sagemaker-sdk:jumpstart-inference-config-name", "Value": "inference_config_name"}, + ] == utils.add_jumpstart_model_info_tags( + tags=tags, + model_id=model_id, + model_version=version, + config_name=inference_config_name, + scope=JumpStartScriptScope.INFERENCE, + ) + + tags = [ + {"Key": "random key", "Value": "random_value"}, + ] + model_id = "model_id" + version = "version" + assert [ + {"Key": "random key", "Value": "random_value"}, + {"Key": "sagemaker-sdk:jumpstart-model-id", "Value": "model_id"}, + {"Key": "sagemaker-sdk:jumpstart-model-version", "Value": "version"}, + {"Key": "sagemaker-sdk:jumpstart-training-config-name", "Value": "training_config_name"}, + ] == utils.add_jumpstart_model_info_tags( + tags=tags, + model_id=model_id, + model_version=version, + config_name=training_config_name, + scope=JumpStartScriptScope.TRAINING, + ) + + tags = [ + {"Key": "random key", "Value": "random_value"}, + ] + model_id = "model_id" + version = "version" + assert [ + {"Key": "random key", "Value": "random_value"}, + {"Key": "sagemaker-sdk:jumpstart-model-id", "Value": "model_id"}, + {"Key": "sagemaker-sdk:jumpstart-model-version", "Value": "version"}, + ] == utils.add_jumpstart_model_info_tags( + tags=tags, + model_id=model_id, + model_version=version, + config_name=training_config_name, ) @@ -1322,10 +1366,8 @@ def test_no_model_id_no_version_found(self): mock_list_tags.return_value = [{"Key": "blah", "Value": "blah1"}] self.assertEquals( - utils.get_jumpstart_model_id_version_from_resource_arn( - "some-arn", mock_sagemaker_session - ), - (None, None, None), + utils.get_jumpstart_model_info_from_resource_arn("some-arn", mock_sagemaker_session), + (None, None, None, None), ) mock_list_tags.assert_called_once_with("some-arn") @@ -1339,10 +1381,8 @@ def test_model_id_no_version_found(self): ] self.assertEquals( - utils.get_jumpstart_model_id_version_from_resource_arn( - "some-arn", mock_sagemaker_session - ), - ("model_id", None, None), + utils.get_jumpstart_model_info_from_resource_arn("some-arn", mock_sagemaker_session), + ("model_id", None, None, None), ) mock_list_tags.assert_called_once_with("some-arn") @@ -1356,10 +1396,8 @@ def test_no_model_id_version_found(self): ] self.assertEquals( - utils.get_jumpstart_model_id_version_from_resource_arn( - "some-arn", mock_sagemaker_session - ), - (None, "model_version", None), + utils.get_jumpstart_model_info_from_resource_arn("some-arn", mock_sagemaker_session), + (None, "model_version", None, None), ) mock_list_tags.assert_called_once_with("some-arn") @@ -1370,27 +1408,54 @@ def test_no_config_name_found(self): mock_list_tags.return_value = [{"Key": "blah", "Value": "blah1"}] self.assertEquals( - utils.get_jumpstart_model_id_version_from_resource_arn( - "some-arn", mock_sagemaker_session - ), - (None, None, None), + utils.get_jumpstart_model_info_from_resource_arn("some-arn", mock_sagemaker_session), + (None, None, None, None), ) mock_list_tags.assert_called_once_with("some-arn") - def test_config_name_found(self): + def test_inference_config_name_found(self): mock_list_tags = Mock() mock_sagemaker_session = Mock() mock_sagemaker_session.list_tags = mock_list_tags mock_list_tags.return_value = [ {"Key": "blah", "Value": "blah1"}, - {"Key": JumpStartTag.MODEL_CONFIG_NAME, "Value": "config_name"}, + {"Key": JumpStartTag.INFERENCE_CONFIG_NAME, "Value": "config_name"}, ] self.assertEquals( - utils.get_jumpstart_model_id_version_from_resource_arn( - "some-arn", mock_sagemaker_session - ), - (None, None, "config_name"), + utils.get_jumpstart_model_info_from_resource_arn("some-arn", mock_sagemaker_session), + (None, None, "config_name", None), + ) + mock_list_tags.assert_called_once_with("some-arn") + + def test_training_config_name_found(self): + mock_list_tags = Mock() + mock_sagemaker_session = Mock() + mock_sagemaker_session.list_tags = mock_list_tags + mock_list_tags.return_value = [ + {"Key": "blah", "Value": "blah1"}, + {"Key": JumpStartTag.TRAINING_CONFIG_NAME, "Value": "config_name"}, + ] + + self.assertEquals( + utils.get_jumpstart_model_info_from_resource_arn("some-arn", mock_sagemaker_session), + (None, None, None, "config_name"), + ) + mock_list_tags.assert_called_once_with("some-arn") + + def test_both_config_name_found(self): + mock_list_tags = Mock() + mock_sagemaker_session = Mock() + mock_sagemaker_session.list_tags = mock_list_tags + mock_list_tags.return_value = [ + {"Key": "blah", "Value": "blah1"}, + {"Key": JumpStartTag.INFERENCE_CONFIG_NAME, "Value": "inference_config_name"}, + {"Key": JumpStartTag.TRAINING_CONFIG_NAME, "Value": "training_config_name"}, + ] + + self.assertEquals( + utils.get_jumpstart_model_info_from_resource_arn("some-arn", mock_sagemaker_session), + (None, None, "inference_config_name", "training_config_name"), ) mock_list_tags.assert_called_once_with("some-arn") @@ -1405,10 +1470,8 @@ def test_model_id_version_found(self): ] self.assertEquals( - utils.get_jumpstart_model_id_version_from_resource_arn( - "some-arn", mock_sagemaker_session - ), - ("model_id", "model_version", None), + utils.get_jumpstart_model_info_from_resource_arn("some-arn", mock_sagemaker_session), + ("model_id", "model_version", None, None), ) mock_list_tags.assert_called_once_with("some-arn") @@ -1425,10 +1488,8 @@ def test_multiple_model_id_versions_found(self): ] self.assertEquals( - utils.get_jumpstart_model_id_version_from_resource_arn( - "some-arn", mock_sagemaker_session - ), - (None, None, None), + utils.get_jumpstart_model_info_from_resource_arn("some-arn", mock_sagemaker_session), + (None, None, None, None), ) mock_list_tags.assert_called_once_with("some-arn") @@ -1445,10 +1506,8 @@ def test_multiple_model_id_versions_found_aliases_consistent(self): ] self.assertEquals( - utils.get_jumpstart_model_id_version_from_resource_arn( - "some-arn", mock_sagemaker_session - ), - ("model_id_1", "model_version_1", None), + utils.get_jumpstart_model_info_from_resource_arn("some-arn", mock_sagemaker_session), + ("model_id_1", "model_version_1", None, None), ) mock_list_tags.assert_called_once_with("some-arn") @@ -1465,10 +1524,8 @@ def test_multiple_model_id_versions_found_aliases_inconsistent(self): ] self.assertEquals( - utils.get_jumpstart_model_id_version_from_resource_arn( - "some-arn", mock_sagemaker_session - ), - (None, None, None), + utils.get_jumpstart_model_info_from_resource_arn("some-arn", mock_sagemaker_session), + (None, None, None, None), ) mock_list_tags.assert_called_once_with("some-arn") @@ -1480,15 +1537,13 @@ def test_multiple_config_names_found_aliases_inconsistent(self): {"Key": "blah", "Value": "blah1"}, {"Key": JumpStartTag.MODEL_ID, "Value": "model_id_1"}, {"Key": JumpStartTag.MODEL_VERSION, "Value": "model_version_1"}, - {"Key": JumpStartTag.MODEL_CONFIG_NAME, "Value": "config_name_1"}, - {"Key": JumpStartTag.MODEL_CONFIG_NAME, "Value": "config_name_2"}, + {"Key": JumpStartTag.INFERENCE_CONFIG_NAME, "Value": "config_name_1"}, + {"Key": JumpStartTag.INFERENCE_CONFIG_NAME, "Value": "config_name_2"}, ] self.assertEquals( - utils.get_jumpstart_model_id_version_from_resource_arn( - "some-arn", mock_sagemaker_session - ), - ("model_id_1", "model_version_1", None), + utils.get_jumpstart_model_info_from_resource_arn("some-arn", mock_sagemaker_session), + ("model_id_1", "model_version_1", None, None), ) mock_list_tags.assert_called_once_with("some-arn") diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py index 083e2dd09a..93abcfc7a8 100644 --- a/tests/unit/test_utils.py +++ b/tests/unit/test_utils.py @@ -1819,7 +1819,13 @@ def test_can_model_package_source_uri_autopopulate(): class TestDeepMergeDict(TestCase): def test_flatten_dict_basic(self): nested_dict = {"a": 1, "b": {"x": 2, "y": {"p": 3, "q": 4}}, "c": 5} - flattened_dict = {"a": 1, "b^x": 2, "b^y^p": 3, "b^y^q": 4, "c": 5} + flattened_dict = { + ("a",): 1, + ("b", "x"): 2, + ("b", "y", "p"): 3, + ("b", "y", "q"): 4, + ("c",): 5, + } self.assertDictEqual(flatten_dict(nested_dict), flattened_dict) self.assertDictEqual(unflatten_dict(flattened_dict), nested_dict) @@ -1831,13 +1837,19 @@ def test_flatten_dict_empty(self): def test_flatten_dict_no_nested(self): nested_dict = {"a": 1, "b": 2, "c": 3} - flattened_dict = {"a": 1, "b": 2, "c": 3} + flattened_dict = {("a",): 1, ("b",): 2, ("c",): 3} self.assertDictEqual(flatten_dict(nested_dict), flattened_dict) self.assertDictEqual(unflatten_dict(flattened_dict), nested_dict) def test_flatten_dict_with_various_types(self): nested_dict = {"a": [1, 2, 3], "b": {"x": None, "y": {"p": [], "q": ""}}, "c": 9} - flattened_dict = {"a": [1, 2, 3], "b^x": None, "b^y^p": [], "b^y^q": "", "c": 9} + flattened_dict = { + ("a",): [1, 2, 3], + ("b", "x"): None, + ("b", "y", "p"): [], + ("b", "y", "q"): "", + ("c",): 9, + } self.assertDictEqual(flatten_dict(nested_dict), flattened_dict) self.assertDictEqual(unflatten_dict(flattened_dict), nested_dict) From b42f8dcb5832b554fbdcca929c1fe6b50a3a9e83 Mon Sep 17 00:00:00 2001 From: Haotian An <33510317+Captainia@users.noreply.github.com> Date: Thu, 2 May 2024 13:04:04 -0400 Subject: [PATCH 28/32] Merge master --- src/sagemaker/jumpstart/estimator.py | 21 +- src/sagemaker/jumpstart/factory/estimator.py | 35 +++- src/sagemaker/jumpstart/factory/model.py | 77 +++++++- src/sagemaker/jumpstart/types.py | 48 +++-- tests/unit/sagemaker/jumpstart/constants.py | 26 +++ .../jumpstart/estimator/test_estimator.py | 179 +++++++++++++++++- tests/unit/sagemaker/jumpstart/test_types.py | 7 + 7 files changed, 346 insertions(+), 47 deletions(-) diff --git a/src/sagemaker/jumpstart/estimator.py b/src/sagemaker/jumpstart/estimator.py index 4939be4041..3132ea4d26 100644 --- a/src/sagemaker/jumpstart/estimator.py +++ b/src/sagemaker/jumpstart/estimator.py @@ -504,7 +504,7 @@ def __init__( enable_remote_debug (bool or PipelineVariable): Optional. Specifies whether RemoteDebug is enabled for the training job config_name (Optional[str]): - Name of the JumpStart Model config to apply. (Default: None). + Name of the training configuration to apply to the Estimator. (Default: None). Raises: ValueError: If the model ID is not recognized by JumpStart. @@ -686,6 +686,7 @@ def attach( model_version: Optional[str] = None, sagemaker_session: session.Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, model_channel_name: str = "model", + config_name: Optional[str] = None, ) -> "JumpStartEstimator": """Attach to an existing training job. @@ -721,6 +722,8 @@ def attach( model data will be downloaded (default: 'model'). If no channel with the same name exists in the training job, this option will be ignored. + config_name (str): Optional. Name of the training configuration to use + when attaching to the training job. (Default: None). Returns: Instance of the calling ``JumpStartEstimator`` Class with the attached @@ -732,7 +735,6 @@ def attach( """ config_name = None if model_id is None: - model_id, model_version, _, config_name = get_model_info_from_training_job( training_job_name=training_job_name, sagemaker_session=sagemaker_session ) @@ -746,6 +748,9 @@ def attach( "tolerate_deprecated_model": True, # model is already trained } + if config_name: + additional_kwargs.update({"config_name": config_name}) + model_specs = verify_model_region_and_return_specs( model_id=model_id, version=model_version, @@ -804,6 +809,7 @@ def deploy( dependencies: Optional[List[str]] = None, git_config: Optional[Dict[str, str]] = None, use_compiled_model: bool = False, + inference_config_name: Optional[str] = None, ) -> PredictorBase: """Creates endpoint from training job. @@ -1039,6 +1045,8 @@ def deploy( (Default: None). use_compiled_model (bool): Flag to select whether to use compiled (optimized) model. (Default: False). + inference_config_name (Optional[str]): Name of the inference configuration to + be used in the model. (Default: None). """ self.orig_predictor_cls = predictor_cls @@ -1091,7 +1099,8 @@ def deploy( git_config=git_config, use_compiled_model=use_compiled_model, training_instance_type=self.instance_type, - config_name=self.config_name, + training_config_name=self.config_name, + inference_config_name=inference_config_name, ) predictor = super(JumpStartEstimator, self).deploy( @@ -1108,7 +1117,7 @@ def deploy( tolerate_deprecated_model=self.tolerate_deprecated_model, tolerate_vulnerable_model=self.tolerate_vulnerable_model, sagemaker_session=self.sagemaker_session, - config_name=self.config_name, + config_name=estimator_deploy_kwargs.config_name, ) # If a predictor class was passed, do not mutate predictor @@ -1140,7 +1149,9 @@ def set_training_config(self, config_name: str) -> None: config_name (str): The name of the config. """ self.__init__( - model_id=self.model_id, model_version=self.model_version, config_name=config_name + model_id=self.model_id, + model_version=self.model_version, + config_name=config_name, ) def __str__(self) -> str: diff --git a/src/sagemaker/jumpstart/factory/estimator.py b/src/sagemaker/jumpstart/factory/estimator.py index 604b20bc81..9177265d74 100644 --- a/src/sagemaker/jumpstart/factory/estimator.py +++ b/src/sagemaker/jumpstart/factory/estimator.py @@ -207,6 +207,7 @@ def get_init_kwargs( estimator_init_kwargs = _add_role_to_kwargs(estimator_init_kwargs) estimator_init_kwargs = _add_env_to_kwargs(estimator_init_kwargs) estimator_init_kwargs = _add_tags_to_kwargs(estimator_init_kwargs) + estimator_init_kwargs = _add_config_name_to_kwargs(estimator_init_kwargs) return estimator_init_kwargs @@ -291,7 +292,8 @@ def get_deploy_kwargs( use_compiled_model: Optional[bool] = None, model_name: Optional[str] = None, training_instance_type: Optional[str] = None, - config_name: Optional[str] = None, + training_config_name: Optional[str] = None, + inference_config_name: Optional[str] = None, ) -> JumpStartEstimatorDeployKwargs: """Returns kwargs required to call `deploy` on `sagemaker.estimator.Estimator` object.""" @@ -319,7 +321,8 @@ def get_deploy_kwargs( tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, - config_name=config_name, + training_config_name=training_config_name, + config_name=inference_config_name, ) model_init_kwargs: JumpStartModelInitKwargs = model.get_init_kwargs( @@ -348,7 +351,7 @@ def get_deploy_kwargs( tolerate_deprecated_model=tolerate_deprecated_model, training_instance_type=training_instance_type, disable_instance_type_logging=True, - config_name=config_name, + config_name=model_deploy_kwargs.config_name, ) estimator_deploy_kwargs: JumpStartEstimatorDeployKwargs = JumpStartEstimatorDeployKwargs( @@ -393,7 +396,7 @@ def get_deploy_kwargs( tolerate_vulnerable_model=model_init_kwargs.tolerate_vulnerable_model, tolerate_deprecated_model=model_init_kwargs.tolerate_deprecated_model, use_compiled_model=use_compiled_model, - config_name=config_name, + config_name=model_deploy_kwargs.config_name, ) return estimator_deploy_kwargs @@ -793,3 +796,27 @@ def _add_fit_extra_kwargs(kwargs: JumpStartEstimatorFitKwargs) -> JumpStartEstim setattr(kwargs, key, value) return kwargs + + +def _add_config_name_to_kwargs( + kwargs: JumpStartEstimatorInitKwargs, +) -> JumpStartEstimatorInitKwargs: + """Sets tags in kwargs based on default or override, returns full kwargs.""" + + specs = verify_model_region_and_return_specs( + model_id=kwargs.model_id, + version=kwargs.model_version, + scope=JumpStartScriptScope.TRAINING, + region=kwargs.region, + tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, + tolerate_deprecated_model=kwargs.tolerate_deprecated_model, + sagemaker_session=kwargs.sagemaker_session, + config_name=kwargs.config_name, + ) + + if specs.training_configs and specs.training_configs.get_top_config_from_ranking(): + kwargs.config_name = ( + kwargs.config_name or specs.training_configs.get_top_config_from_ranking().config_name + ) + + return kwargs diff --git a/src/sagemaker/jumpstart/factory/model.py b/src/sagemaker/jumpstart/factory/model.py index 54301973e8..79a7b18788 100644 --- a/src/sagemaker/jumpstart/factory/model.py +++ b/src/sagemaker/jumpstart/factory/model.py @@ -42,6 +42,7 @@ JumpStartModelDeployKwargs, JumpStartModelInitKwargs, JumpStartModelRegisterKwargs, + JumpStartModelSpecs, ) from sagemaker.jumpstart.utils import ( add_jumpstart_model_info_tags, @@ -548,7 +549,27 @@ def _add_resources_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModel return kwargs -def _add_config_name_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModelInitKwargs: +def _select_inference_config_from_training_config( + specs: JumpStartModelSpecs, training_config_name: str +) -> Optional[str]: + """Selects the inference config from the training config. + + Args: + specs (JumpStartModelSpecs): The specs for the model. + training_config_name (str): The name of the training config. + + Returns: + str: The name of the inference config. + """ + if specs.training_configs: + resolved_training_config = specs.training_configs.configs.get(training_config_name) + if resolved_training_config: + return resolved_training_config.default_inference_config + + return None + + +def _add_config_name_to_init_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModelInitKwargs: """Sets default config name to the kwargs. Returns full kwargs. Raises: @@ -566,13 +587,9 @@ def _add_config_name_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartMod model_type=kwargs.model_type, config_name=kwargs.config_name, ) - if ( - specs.inference_configs - and specs.inference_configs.get_top_config_from_ranking().config_name - ): - kwargs.config_name = ( - kwargs.config_name or specs.inference_configs.get_top_config_from_ranking().config_name - ) + if specs.inference_configs: + default_config_name = specs.inference_configs.get_top_config_from_ranking().config_name + kwargs.config_name = kwargs.config_name or default_config_name if not kwargs.config_name: return kwargs @@ -593,6 +610,42 @@ def _add_config_name_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartMod return kwargs +def _add_config_name_to_deploy_kwargs( + kwargs: JumpStartModelDeployKwargs, training_config_name: Optional[str] = None +) -> JumpStartModelInitKwargs: + """Sets default config name to the kwargs. Returns full kwargs. + + If a training_config_name is passed, then choose the inference config + based on the supported inference configs in that training config. + + Raises: + ValueError: If the instance_type is not supported with the current config. + """ + + specs = verify_model_region_and_return_specs( + model_id=kwargs.model_id, + version=kwargs.model_version, + scope=JumpStartScriptScope.INFERENCE, + region=kwargs.region, + tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, + tolerate_deprecated_model=kwargs.tolerate_deprecated_model, + sagemaker_session=kwargs.sagemaker_session, + model_type=kwargs.model_type, + config_name=kwargs.config_name, + ) + + if training_config_name: + kwargs.config_name = _select_inference_config_from_training_config( + specs=specs, training_config_name=training_config_name + ) + + if specs.inference_configs: + default_config_name = specs.inference_configs.get_top_config_from_ranking().config_name + kwargs.config_name = kwargs.config_name or default_config_name + + return kwargs + + def get_deploy_kwargs( model_id: str, model_version: Optional[str] = None, @@ -623,6 +676,7 @@ def get_deploy_kwargs( resources: Optional[ResourceRequirements] = None, managed_instance_scaling: Optional[str] = None, endpoint_type: Optional[EndpointType] = None, + training_config_name: Optional[str] = None, config_name: Optional[str] = None, ) -> JumpStartModelDeployKwargs: """Returns kwargs required to call `deploy` on `sagemaker.estimator.Model` object.""" @@ -664,6 +718,10 @@ def get_deploy_kwargs( deploy_kwargs = _add_endpoint_name_to_kwargs(kwargs=deploy_kwargs) + deploy_kwargs = _add_config_name_to_deploy_kwargs( + kwargs=deploy_kwargs, training_config_name=training_config_name + ) + deploy_kwargs = _add_instance_type_to_kwargs(kwargs=deploy_kwargs) deploy_kwargs.initial_instance_count = initial_instance_count or 1 @@ -858,6 +916,7 @@ def get_init_kwargs( model_init_kwargs = _add_model_package_arn_to_kwargs(kwargs=model_init_kwargs) model_init_kwargs = _add_resources_to_kwargs(kwargs=model_init_kwargs) - model_init_kwargs = _add_config_name_to_kwargs(kwargs=model_init_kwargs) + + model_init_kwargs = _add_config_name_to_init_kwargs(kwargs=model_init_kwargs) return model_init_kwargs diff --git a/src/sagemaker/jumpstart/types.py b/src/sagemaker/jumpstart/types.py index e0a0f9bea7..0a586f60aa 100644 --- a/src/sagemaker/jumpstart/types.py +++ b/src/sagemaker/jumpstart/types.py @@ -1077,30 +1077,52 @@ class JumpStartMetadataConfig(JumpStartDataHolderType): "config_components", "resolved_metadata_config", "config_name", + "default_inference_config", + "default_incremental_trainig_config", + "supported_inference_configs", + "supported_incremental_training_configs", ] def __init__( self, config_name: str, + config: Dict[str, Any], base_fields: Dict[str, Any], config_components: Dict[str, JumpStartConfigComponent], - benchmark_metrics: Dict[str, List[JumpStartBenchmarkStat]], ): """Initializes a JumpStartMetadataConfig object from its json representation. Args: + config_name (str): Name of the config, + config (Dict[str, Any]): + Dictionary representation of the config. base_fields (Dict[str, Any]): The default base fields that are used to construct the final resolved config. config_components (Dict[str, JumpStartConfigComponent]): The list of components that are used to construct the resolved config. - benchmark_metrics (Dict[str, List[JumpStartBenchmarkStat]]): - The dictionary of benchmark metrics with name being the key. """ self.base_fields = base_fields self.config_components: Dict[str, JumpStartConfigComponent] = config_components - self.benchmark_metrics: Dict[str, List[JumpStartBenchmarkStat]] = benchmark_metrics + self.benchmark_metrics: Dict[str, List[JumpStartBenchmarkStat]] = ( + { + stat_name: [JumpStartBenchmarkStat(stat) for stat in stats] + for stat_name, stats in config.get("benchmark_metrics").items() + } + if config and config.get("benchmark_metrics") + else None + ) self.resolved_metadata_config: Optional[Dict[str, Any]] = None self.config_name: Optional[str] = config_name + self.default_inference_config: Optional[str] = config.get("default_inference_config") + self.default_incremental_trainig_config: Optional[str] = config.get( + "default_incremental_training_config" + ) + self.supported_inference_configs: Optional[List[str]] = config.get( + "supported_inference_configs" + ) + self.supported_incremental_training_configs: Optional[List[str]] = config.get( + "supported_incremental_training_configs" + ) def to_json(self) -> Dict[str, Any]: """Returns json representation of JumpStartMetadataConfig object.""" @@ -1255,6 +1277,7 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: { alias: JumpStartMetadataConfig( alias, + config, json_obj, ( { @@ -1264,14 +1287,6 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: if config and config.get("component_names") else None ), - ( - { - stat_name: [JumpStartBenchmarkStat(stat) for stat in stats] - for stat_name, stats in config.get("benchmark_metrics").items() - } - if config and config.get("benchmark_metrics") - else None - ), ) for alias, config in json_obj["inference_configs"].items() } @@ -1308,6 +1323,7 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: { alias: JumpStartMetadataConfig( alias, + config, json_obj, ( { @@ -1317,14 +1333,6 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: if config and config.get("component_names") else None ), - ( - { - stat_name: [JumpStartBenchmarkStat(stat) for stat in stats] - for stat_name, stats in config.get("benchmark_metrics").items() - } - if config and config.get("benchmark_metrics") - else None - ), ) for alias, config in json_obj["training_configs"].items() } diff --git a/tests/unit/sagemaker/jumpstart/constants.py b/tests/unit/sagemaker/jumpstart/constants.py index 90f037daea..3815bfc9ef 100644 --- a/tests/unit/sagemaker/jumpstart/constants.py +++ b/tests/unit/sagemaker/jumpstart/constants.py @@ -7752,6 +7752,10 @@ "ml.tr1n1.4xlarge": [{"name": "Latency", "value": "50", "unit": "Tokens/S"}], }, "component_names": ["neuron-training"], + "default_inference_config": "neuron-inference", + "default_incremental_training_config": "neuron-training", + "supported_inference_configs": ["neuron-inference", "neuron-inference-budget"], + "supported_incremental_training_configs": ["neuron-training", "neuron-training-budget"], }, "neuron-training-budget": { "benchmark_metrics": { @@ -7759,24 +7763,43 @@ "ml.tr1n1.4xlarge": [{"name": "Latency", "value": "50", "unit": "Tokens/S"}], }, "component_names": ["neuron-training-budget"], + "default_inference_config": "neuron-inference-budget", + "default_incremental_training_config": "neuron-training-budget", + "supported_inference_configs": ["neuron-inference", "neuron-inference-budget"], + "supported_incremental_training_configs": ["neuron-training", "neuron-training-budget"], }, "gpu-training": { "benchmark_metrics": { "ml.p3.2xlarge": [{"name": "Latency", "value": "200", "unit": "Tokens/S"}], }, "component_names": ["gpu-training"], + "default_inference_config": "gpu-inference", + "default_incremental_training_config": "gpu-training", + "supported_inference_configs": ["gpu-inference", "gpu-inference-budget"], + "supported_incremental_training_configs": ["gpu-training", "gpu-training-budget"], }, "gpu-training-budget": { "benchmark_metrics": { "ml.p3.2xlarge": [{"name": "Latency", "value": "100", "unit": "Tokens/S"}] }, "component_names": ["gpu-training-budget"], + "default_inference_config": "gpu-inference-budget", + "default_incremental_training_config": "gpu-training-budget", + "supported_inference_configs": ["gpu-inference", "gpu-inference-budget"], + "supported_incremental_training_configs": ["gpu-training", "gpu-training-budget"], }, }, "training_config_components": { "neuron-training": { + "default_training_instance_type": "ml.trn1.2xlarge", "supported_training_instance_types": ["ml.trn1.xlarge", "ml.trn1.2xlarge"], "training_artifact_key": "artifacts/meta-textgeneration-llama-2-7b/neuron-training/model/", + "training_ecr_specs": { + "framework": "huggingface", + "framework_version": "2.0.0", + "py_version": "py310", + "huggingface_transformers_version": "4.28.1", + }, "training_instance_type_variants": { "regional_aliases": { "us-west-2": { @@ -7788,6 +7811,7 @@ }, }, "gpu-training": { + "default_training_instance_type": "ml.p2.xlarge", "supported_training_instance_types": ["ml.p2.xlarge", "ml.p3.2xlarge"], "training_artifact_key": "artifacts/meta-textgeneration-llama-2-7b/gpu-training/model/", "training_instance_type_variants": { @@ -7804,6 +7828,7 @@ }, }, "neuron-training-budget": { + "default_training_instance_type": "ml.trn1.2xlarge", "supported_training_instance_types": ["ml.trn1.xlarge", "ml.trn1.2xlarge"], "training_artifact_key": "artifacts/meta-textgeneration-llama-2-7b/neuron-training-budget/model/", "training_instance_type_variants": { @@ -7817,6 +7842,7 @@ }, }, "gpu-training-budget": { + "default_training_instance_type": "ml.p2.xlarge", "supported_training_instance_types": ["ml.p2.xlarge", "ml.p3.2xlarge"], "training_artifact_key": "artifacts/meta-textgeneration-llama-2-7b/gpu-training-budget/model/", "training_instance_type_variants": { diff --git a/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py b/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py index 2794a5319f..2af470a13e 100644 --- a/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py +++ b/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py @@ -682,7 +682,6 @@ def test_estimator_use_kwargs(self): "input_mode": "File", "output_path": "Optional[Union[str, PipelineVariable]] = None", "output_kms_key": "Optional[Union[str, PipelineVariable]] = None", - "base_job_name": "Optional[str] = None", "sagemaker_session": DEFAULT_JUMPSTART_SAGEMAKER_SESSION, "hyperparameters": {"hyp1": "val1"}, "tags": [], @@ -1141,7 +1140,9 @@ def test_jumpstart_estimator_kwargs_match_parent_class(self): js_class_deploy = JumpStartEstimator.deploy js_class_deploy_args = set(signature(js_class_deploy).parameters.keys()) - assert js_class_deploy_args - parent_class_deploy_args == model_class_init_args - { + assert js_class_deploy_args - parent_class_deploy_args - { + "inference_config_name" + } == model_class_init_args - { "model_data", "self", "name", @@ -1886,14 +1887,17 @@ def test_estimator_initialization_with_config_name( mock_session.return_value = sagemaker_session - estimator = JumpStartEstimator(model_id=model_id, config_name="neuron-training") + estimator = JumpStartEstimator( + model_id=model_id, + config_name="gpu-training", + ) mock_estimator_init.assert_called_once_with( instance_type="ml.p2.xlarge", instance_count=1, image_uri="763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-training:1.5.0-gpu-py3", model_uri="s3://jumpstart-cache-prod-us-west-2/artifacts/meta-textgeneration-llama-2-7b/" - "neuron-training/model/", + "gpu-training/model/", source_dir="s3://jumpstart-cache-prod-us-west-2/source-directory-tarballs/pytorch/" "transfer_learning/eqa/v1.0.0/sourcedir.tar.gz", entry_point="transfer_learning.py", @@ -1903,7 +1907,7 @@ def test_estimator_initialization_with_config_name( tags=[ {"Key": JumpStartTag.MODEL_ID, "Value": "pytorch-eqa-bert-base-cased"}, {"Key": JumpStartTag.MODEL_VERSION, "Value": "1.0.0"}, - {"Key": JumpStartTag.TRAINING_CONFIG_NAME, "Value": "neuron-training"}, + {"Key": JumpStartTag.TRAINING_CONFIG_NAME, "Value": "gpu-training"}, ], enable_network_isolation=False, ) @@ -1936,16 +1940,16 @@ def test_estimator_set_config_name( mock_session.return_value = sagemaker_session - estimator = JumpStartEstimator(model_id=model_id) + estimator = JumpStartEstimator(model_id=model_id, config_name="gpu-training") - estimator.set_training_config(config_name="neuron-training") + estimator.set_training_config(config_name="gpu-training-budget") mock_estimator_init.assert_called_with( instance_type="ml.p2.xlarge", instance_count=1, image_uri="763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-training:1.5.0-gpu-py3", model_uri="s3://jumpstart-cache-prod-us-west-2/artifacts/meta-textgeneration-llama-2-7b/" - "neuron-training/model/", + "gpu-training-budget/model/", source_dir="s3://jumpstart-cache-prod-us-west-2/source-directory-tarballs/pytorch/" "transfer_learning/eqa/v1.0.0/sourcedir.tar.gz", entry_point="transfer_learning.py", @@ -1955,7 +1959,7 @@ def test_estimator_set_config_name( tags=[ {"Key": JumpStartTag.MODEL_ID, "Value": "pytorch-eqa-bert-base-cased"}, {"Key": JumpStartTag.MODEL_VERSION, "Value": "1.0.0"}, - {"Key": JumpStartTag.TRAINING_CONFIG_NAME, "Value": "neuron-training"}, + {"Key": JumpStartTag.TRAINING_CONFIG_NAME, "Value": "gpu-training-budget"}, ], enable_network_isolation=False, ) @@ -1964,6 +1968,163 @@ def test_estimator_set_config_name( mock_estimator_fit.assert_called_once_with(wait=True) + @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") + @mock.patch("sagemaker.jumpstart.factory.estimator.Session") + @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") + @mock.patch("sagemaker.jumpstart.estimator.Estimator.deploy") + @mock.patch("sagemaker.jumpstart.estimator.Estimator.fit") + @mock.patch("sagemaker.jumpstart.factory.estimator.JUMPSTART_DEFAULT_REGION_NAME", region) + def test_estimator_default_inference_config( + self, + mock_estimator_fit: mock.Mock, + mock_estimator_deploy: mock.Mock, + mock_get_model_specs: mock.Mock, + mock_session: mock.Mock, + mock_get_manifest: mock.Mock, + ): + mock_estimator_deploy.return_value = default_predictor + mock_get_model_specs.side_effect = get_prototype_spec_with_configs + mock_get_manifest.side_effect = ( + lambda region, model_type, *args, **kwargs: get_prototype_manifest(region, model_type) + ) + mock_estimator_fit.return_value = default_predictor + + model_id, _ = "pytorch-eqa-bert-base-cased", "*" + + mock_session.return_value = sagemaker_session + + estimator = JumpStartEstimator(model_id=model_id, config_name="gpu-training") + + assert estimator.config_name == "gpu-training" + + estimator.deploy() + + mock_estimator_deploy.assert_called_once_with( + instance_type="ml.p2.xlarge", + initial_instance_count=1, + image_uri="763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-inference:1.5.0-gpu-py3", + source_dir="s3://jumpstart-cache-prod-us-west-2/source-directory-tarballs/" + "pytorch/inference/eqa/v1.0.0/sourcedir.tar.gz", + entry_point="inference.py", + predictor_cls=Predictor, + wait=True, + role="fake role! do not use!", + use_compiled_model=False, + enable_network_isolation=False, + tags=[ + {"Key": JumpStartTag.MODEL_ID, "Value": "pytorch-eqa-bert-base-cased"}, + {"Key": JumpStartTag.MODEL_VERSION, "Value": "1.0.0"}, + {"Key": JumpStartTag.INFERENCE_CONFIG_NAME, "Value": "gpu-inference"}, + ], + ) + + @mock.patch("sagemaker.jumpstart.estimator.JumpStartEstimator._attach") + @mock.patch("sagemaker.jumpstart.estimator.get_model_info_from_training_job") + @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") + @mock.patch("sagemaker.jumpstart.factory.estimator.Session") + @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") + @mock.patch("sagemaker.jumpstart.estimator.Estimator.fit") + @mock.patch("sagemaker.jumpstart.estimator.Estimator.__init__") + @mock.patch("sagemaker.jumpstart.factory.estimator.JUMPSTART_DEFAULT_REGION_NAME", region) + def test_estimator_incremental_training_config( + self, + mock_estimator_init: mock.Mock, + mock_estimator_fit: mock.Mock, + mock_get_model_specs: mock.Mock, + mock_session: mock.Mock, + mock_get_manifest: mock.Mock, + mock_get_model_info_from_training_job: mock.Mock, + mock_attach: mock.Mock, + ): + mock_get_model_info_from_training_job.return_value = ( + "pytorch-eqa-bert-base-cased", + "1.0.0", + None, + "gpu-training-budget", + ) + mock_get_model_specs.side_effect = get_prototype_spec_with_configs + mock_get_manifest.side_effect = ( + lambda region, model_type, *args, **kwargs: get_prototype_manifest(region, model_type) + ) + mock_estimator_fit.return_value = default_predictor + + model_id, _ = "pytorch-eqa-bert-base-cased", "*" + + mock_session.return_value = sagemaker_session + + estimator = JumpStartEstimator(model_id=model_id, config_name="gpu-training") + + assert estimator.config_name == "gpu-training" + + JumpStartEstimator.attach( + training_job_name="some-training-job-name", sagemaker_session=mock_session + ) + + mock_attach.assert_called_once_with( + training_job_name="some-training-job-name", + sagemaker_session=mock_session, + model_channel_name="model", + additional_kwargs={ + "model_id": "pytorch-eqa-bert-base-cased", + "model_version": "1.0.0", + "tolerate_vulnerable_model": True, + "tolerate_deprecated_model": True, + "config_name": "gpu-training-budget", + }, + ) + + @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") + @mock.patch("sagemaker.jumpstart.factory.estimator.Session") + @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") + @mock.patch("sagemaker.jumpstart.estimator.Estimator.deploy") + @mock.patch("sagemaker.jumpstart.estimator.Estimator.fit") + @mock.patch("sagemaker.jumpstart.estimator.Estimator.__init__") + @mock.patch("sagemaker.jumpstart.factory.estimator.JUMPSTART_DEFAULT_REGION_NAME", region) + def test_estimator_deploy_with_config( + self, + mock_estimator_init: mock.Mock, + mock_estimator_fit: mock.Mock, + mock_estimator_deploy: mock.Mock, + mock_get_model_specs: mock.Mock, + mock_session: mock.Mock, + mock_get_manifest: mock.Mock, + ): + mock_estimator_deploy.return_value = default_predictor + mock_get_model_specs.side_effect = get_prototype_spec_with_configs + mock_get_manifest.side_effect = ( + lambda region, model_type, *args, **kwargs: get_prototype_manifest(region, model_type) + ) + mock_estimator_fit.return_value = default_predictor + + model_id, _ = "pytorch-eqa-bert-base-cased", "*" + + mock_session.return_value = sagemaker_session + + estimator = JumpStartEstimator(model_id=model_id, config_name="gpu-training-budget") + + assert estimator.config_name == "gpu-training-budget" + + estimator.deploy() + + mock_estimator_deploy.assert_called_once_with( + instance_type="ml.p2.xlarge", + initial_instance_count=1, + image_uri="763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-inference:1.5.0-gpu-py3", + source_dir="s3://jumpstart-cache-prod-us-west-2/source-directory-tarballs/" + "pytorch/inference/eqa/v1.0.0/sourcedir.tar.gz", + entry_point="inference.py", + predictor_cls=Predictor, + wait=True, + role="fake role! do not use!", + use_compiled_model=False, + enable_network_isolation=False, + tags=[ + {"Key": JumpStartTag.MODEL_ID, "Value": "pytorch-eqa-bert-base-cased"}, + {"Key": JumpStartTag.MODEL_VERSION, "Value": "1.0.0"}, + {"Key": JumpStartTag.INFERENCE_CONFIG_NAME, "Value": "gpu-inference-budget"}, + ], + ) + def test_jumpstart_estimator_requires_model_id(): with pytest.raises(ValueError): diff --git a/tests/unit/sagemaker/jumpstart/test_types.py b/tests/unit/sagemaker/jumpstart/test_types.py index c52bf76f4e..b7e5f16ae1 100644 --- a/tests/unit/sagemaker/jumpstart/test_types.py +++ b/tests/unit/sagemaker/jumpstart/test_types.py @@ -1193,8 +1193,15 @@ def test_training_configs_parsing(): assert config.config_components["neuron-training"] == JumpStartConfigComponent( "neuron-training", { + "default_training_instance_type": "ml.trn1.2xlarge", "supported_training_instance_types": ["ml.trn1.xlarge", "ml.trn1.2xlarge"], "training_artifact_key": "artifacts/meta-textgeneration-llama-2-7b/neuron-training/model/", + "training_ecr_specs": { + "framework": "huggingface", + "framework_version": "2.0.0", + "py_version": "py310", + "huggingface_transformers_version": "4.28.1", + }, "training_instance_type_variants": { "regional_aliases": { "us-west-2": { From eebd6100b5ea97f6fc40f329f9c750e906d7ccaf Mon Sep 17 00:00:00 2001 From: Jonathan Makunga <54963715+makungaj1@users.noreply.github.com> Date: Thu, 2 May 2024 11:24:30 -0700 Subject: [PATCH 29/32] Benchmark feature fixes (#4632) * Filter down Benchmark Metrics * Filter down Benchmark Metrics * Testing NB * Testing MB * Testing * Refactoring * Unit tests * Display instance type first, and instance rate last * Display unbalanced metrics * Testing with NB * Testing with NB * Debug * Debug * Testing with NB * Testing with NB * Testing with NB * Refactoring * Refactoring * Refactoring * Unit tests * Custom lru * Custom lru * Custom lru * Custom lru * Custom lru * Custom lru * Custom lru * Custom lru * Custom lru * Custom lru * Refactoring * Debug * Config ranking * Debug * Debug * Debug * Debug * Debug * Ranking * Ranking-Debug * Ranking-Debug * Ranking-Debug * Ranking-Debug * Ranking-Debug * Ranking-Debug * Debug * Debug * Debug * Debug * Refactoring * Contact JumpStart team to fix flaky test. test_list_jumpstart_models_script_filter --------- Co-authored-by: Jonathan Makunga --- src/sagemaker/jumpstart/model.py | 95 ++++------ src/sagemaker/jumpstart/types.py | 2 + src/sagemaker/jumpstart/utils.py | 167 ++++++++++++++---- .../serve/builder/jumpstart_builder.py | 12 +- .../sagemaker/jumpstart/model/test_model.py | 8 +- tests/unit/sagemaker/jumpstart/test_utils.py | 41 +++-- tests/unit/sagemaker/jumpstart/utils.py | 17 +- .../serve/builder/test_js_builder.py | 6 + 8 files changed, 214 insertions(+), 134 deletions(-) diff --git a/src/sagemaker/jumpstart/model.py b/src/sagemaker/jumpstart/model.py index 619af2f7a9..6f263d9a7e 100644 --- a/src/sagemaker/jumpstart/model.py +++ b/src/sagemaker/jumpstart/model.py @@ -14,7 +14,6 @@ from __future__ import absolute_import -from functools import lru_cache from typing import Dict, List, Optional, Any, Union import pandas as pd from botocore.exceptions import ClientError @@ -48,6 +47,8 @@ get_jumpstart_configs, get_metrics_from_deployment_configs, add_instance_rate_stats_to_benchmark_metrics, + deployment_config_response_data, + _deployment_config_lru_cache, ) from sagemaker.jumpstart.constants import JUMPSTART_LOGGER from sagemaker.jumpstart.enums import JumpStartModelType @@ -449,10 +450,12 @@ def deployment_config(self) -> Optional[Dict[str, Any]]: Returns: Optional[Dict[str, Any]]: Deployment config. """ - deployment_config = self._retrieve_selected_deployment_config( - self.config_name, self.instance_type - ) - return deployment_config.to_json() if deployment_config is not None else None + if self.config_name is None: + return None + for config in self.list_deployment_configs(): + if config.get("DeploymentConfigName") == self.config_name: + return config + return None @property def benchmark_metrics(self) -> pd.DataFrame: @@ -461,16 +464,14 @@ def benchmark_metrics(self) -> pd.DataFrame: Returns: Benchmark Metrics: Pandas DataFrame object. """ - benchmark_metrics_data = self._get_deployment_configs_benchmarks_data( - self.config_name, self.instance_type - ) - keys = list(benchmark_metrics_data.keys()) - df = pd.DataFrame(benchmark_metrics_data).sort_values(by=[keys[0], keys[1]]) - return df + df = pd.DataFrame(self._get_deployment_configs_benchmarks_data()) + default_mask = df.apply(lambda row: any("Default" in str(val) for val in row), axis=1) + sorted_df = pd.concat([df[default_mask], df[~default_mask]]) + return sorted_df - def display_benchmark_metrics(self) -> None: + def display_benchmark_metrics(self, *args, **kwargs) -> None: """Display deployment configs benchmark metrics.""" - print(self.benchmark_metrics.to_markdown(index=False)) + print(self.benchmark_metrics.to_markdown(index=False), *args, **kwargs) def list_deployment_configs(self) -> List[Dict[str, Any]]: """List deployment configs for ``This`` model. @@ -478,12 +479,9 @@ def list_deployment_configs(self) -> List[Dict[str, Any]]: Returns: List[Dict[str, Any]]: A list of deployment configs. """ - return [ - deployment_config.to_json() - for deployment_config in self._get_deployment_configs( - self.config_name, self.instance_type - ) - ] + return deployment_config_response_data( + self._get_deployment_configs(self.config_name, self.instance_type) + ) def _create_sagemaker_model( self, @@ -873,71 +871,46 @@ def register_deploy_wrapper(*args, **kwargs): return model_package - @lru_cache - def _get_deployment_configs_benchmarks_data( - self, config_name: str, instance_type: str - ) -> Dict[str, Any]: + @_deployment_config_lru_cache + def _get_deployment_configs_benchmarks_data(self) -> Dict[str, Any]: """Deployment configs benchmark metrics. - Args: - config_name (str): Name of selected deployment config. - instance_type (str): The selected Instance type. Returns: Dict[str, List[str]]: Deployment config benchmark data. """ return get_metrics_from_deployment_configs( - self._get_deployment_configs(config_name, instance_type) + self._get_deployment_configs(None, None), ) - @lru_cache - def _retrieve_selected_deployment_config( - self, config_name: str, instance_type: str - ) -> Optional[DeploymentConfigMetadata]: - """Retrieve the deployment config to apply to `This` model. - - Args: - config_name (str): The name of the deployment config to retrieve. - instance_type (str): The instance type of the deployment config to retrieve. - Returns: - Optional[Dict[str, Any]]: The retrieved deployment config. - """ - if config_name is None: - return None - - for deployment_config in self._get_deployment_configs(config_name, instance_type): - if deployment_config.deployment_config_name == config_name: - return deployment_config - return None - - @lru_cache + @_deployment_config_lru_cache def _get_deployment_configs( - self, selected_config_name: str, selected_instance_type: str + self, selected_config_name: Optional[str], selected_instance_type: Optional[str] ) -> List[DeploymentConfigMetadata]: """Retrieve deployment configs metadata. Args: - selected_config_name (str): The name of the selected deployment config. - selected_instance_type (str): The selected instance type. + selected_config_name (Optional[str]): The name of the selected deployment config. + selected_instance_type (Optional[str]): The selected instance type. """ deployment_configs = [] - if self._metadata_configs is None: + if not self._metadata_configs: return deployment_configs err = None for config_name, metadata_config in self._metadata_configs.items(): - if err is None or "is not authorized to perform: pricing:GetProducts" not in err: - err, metadata_config.benchmark_metrics = ( - add_instance_rate_stats_to_benchmark_metrics( - self.region, metadata_config.benchmark_metrics - ) - ) - resolved_config = metadata_config.resolved_config if selected_config_name == config_name: instance_type_to_use = selected_instance_type else: instance_type_to_use = resolved_config.get("default_inference_instance_type") + if metadata_config.benchmark_metrics: + err, metadata_config.benchmark_metrics = ( + add_instance_rate_stats_to_benchmark_metrics( + self.region, metadata_config.benchmark_metrics + ) + ) + init_kwargs = get_init_kwargs( model_id=self.model_id, instance_type=instance_type_to_use, @@ -957,9 +930,9 @@ def _get_deployment_configs( ) deployment_configs.append(deployment_config_metadata) - if err is not None and "is not authorized to perform: pricing:GetProducts" in err: + if err and err["Code"] == "AccessDeniedException": error_message = "Instance rate metrics will be omitted. Reason: %s" - JUMPSTART_LOGGER.warning(error_message, err) + JUMPSTART_LOGGER.warning(error_message, err["Message"]) return deployment_configs diff --git a/src/sagemaker/jumpstart/types.py b/src/sagemaker/jumpstart/types.py index 0a586f60aa..f85f23c361 100644 --- a/src/sagemaker/jumpstart/types.py +++ b/src/sagemaker/jumpstart/types.py @@ -2255,6 +2255,8 @@ def _val_to_json(self, val: Any) -> Any: Any: The converted json value. """ if issubclass(type(val), JumpStartDataHolderType): + if isinstance(val, JumpStartBenchmarkStat): + val.name = val.name.replace("_", " ").title() return val.to_json() if isinstance(val, list): list_obj = [] diff --git a/src/sagemaker/jumpstart/utils.py b/src/sagemaker/jumpstart/utils.py index d2a0a396b5..44be0ea813 100644 --- a/src/sagemaker/jumpstart/utils.py +++ b/src/sagemaker/jumpstart/utils.py @@ -12,8 +12,10 @@ # language governing permissions and limitations under the License. """This module contains utilities related to SageMaker JumpStart.""" from __future__ import absolute_import + import logging import os +from functools import lru_cache, wraps from typing import Any, Dict, List, Set, Optional, Tuple, Union from urllib.parse import urlparse import boto3 @@ -1040,7 +1042,9 @@ def get_jumpstart_configs( raise ValueError(f"Unknown script scope: {scope}.") if not config_names: - config_names = metadata_configs.configs.keys() if metadata_configs else [] + config_names = ( + metadata_configs.config_rankings.get("overall").rankings if metadata_configs else [] + ) return ( {config_name: metadata_configs.configs[config_name] for config_name in config_names} @@ -1052,43 +1056,42 @@ def get_jumpstart_configs( def add_instance_rate_stats_to_benchmark_metrics( region: str, benchmark_metrics: Optional[Dict[str, List[JumpStartBenchmarkStat]]], -) -> Optional[Tuple[str, Dict[str, List[JumpStartBenchmarkStat]]]]: +) -> Optional[Tuple[Dict[str, str], Dict[str, List[JumpStartBenchmarkStat]]]]: """Adds instance types metric stats to the given benchmark_metrics dict. Args: region (str): AWS region. - benchmark_metrics (Dict[str, List[JumpStartBenchmarkStat]]): + benchmark_metrics (Optional[Dict[str, List[JumpStartBenchmarkStat]]]): Returns: - Tuple[str, Dict[str, List[JumpStartBenchmarkStat]]]: - Contains Error message and metrics dict. + Optional[Tuple[Dict[str, str], Dict[str, List[JumpStartBenchmarkStat]]]]: + Contains Error and metrics. """ - - if benchmark_metrics is None: + if not benchmark_metrics: return None - final_benchmark_metrics = {} - err_message = None + final_benchmark_metrics = {} for instance_type, benchmark_metric_stats in benchmark_metrics.items(): instance_type = instance_type if instance_type.startswith("ml.") else f"ml.{instance_type}" - if not has_instance_rate_stat(benchmark_metric_stats): + if not has_instance_rate_stat(benchmark_metric_stats) and not err_message: try: instance_type_rate = get_instance_rate_per_hour( instance_type=instance_type, region=region ) + if not benchmark_metric_stats: + benchmark_metric_stats = [] benchmark_metric_stats.append(JumpStartBenchmarkStat(instance_type_rate)) - final_benchmark_metrics[instance_type] = benchmark_metric_stats + final_benchmark_metrics[instance_type] = benchmark_metric_stats except ClientError as e: final_benchmark_metrics[instance_type] = benchmark_metric_stats - err_message = e.response["Error"]["Message"] + err_message = e.response["Error"] except Exception: # pylint: disable=W0703 final_benchmark_metrics[instance_type] = benchmark_metric_stats - err_message = ( - f"Unable to get instance rate per hour for instance type: {instance_type}." - ) + else: + final_benchmark_metrics[instance_type] = benchmark_metric_stats return err_message, final_benchmark_metrics @@ -1103,31 +1106,32 @@ def has_instance_rate_stat(benchmark_metric_stats: Optional[List[JumpStartBenchm bool: Whether the benchmark metric stats contains instance rate metric stat. """ if benchmark_metric_stats is None: - return False - + return True for benchmark_metric_stat in benchmark_metric_stats: if benchmark_metric_stat.name.lower() == "instance rate": return True - return False def get_metrics_from_deployment_configs( - deployment_configs: List[DeploymentConfigMetadata], + deployment_configs: Optional[List[DeploymentConfigMetadata]], ) -> Dict[str, List[str]]: """Extracts benchmark metrics from deployment configs metadata. Args: - deployment_configs (List[DeploymentConfigMetadata]): List of deployment configs metadata. + deployment_configs (Optional[List[DeploymentConfigMetadata]]): + List of deployment configs metadata. + Returns: + Dict[str, List[str]]: Deployment configs bench metrics dict. """ - data = {"Config Name": [], "Instance Type": []} - - for outer_index, deployment_config in enumerate(deployment_configs): - if deployment_config.deployment_args is None: - continue + if not deployment_configs: + return {} + data = {"Instance Type": [], "Config Name": []} + instance_rate_data = {} + for index, deployment_config in enumerate(deployment_configs): benchmark_metrics = deployment_config.benchmark_metrics - if benchmark_metrics is None: + if not deployment_config.deployment_args or not benchmark_metrics: continue for inner_index, current_instance_type in enumerate(benchmark_metrics): @@ -1136,23 +1140,108 @@ def get_metrics_from_deployment_configs( data["Config Name"].append(deployment_config.deployment_config_name) instance_type_to_display = ( f"{current_instance_type} (Default)" - if current_instance_type == deployment_config.deployment_args.default_instance_type + if index == 0 + and current_instance_type == deployment_config.deployment_args.default_instance_type else current_instance_type ) data["Instance Type"].append(instance_type_to_display) - if outer_index == 0 and inner_index == 0: - temp_data = {} - for metric in current_instance_type_metrics: - column_name = f"{metric.name.replace('_', ' ').title()} ({metric.unit})" - if metric.name.lower() == "instance rate": - data[column_name] = [] - else: - temp_data[column_name] = [] - data = {**data, **temp_data} - for metric in current_instance_type_metrics: - column_name = f"{metric.name.replace('_', ' ').title()} ({metric.unit})" - if column_name in data: + column_name = f"{metric.name} ({metric.unit})" + + if metric.name.lower() == "instance rate": + if column_name not in instance_rate_data: + instance_rate_data[column_name] = [] + instance_rate_data[column_name].append(metric.value) + else: + if column_name not in data: + data[column_name] = [] + for _ in range(len(data[column_name]), inner_index): + data[column_name].append(" - ") data[column_name].append(metric.value) + + data = {**data, **instance_rate_data} return data + + +def deployment_config_response_data( + deployment_configs: Optional[List[DeploymentConfigMetadata]], +) -> List[Dict[str, Any]]: + """Deployment config api response data. + + Args: + deployment_configs (Optional[List[DeploymentConfigMetadata]]): + List of deployment configs metadata. + Returns: + List[Dict[str, Any]]: List of deployment config api response data. + """ + configs = [] + if not deployment_configs: + return configs + + for deployment_config in deployment_configs: + deployment_config_json = deployment_config.to_json() + benchmark_metrics = deployment_config_json.get("BenchmarkMetrics") + if benchmark_metrics and deployment_config.deployment_args: + deployment_config_json["BenchmarkMetrics"] = { + deployment_config.deployment_args.instance_type: benchmark_metrics.get( + deployment_config.deployment_args.instance_type + ) + } + + configs.append(deployment_config_json) + return configs + + +def _deployment_config_lru_cache(_func=None, *, maxsize: int = 128, typed: bool = False): + """LRU cache for deployment configs.""" + + def has_instance_rate_metric(config: DeploymentConfigMetadata) -> bool: + """Determines whether metadata config contains instance rate metric stat. + + Args: + config (DeploymentConfigMetadata): Metadata config metadata. + Returns: + bool: Whether the metadata config contains instance rate metric stat. + """ + if config.benchmark_metrics is None: + return True + for benchmark_metric_stats in config.benchmark_metrics.values(): + if not has_instance_rate_stat(benchmark_metric_stats): + return False + return True + + def wrapper_cache(f): + f = lru_cache(maxsize=maxsize, typed=typed)(f) + + @wraps(f) + def wrapped_f(*args, **kwargs): + res = f(*args, **kwargs) + + # Clear cache on first call if + # - The output does not contain Instant rate metrics + # as this is caused by missing policy. + if f.cache_info().hits == 0 and f.cache_info().misses == 1: + if isinstance(res, list): + for item in res: + if isinstance( + item, DeploymentConfigMetadata + ) and not has_instance_rate_metric(item): + f.cache_clear() + break + elif isinstance(res, dict): + keys = list(res.keys()) + if "Instance Rate" not in keys[-1]: + f.cache_clear() + elif len(res[keys[1]]) > len(res[keys[-1]]): + del res[keys[-1]] + f.cache_clear() + return res + + wrapped_f.cache_info = f.cache_info + wrapped_f.cache_clear = f.cache_clear + return wrapped_f + + if _func is None: + return wrapper_cache + return wrapper_cache(_func) diff --git a/src/sagemaker/serve/builder/jumpstart_builder.py b/src/sagemaker/serve/builder/jumpstart_builder.py index ec987dd9fe..f6a4d165df 100644 --- a/src/sagemaker/serve/builder/jumpstart_builder.py +++ b/src/sagemaker/serve/builder/jumpstart_builder.py @@ -454,14 +454,14 @@ def get_deployment_config(self) -> Optional[Dict[str, Any]]: Optional[Dict[str, Any]]: Deployment config to apply to this model. """ if not hasattr(self, "pysdk_model") or self.pysdk_model is None: - self.pysdk_model = self._create_pre_trained_js_model() + self._build_for_jumpstart() return self.pysdk_model.deployment_config def display_benchmark_metrics(self): """Display Markdown Benchmark Metrics for deployment configs.""" if not hasattr(self, "pysdk_model") or self.pysdk_model is None: - self.pysdk_model = self._create_pre_trained_js_model() + self._build_for_jumpstart() self.pysdk_model.display_benchmark_metrics() @@ -472,18 +472,20 @@ def list_deployment_configs(self) -> List[Dict[str, Any]]: List[Dict[str, Any]]: A list of deployment configs. """ if not hasattr(self, "pysdk_model") or self.pysdk_model is None: - self.pysdk_model = self._create_pre_trained_js_model() + self._build_for_jumpstart() return self.pysdk_model.list_deployment_configs() def _build_for_jumpstart(self): """Placeholder docstring""" + if hasattr(self, "pysdk_model") and self.pysdk_model is not None: + return self.pysdk_model + # we do not pickle for jumpstart. set to none self.secret_key = None self.jumpstart = True - if not hasattr(self, "pysdk_model") or self.pysdk_model is None: - self.pysdk_model = self._create_pre_trained_js_model() + self.pysdk_model = self._create_pre_trained_js_model() logger.info( "JumpStart ID %s is packaged with Image URI: %s", self.model, self.pysdk_model.image_uri diff --git a/tests/unit/sagemaker/jumpstart/model/test_model.py b/tests/unit/sagemaker/jumpstart/model/test_model.py index f32687bd99..75b3fd7300 100644 --- a/tests/unit/sagemaker/jumpstart/model/test_model.py +++ b/tests/unit/sagemaker/jumpstart/model/test_model.py @@ -1733,7 +1733,7 @@ def test_model_list_deployment_configs( mock_get_init_kwargs.side_effect = lambda *args, **kwargs: get_mock_init_kwargs(model_id) mock_verify_model_region_and_return_specs.side_effect = ( - lambda *args, **kwargs: get_base_spec_with_prototype_configs() + lambda *args, **kwargs: get_base_spec_with_prototype_configs_with_missing_benchmarks() ) mock_add_instance_rate_stats_to_benchmark_metrics.side_effect = lambda region, metrics: ( None, @@ -1750,7 +1750,7 @@ def test_model_list_deployment_configs( configs = model.list_deployment_configs() - self.assertEqual(configs, get_base_deployment_configs()) + self.assertEqual(configs, get_base_deployment_configs(True)) @mock.patch("sagemaker.jumpstart.utils.verify_model_region_and_return_specs") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") @@ -1803,7 +1803,7 @@ def test_model_retrieve_deployment_config( model_id, _ = "pytorch-eqa-bert-base-cased", "*" mock_verify_model_region_and_return_specs.side_effect = ( - lambda *args, **kwargs: get_base_spec_with_prototype_configs_with_missing_benchmarks() + lambda *args, **kwargs: get_base_spec_with_prototype_configs() ) mock_add_instance_rate_stats_to_benchmark_metrics.side_effect = lambda region, metrics: ( None, @@ -1815,7 +1815,7 @@ def test_model_retrieve_deployment_config( ) mock_model_deploy.return_value = default_predictor - expected = get_base_deployment_configs(True)[0] + expected = get_base_deployment_configs()[0] config_name = expected.get("DeploymentConfigName") instance_type = expected.get("InstanceType") mock_get_init_kwargs.side_effect = lambda *args, **kwargs: get_mock_init_kwargs( diff --git a/tests/unit/sagemaker/jumpstart/test_utils.py b/tests/unit/sagemaker/jumpstart/test_utils.py index 5b30a94dd6..e6ea212994 100644 --- a/tests/unit/sagemaker/jumpstart/test_utils.py +++ b/tests/unit/sagemaker/jumpstart/test_utils.py @@ -52,6 +52,7 @@ get_special_model_spec, get_prototype_manifest, get_base_deployment_configs_metadata, + get_base_deployment_configs, ) from mock import MagicMock @@ -1869,29 +1870,15 @@ def test_add_instance_rate_stats_to_benchmark_metrics_client_ex( mock_get_instance_rate_per_hour, ): mock_get_instance_rate_per_hour.side_effect = ClientError( - {"Error": {"Message": "is not authorized to perform: pricing:GetProducts"}}, "GetProducts" - ) - - err, out = utils.add_instance_rate_stats_to_benchmark_metrics( - "us-west-2", { - "ml.p2.xlarge": [ - JumpStartBenchmarkStat({"name": "Latency", "value": "100", "unit": "Tokens/S"}) - ], + "Error": { + "Message": "is not authorized to perform: pricing:GetProducts", + "Code": "AccessDenied", + }, }, + "GetProducts", ) - assert err == "is not authorized to perform: pricing:GetProducts" - for key in out: - assert len(out[key]) == 1 - - -@patch("sagemaker.jumpstart.utils.get_instance_rate_per_hour") -def test_add_instance_rate_stats_to_benchmark_metrics_ex( - mock_get_instance_rate_per_hour, -): - mock_get_instance_rate_per_hour.side_effect = Exception() - err, out = utils.add_instance_rate_stats_to_benchmark_metrics( "us-west-2", { @@ -1901,7 +1888,8 @@ def test_add_instance_rate_stats_to_benchmark_metrics_ex( }, ) - assert err == "Unable to get instance rate per hour for instance type: ml.p2.xlarge." + assert err["Message"] == "is not authorized to perform: pricing:GetProducts" + assert err["Code"] == "AccessDenied" for key in out: assert len(out[key]) == 1 @@ -1909,7 +1897,7 @@ def test_add_instance_rate_stats_to_benchmark_metrics_ex( @pytest.mark.parametrize( "stats, expected", [ - (None, False), + (None, True), ( [JumpStartBenchmarkStat({"name": "Instance Rate", "unit": "USD/Hrs", "value": "3.76"})], True, @@ -1919,3 +1907,14 @@ def test_add_instance_rate_stats_to_benchmark_metrics_ex( ) def test_has_instance_rate_stat(stats, expected): assert utils.has_instance_rate_stat(stats) is expected + + +@pytest.mark.parametrize( + "data, expected", + [(None, []), ([], []), (get_base_deployment_configs_metadata(), get_base_deployment_configs())], +) +def test_deployment_config_response_data(data, expected): + out = utils.deployment_config_response_data(data) + + print(out) + assert out == expected diff --git a/tests/unit/sagemaker/jumpstart/utils.py b/tests/unit/sagemaker/jumpstart/utils.py index e8a93dff6c..63b964e16e 100644 --- a/tests/unit/sagemaker/jumpstart/utils.py +++ b/tests/unit/sagemaker/jumpstart/utils.py @@ -358,7 +358,8 @@ def get_base_deployment_configs_metadata( else get_base_spec_with_prototype_configs() ) configs = [] - for config_name, jumpstart_config in specs.inference_configs.configs.items(): + for config_name in specs.inference_configs.config_rankings.get("overall").rankings: + jumpstart_config = specs.inference_configs.configs.get(config_name) benchmark_metrics = jumpstart_config.benchmark_metrics if benchmark_metrics: @@ -388,9 +389,17 @@ def get_base_deployment_configs_metadata( def get_base_deployment_configs( omit_benchmark_metrics: bool = False, ) -> List[Dict[str, Any]]: - return [ - config.to_json() for config in get_base_deployment_configs_metadata(omit_benchmark_metrics) - ] + configs = [] + for config in get_base_deployment_configs_metadata(omit_benchmark_metrics): + config_json = config.to_json() + if config_json["BenchmarkMetrics"]: + config_json["BenchmarkMetrics"] = { + config.deployment_args.instance_type: config_json["BenchmarkMetrics"].get( + config.deployment_args.instance_type + ) + } + configs.append(config_json) + return configs def append_instance_stat_metrics( diff --git a/tests/unit/sagemaker/serve/builder/test_js_builder.py b/tests/unit/sagemaker/serve/builder/test_js_builder.py index 56b01cd9e3..4ec96e88e3 100644 --- a/tests/unit/sagemaker/serve/builder/test_js_builder.py +++ b/tests/unit/sagemaker/serve/builder/test_js_builder.py @@ -866,6 +866,12 @@ def test_display_benchmark_metrics_initial( model="facebook/galactica-mock-model-id", schema_builder=mock_schema_builder, ) + + mock_pre_trained_model.return_value.image_uri = mock_tgi_image_uri + mock_pre_trained_model.return_value.list_deployment_configs.side_effect = ( + lambda: DEPLOYMENT_CONFIGS + ) + builder.display_benchmark_metrics() mock_pre_trained_model.return_value.display_benchmark_metrics.assert_called_once() From 8bdf8400246af822c0797f4c0e0ff2183b1a2dc1 Mon Sep 17 00:00:00 2001 From: Haotian An <33510317+Captainia@users.noreply.github.com> Date: Mon, 6 May 2024 09:14:16 -0400 Subject: [PATCH 30/32] Merge master --- src/sagemaker/jumpstart/estimator.py | 4 ++++ src/sagemaker/jumpstart/factory/estimator.py | 2 ++ src/sagemaker/jumpstart/session_utils.py | 4 ++-- src/sagemaker/jumpstart/types.py | 7 +++++-- tests/unit/sagemaker/jumpstart/test_types.py | 8 ++++++++ 5 files changed, 21 insertions(+), 4 deletions(-) diff --git a/src/sagemaker/jumpstart/estimator.py b/src/sagemaker/jumpstart/estimator.py index 3132ea4d26..5f7e0ed82c 100644 --- a/src/sagemaker/jumpstart/estimator.py +++ b/src/sagemaker/jumpstart/estimator.py @@ -112,6 +112,7 @@ def __init__( disable_output_compression: Optional[bool] = None, enable_remote_debug: Optional[Union[bool, PipelineVariable]] = None, config_name: Optional[str] = None, + enable_session_tag_chaining: Optional[Union[bool, PipelineVariable]] = None, ): """Initializes a ``JumpStartEstimator``. @@ -505,6 +506,8 @@ def __init__( Specifies whether RemoteDebug is enabled for the training job config_name (Optional[str]): Name of the training configuration to apply to the Estimator. (Default: None). + enable_session_tag_chaining (bool or PipelineVariable): Optional. + Specifies whether SessionTagChaining is enabled for the training job Raises: ValueError: If the model ID is not recognized by JumpStart. @@ -584,6 +587,7 @@ def _validate_model_id_and_get_type_hook(): enable_infra_check=enable_infra_check, enable_remote_debug=enable_remote_debug, config_name=config_name, + enable_session_tag_chaining=enable_session_tag_chaining, ) self.model_id = estimator_init_kwargs.model_id diff --git a/src/sagemaker/jumpstart/factory/estimator.py b/src/sagemaker/jumpstart/factory/estimator.py index 9177265d74..e171dcd99c 100644 --- a/src/sagemaker/jumpstart/factory/estimator.py +++ b/src/sagemaker/jumpstart/factory/estimator.py @@ -131,6 +131,7 @@ def get_init_kwargs( enable_infra_check: Optional[Union[bool, PipelineVariable]] = None, enable_remote_debug: Optional[Union[bool, PipelineVariable]] = None, config_name: Optional[str] = None, + enable_session_tag_chaining: Optional[Union[bool, PipelineVariable]] = None, ) -> JumpStartEstimatorInitKwargs: """Returns kwargs required to instantiate `sagemaker.estimator.Estimator` object.""" @@ -190,6 +191,7 @@ def get_init_kwargs( enable_infra_check=enable_infra_check, enable_remote_debug=enable_remote_debug, config_name=config_name, + enable_session_tag_chaining=enable_session_tag_chaining, ) estimator_init_kwargs = _add_model_version_to_kwargs(estimator_init_kwargs) diff --git a/src/sagemaker/jumpstart/session_utils.py b/src/sagemaker/jumpstart/session_utils.py index 7953b67913..0955ae9480 100644 --- a/src/sagemaker/jumpstart/session_utils.py +++ b/src/sagemaker/jumpstart/session_utils.py @@ -219,7 +219,7 @@ def get_model_info_from_training_job( model_id, inferred_model_version, inference_config_name, - trainig_config_name, + training_config_name, ) = get_jumpstart_model_info_from_resource_arn(training_job_arn, sagemaker_session) model_version = inferred_model_version or None @@ -231,4 +231,4 @@ def get_model_info_from_training_job( "for this training job." ) - return model_id, model_version, inference_config_name, trainig_config_name + return model_id, model_version, inference_config_name, training_config_name diff --git a/src/sagemaker/jumpstart/types.py b/src/sagemaker/jumpstart/types.py index f85f23c361..ab2eeed7f0 100644 --- a/src/sagemaker/jumpstart/types.py +++ b/src/sagemaker/jumpstart/types.py @@ -1078,7 +1078,7 @@ class JumpStartMetadataConfig(JumpStartDataHolderType): "resolved_metadata_config", "config_name", "default_inference_config", - "default_incremental_trainig_config", + "default_incremental_training_config", "supported_inference_configs", "supported_incremental_training_configs", ] @@ -1114,7 +1114,7 @@ def __init__( self.resolved_metadata_config: Optional[Dict[str, Any]] = None self.config_name: Optional[str] = config_name self.default_inference_config: Optional[str] = config.get("default_inference_config") - self.default_incremental_trainig_config: Optional[str] = config.get( + self.default_incremental_training_config: Optional[str] = config.get( "default_incremental_training_config" ) self.supported_inference_configs: Optional[List[str]] = config.get( @@ -1775,6 +1775,7 @@ class JumpStartEstimatorInitKwargs(JumpStartKwargs): "enable_infra_check", "enable_remote_debug", "config_name", + "enable_session_tag_chaining", ] SERIALIZATION_EXCLUSION_SET = { @@ -1844,6 +1845,7 @@ def __init__( enable_infra_check: Optional[Union[bool, PipelineVariable]] = None, enable_remote_debug: Optional[Union[bool, PipelineVariable]] = None, config_name: Optional[str] = None, + enable_session_tag_chaining: Optional[Union[bool, PipelineVariable]] = None, ) -> None: """Instantiates JumpStartEstimatorInitKwargs object.""" @@ -1904,6 +1906,7 @@ def __init__( self.enable_infra_check = enable_infra_check self.enable_remote_debug = enable_remote_debug self.config_name = config_name + self.enable_session_tag_chaining = enable_session_tag_chaining class JumpStartEstimatorFitKwargs(JumpStartKwargs): diff --git a/tests/unit/sagemaker/jumpstart/test_types.py b/tests/unit/sagemaker/jumpstart/test_types.py index b7e5f16ae1..f8c9c81a38 100644 --- a/tests/unit/sagemaker/jumpstart/test_types.py +++ b/tests/unit/sagemaker/jumpstart/test_types.py @@ -1055,6 +1055,14 @@ def test_inference_configs_parsing(): ) assert list(config.config_components.keys()) == ["neuron-inference"] + spec = { + **BASE_SPEC, + **INFERENCE_CONFIGS, + **INFERENCE_CONFIG_RANKINGS, + "unrecognized-field": "blah", # New fields in base metadata fields should be ignored + } + specs1 = JumpStartModelSpecs(spec) + def test_set_inference_configs(): spec = {**BASE_SPEC, **INFERENCE_CONFIGS, **INFERENCE_CONFIG_RANKINGS} From 9dc9b54e8ffbba36d1aa530bea14fbeaf7df2d52 Mon Sep 17 00:00:00 2001 From: Haotian An <33510317+Captainia@users.noreply.github.com> Date: Tue, 7 May 2024 09:14:58 -0400 Subject: [PATCH 31/32] Merge master into benchmark feature (#4652) --- src/sagemaker/jumpstart/types.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/sagemaker/jumpstart/types.py b/src/sagemaker/jumpstart/types.py index ab2eeed7f0..7a51d075ae 100644 --- a/src/sagemaker/jumpstart/types.py +++ b/src/sagemaker/jumpstart/types.py @@ -1097,7 +1097,7 @@ def __init__( config (Dict[str, Any]): Dictionary representation of the config. base_fields (Dict[str, Any]): - The default base fields that are used to construct the final resolved config. + The default base fields that are used to construct the resolved config. config_components (Dict[str, JumpStartConfigComponent]): The list of components that are used to construct the resolved config. """ From c6edf95372a60c9a5d1a20ed5ee5790f849da3f3 Mon Sep 17 00:00:00 2001 From: Haotian An <33510317+Captainia@users.noreply.github.com> Date: Tue, 7 May 2024 10:32:04 -0400 Subject: [PATCH 32/32] Merge master into master-benchmark-feature (#4656)