diff --git a/src/sagemaker/jumpstart/artifacts/metric_definitions.py b/src/sagemaker/jumpstart/artifacts/metric_definitions.py index 0a9cfa00ae..b6f6019641 100644 --- a/src/sagemaker/jumpstart/artifacts/metric_definitions.py +++ b/src/sagemaker/jumpstart/artifacts/metric_definitions.py @@ -34,6 +34,7 @@ def _retrieve_default_training_metric_definitions( tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + instance_type: Optional[str] = None, ) -> Optional[List[Dict[str, str]]]: """Retrieves the default training metric definitions for the model. @@ -55,6 +56,8 @@ def _retrieve_default_training_metric_definitions( object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + instance_type (str): An instance type to optionally supply in order to get + metric definitions specific for the instance type. Returns: list: the default training metric definitions to use for the model or None. """ @@ -72,4 +75,29 @@ def _retrieve_default_training_metric_definitions( sagemaker_session=sagemaker_session, ) - return deepcopy(model_specs.metrics) if model_specs.metrics else None + default_metric_definitions = ( + deepcopy(model_specs.metrics) if getattr(model_specs, "metrics") else [] + ) + + instance_specific_metric_definitions = ( + model_specs.training_instance_type_variants.get_instance_specific_metric_definitions( + instance_type + ) + if instance_type + and getattr(model_specs, "training_instance_type_variants", None) is not None + else [] + ) + + instance_specific_metric_name: str + for instance_specific_metric_definition in instance_specific_metric_definitions: + instance_specific_metric_name = instance_specific_metric_definition["Name"] + default_metric_definitions = list( + filter( + lambda metric_definition: metric_definition["Name"] + != instance_specific_metric_name, + default_metric_definitions, + ) + ) + default_metric_definitions.append(instance_specific_metric_definition) + + return default_metric_definitions diff --git a/src/sagemaker/jumpstart/factory/estimator.py b/src/sagemaker/jumpstart/factory/estimator.py index 1ee6bb71db..5b9ca76a85 100644 --- a/src/sagemaker/jumpstart/factory/estimator.py +++ b/src/sagemaker/jumpstart/factory/estimator.py @@ -633,6 +633,7 @@ def _add_metric_definitions_to_kwargs( tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, sagemaker_session=kwargs.sagemaker_session, + instance_type=kwargs.instance_type, ) or [] ) diff --git a/src/sagemaker/jumpstart/types.py b/src/sagemaker/jumpstart/types.py index d5e0526a55..4f5a8489f0 100644 --- a/src/sagemaker/jumpstart/types.py +++ b/src/sagemaker/jumpstart/types.py @@ -403,6 +403,42 @@ def to_json(self) -> Dict[str, Any]: json_obj = {att: getattr(self, att) for att in self.__slots__ if hasattr(self, att)} return json_obj + def get_instance_specific_metric_definitions( + self, instance_type: str + ) -> List[JumpStartHyperparameter]: + """Returns instance specific metric definitions. + + Returns empty list if a model, instance type tuple does not have specific + metric definitions. + """ + + if self.variants is None: + return [] + + instance_specific_metric_definitions: List[Dict[str, Union[str, Any]]] = ( + self.variants.get(instance_type, {}).get("properties", {}).get("metrics", []) + ) + + instance_type_family = get_instance_type_family(instance_type) + + instance_family_metric_definitions: List[Dict[str, Union[str, Any]]] = ( + self.variants.get(instance_type_family, {}).get("properties", {}).get("metrics", []) + if instance_type_family not in {"", None} + else [] + ) + + instance_specific_metric_names = { + metric_definition["Name"] for metric_definition in instance_specific_metric_definitions + } + + metric_definitions_to_return = deepcopy(instance_specific_metric_definitions) + + for instance_family_metric_definition in instance_family_metric_definitions: + if instance_family_metric_definition["Name"] not in instance_specific_metric_names: + metric_definitions_to_return.append(instance_family_metric_definition) + + return metric_definitions_to_return + def get_instance_specific_prepacked_artifact_key(self, instance_type: str) -> Optional[str]: """Returns instance specific model artifact key. diff --git a/src/sagemaker/metric_definitions.py b/src/sagemaker/metric_definitions.py index 648c6e0cb4..71dd26db45 100644 --- a/src/sagemaker/metric_definitions.py +++ b/src/sagemaker/metric_definitions.py @@ -29,6 +29,7 @@ def retrieve_default( region: Optional[str] = None, model_id: Optional[str] = None, model_version: Optional[str] = None, + instance_type: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, @@ -42,6 +43,8 @@ def retrieve_default( retrieve the default training metric definitions. (Default: None). model_version (str): The version of the model for which to retrieve the default training metric definitions. (Default: None). + instance_type (str): An instance type to optionally supply in order to get + metric definitions specific for the instance type. tolerate_vulnerable_model (bool): True if vulnerable versions of model specifications should be tolerated (exception not raised). If False, raises an exception if the script used by this version of the model has dependencies with known @@ -66,10 +69,11 @@ def retrieve_default( ) return artifacts._retrieve_default_training_metric_definitions( - model_id, - model_version, - region, - tolerate_vulnerable_model, - tolerate_deprecated_model, + model_id=model_id, + model_version=model_version, + instance_type=instance_type, + region=region, + tolerate_vulnerable_model=tolerate_vulnerable_model, + tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, ) diff --git a/tests/unit/sagemaker/jumpstart/constants.py b/tests/unit/sagemaker/jumpstart/constants.py index 1fe5eb1663..910a99f442 100644 --- a/tests/unit/sagemaker/jumpstart/constants.py +++ b/tests/unit/sagemaker/jumpstart/constants.py @@ -190,59 +190,8 @@ "framework_version": "1.5.0", "py_version": "py3", }, - "hosting_instance_type_variants": { - "regional_aliases": { - "us-west-2": { - "gpu_image_uri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/" - "huggingface-pytorch-inference:1.13.1-transformers4.26.0-gpu-py39-cu117-ubuntu20.04", - "cpu_image_uri": "867930986793.dkr.us-west-2.amazonaws.com/cpu-blah", - "inf_model_package_arn": "us-west-2/blah/blah/blah/inf", - "gpu_model_package_arn": "us-west-2/blah/blah/blah/gpu", - } - }, - "variants": { - "p2": { - "regional_properties": { - "image_uri": "$gpu_image_uri", - "model_package_arn": "$gpu_model_package_arn", - } - }, - "p3": { - "regional_properties": { - "image_uri": "$gpu_image_uri", - "model_package_arn": "$gpu_model_package_arn", - } - }, - "p4": { - "regional_properties": { - "image_uri": "$gpu_image_uri", - "model_package_arn": "$gpu_model_package_arn", - } - }, - "g4dn": { - "regional_properties": { - "image_uri": "$gpu_image_uri", - "model_package_arn": "$gpu_model_package_arn", - } - }, - "m2": {"regional_properties": {"image_uri": "$cpu_image_uri"}}, - "c2": {"regional_properties": {"image_uri": "$cpu_image_uri"}}, - "ml.g5.48xlarge": { - "properties": {"environment_variables": {"TENSOR_PARALLEL_DEGREE": "8"}} - }, - "ml.g5.12xlarge": { - "properties": {"environment_variables": {"TENSOR_PARALLEL_DEGREE": "4"}} - }, - "inf1": {"regional_properties": {"model_package_arn": "$inf_model_package_arn"}}, - "inf2": {"regional_properties": {"model_package_arn": "$inf_model_package_arn"}}, - }, - }, - "training_ecr_specs": { - "framework": "pytorch", - "framework_version": "1.5.0", - "py_version": "py3", - }, "training_instance_type_variants": { + "regional_aliases": {}, "variants": { "ml.p2.12xlarge": { "properties": { @@ -305,9 +254,28 @@ "scope": "algorithm", }, ], + "metrics": [ + { + "Name": "huggingface-textgeneration:instance-typemetric-loss", + "Regex": "'eval_loss': ([0-9]+\\.[0-9]+)", + }, + { + "Name": "huggingface-textgeneration:eval-loss", + "Regex": "'eval_loss': ([0-9]+\\.[0-9]+)", + }, + { + "Name": "huggingface-textgeneration:train-loss", + "Regex": "'instance type specific': ([0-9]+\\.[0-9]+)", + }, + { + "Name": "huggingface-textgeneration:noneyourbusiness-loss", + "Regex": "'loss-noyb instance specific': ([0-9]+\\.[0-9]+)", + }, + ], } }, "p2": { + "regional_properties": {"image_uri": "$gpu_ecr_uri_2"}, "properties": { "hyperparameters": [ { @@ -372,10 +340,80 @@ "default": "20", "scope": "container", }, - ] + ], + "metrics": [ + { + "Name": "huggingface-textgeneration:wtafigo", + "Regex": "'evasadfasdl_loss': ([0-9]+\\.[0-9]+)", + }, + { + "Name": "huggingface-textgeneration:eval-loss", + "Regex": "'eval_loss': ([0-9]+\\.[0-9]+)", + }, + { + "Name": "huggingface-textgeneration:train-loss", + "Regex": "'instance family specific': ([0-9]+\\.[0-9]+)", + }, + { + "Name": "huggingface-textgeneration:noneyourbusiness-loss", + "Regex": "'loss-noyb': ([0-9]+\\.[0-9]+)", + }, + ], + }, + }, + }, + }, + "hosting_instance_type_variants": { + "regional_aliases": { + "us-west-2": { + "gpu_image_uri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/" + "huggingface-pytorch-inference:1.13.1-transformers4.26.0-gpu-py39-cu117-ubuntu20.04", + "cpu_image_uri": "867930986793.dkr.us-west-2.amazonaws.com/cpu-blah", + "inf_model_package_arn": "us-west-2/blah/blah/blah/inf", + "gpu_model_package_arn": "us-west-2/blah/blah/blah/gpu", + } + }, + "variants": { + "p2": { + "regional_properties": { + "image_uri": "$gpu_image_uri", + "model_package_arn": "$gpu_model_package_arn", } }, - } + "p3": { + "regional_properties": { + "image_uri": "$gpu_image_uri", + "model_package_arn": "$gpu_model_package_arn", + } + }, + "p4": { + "regional_properties": { + "image_uri": "$gpu_image_uri", + "model_package_arn": "$gpu_model_package_arn", + } + }, + "g4dn": { + "regional_properties": { + "image_uri": "$gpu_image_uri", + "model_package_arn": "$gpu_model_package_arn", + } + }, + "m2": {"regional_properties": {"image_uri": "$cpu_image_uri"}}, + "c2": {"regional_properties": {"image_uri": "$cpu_image_uri"}}, + "ml.g5.48xlarge": { + "properties": {"environment_variables": {"TENSOR_PARALLEL_DEGREE": "8"}} + }, + "ml.g5.12xlarge": { + "properties": {"environment_variables": {"TENSOR_PARALLEL_DEGREE": "4"}} + }, + "inf1": {"regional_properties": {"model_package_arn": "$inf_model_package_arn"}}, + "inf2": {"regional_properties": {"model_package_arn": "$inf_model_package_arn"}}, + }, + }, + "training_ecr_specs": { + "framework": "pytorch", + "framework_version": "1.5.0", + "py_version": "py3", }, "hosting_artifact_key": "pytorch-infer/infer-pytorch-ic-mobilenet-v2.tar.gz", "training_artifact_key": "pytorch-training/train-pytorch-ic-mobilenet-v2.tar.gz", @@ -515,7 +553,16 @@ "ml.c5.2xlarge", ], "hosting_use_script_uri": True, - "metrics": [{"Regex": "val_accuracy: ([0-9\\.]+)", "Name": "pytorch-ic:val-accuracy"}], + "metrics": [ + { + "Name": "huggingface-textgeneration:train-loss", + "Regex": "'loss default': ([0-9]+\\.[0-9]+)", + }, + { + "Name": "huggingface-textgeyyyuyuyuyneration:train-loss", + "Regex": "'loss default': ([0-9]+\\.[0-9]+)", + }, + ], "model_kwargs": {"some-model-kwarg-key": "some-model-kwarg-value"}, "deploy_kwargs": {"some-model-deploy-kwarg-key": "some-model-deploy-kwarg-value"}, "estimator_kwargs": { diff --git a/tests/unit/sagemaker/jumpstart/test_types.py b/tests/unit/sagemaker/jumpstart/test_types.py index 5d42b47707..4aff263e96 100644 --- a/tests/unit/sagemaker/jumpstart/test_types.py +++ b/tests/unit/sagemaker/jumpstart/test_types.py @@ -32,7 +32,52 @@ } }, "variants": { - "p2": {"regional_properties": {"image_uri": "$gpu_image_uri"}}, + "ml.p2.12xlarge": { + "properties": { + "environment_variables": {"TENSOR_PARALLEL_DEGREE": "4"}, + "metrics": [ + { + "Name": "huggingface-textgeneration:eval-loss", + "Regex": "'eval_loss': ([0-9]+\\.[0-9]+)", + }, + { + "Name": "huggingface-textgeneration:instance-typemetric-loss", + "Regex": "'eval_loss': ([0-9]+\\.[0-9]+)", + }, + { + "Name": "huggingface-textgeneration:train-loss", + "Regex": "'instance type specific': ([0-9]+\\.[0-9]+)", + }, + { + "Name": "huggingface-textgeneration:noneyourbusiness-loss", + "Regex": "'loss-noyb instance specific': ([0-9]+\\.[0-9]+)", + }, + ], + } + }, + "p2": { + "regional_properties": {"image_uri": "$gpu_image_uri"}, + "properties": { + "metrics": [ + { + "Name": "huggingface-textgeneration:wtafigo", + "Regex": "'evasadfasdl_loss': ([0-9]+\\.[0-9]+)", + }, + { + "Name": "huggingface-textgeneration:eval-loss", + "Regex": "'eval_loss': ([0-9]+\\.[0-9]+)", + }, + { + "Name": "huggingface-textgeneration:train-loss", + "Regex": "'instance family specific': ([0-9]+\\.[0-9]+)", + }, + { + "Name": "huggingface-textgeneration:noneyourbusiness-loss", + "Regex": "'loss-noyb': ([0-9]+\\.[0-9]+)", + }, + ] + }, + }, "p3": {"regional_properties": {"image_uri": "$gpu_image_uri"}}, "ml.p3.200xlarge": {"regional_properties": {"image_uri": "$gpu_image_uri_2"}}, "p4": { @@ -652,6 +697,61 @@ def test_jumpstart_environment_variables_instance_variants(): ) +def test_jumpstart_metric_definitions_instance_variants(): + + metric_definitions = INSTANCE_TYPE_VARIANT.get_instance_specific_metric_definitions( + instance_type="ml.p2.2xlarge" + ) + assert metric_definitions == [ + { + "Name": "huggingface-textgeneration:wtafigo", + "Regex": "'evasadfasdl_loss': ([0-9]+\\.[0-9]+)", + }, + {"Name": "huggingface-textgeneration:eval-loss", "Regex": "'eval_loss': ([0-9]+\\.[0-9]+)"}, + { + "Name": "huggingface-textgeneration:train-loss", + "Regex": "'instance family specific': ([0-9]+\\.[0-9]+)", + }, + { + "Name": "huggingface-textgeneration:noneyourbusiness-loss", + "Regex": "'loss-noyb': ([0-9]+\\.[0-9]+)", + }, + ] + + metric_definitions = INSTANCE_TYPE_VARIANT.get_instance_specific_metric_definitions( + instance_type="ml.p2.12xlarge" + ) + assert metric_definitions == [ + {"Name": "huggingface-textgeneration:eval-loss", "Regex": "'eval_loss': ([0-9]+\\.[0-9]+)"}, + { + "Name": "huggingface-textgeneration:instance-typemetric-loss", + "Regex": "'eval_loss': ([0-9]+\\.[0-9]+)", + }, + { + "Name": "huggingface-textgeneration:train-loss", + "Regex": "'instance type specific': ([0-9]+\\.[0-9]+)", + }, + { + "Name": "huggingface-textgeneration:noneyourbusiness-loss", + "Regex": "'loss-noyb instance specific': ([0-9]+\\.[0-9]+)", + }, + { + "Name": "huggingface-textgeneration:wtafigo", + "Regex": "'evasadfasdl_loss': ([0-9]+\\.[0-9]+)", + }, + ] + + metric_definitions = INSTANCE_TYPE_VARIANT.get_instance_specific_metric_definitions( + instance_type="ml.g77.2xlarge" + ) + assert metric_definitions == [] + + metric_definitions = INSTANCE_TYPE_VARIANT.get_instance_specific_metric_definitions( + instance_type="ml.p3.2xlarge" + ) + assert metric_definitions == [] + + def test_jumpstart_hosting_prepacked_artifact_key_instance_variants(): assert ( INSTANCE_TYPE_VARIANT.get_instance_specific_prepacked_artifact_key( diff --git a/tests/unit/sagemaker/metric_definitions/jumpstart/test_default.py b/tests/unit/sagemaker/metric_definitions/jumpstart/test_default.py index bea68dd713..1895acc95b 100644 --- a/tests/unit/sagemaker/metric_definitions/jumpstart/test_default.py +++ b/tests/unit/sagemaker/metric_definitions/jumpstart/test_default.py @@ -19,7 +19,10 @@ from sagemaker import metric_definitions -from tests.unit.sagemaker.jumpstart.utils import get_spec_from_base_spec +from tests.unit.sagemaker.jumpstart.utils import get_spec_from_base_spec, get_special_model_spec + +mock_client = boto3.client("s3") +mock_session = Mock(s3_client=mock_client) @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @@ -27,9 +30,6 @@ def test_jumpstart_default_metric_definitions(patched_get_model_specs): patched_get_model_specs.side_effect = get_spec_from_base_spec - mock_client = boto3.client("s3") - mock_session = Mock(s3_client=mock_client) - model_id = "pytorch-ic-mobilenet-v2" region = "us-west-2" @@ -88,3 +88,93 @@ def test_jumpstart_default_metric_definitions(patched_get_model_specs): metric_definitions.retrieve_default( model_id=model_id, ) + + +@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") +def test_jumpstart_sdk_metric_definitions_instance_type_overrides(patched_get_model_specs): + + patched_get_model_specs.side_effect = get_special_model_spec + + model_id = "variant-model" + region = "us-west-2" + + # assert that we can add metric definitions to default + metrics = metric_definitions.retrieve_default( + region=region, + model_id=model_id, + model_version="*", + sagemaker_session=mock_session, + instance_type="ml.p2.48xlarge", + ) + assert metrics == [ + { + "Name": "huggingface-textgeyyyuyuyuyneration:train-loss", + "Regex": "'loss default': ([0-9]+\\.[0-9]+)", + }, + { + "Name": "huggingface-textgeneration:wtafigo", + "Regex": "'evasadfasdl_loss': ([0-9]+\\.[0-9]+)", + }, + {"Name": "huggingface-textgeneration:eval-loss", "Regex": "'eval_loss': ([0-9]+\\.[0-9]+)"}, + { + "Name": "huggingface-textgeneration:train-loss", + "Regex": "'instance family specific': ([0-9]+\\.[0-9]+)", + }, + { + "Name": "huggingface-textgeneration:noneyourbusiness-loss", + "Regex": "'loss-noyb': ([0-9]+\\.[0-9]+)", + }, + ] + + # assert that we can override default metric definitions (instance family + instance type + # specific) + metrics = metric_definitions.retrieve_default( + region=region, + model_id=model_id, + model_version="*", + sagemaker_session=mock_session, + instance_type="ml.p2.12xlarge", + ) + assert metrics == [ + { + "Name": "huggingface-textgeyyyuyuyuyneration:train-loss", + "Regex": "'loss default': ([0-9]+\\.[0-9]+)", + }, + { + "Name": "huggingface-textgeneration:instance-typemetric-loss", + "Regex": "'eval_loss': ([0-9]+\\.[0-9]+)", + }, + {"Name": "huggingface-textgeneration:eval-loss", "Regex": "'eval_loss': ([0-9]+\\.[0-9]+)"}, + { + "Name": "huggingface-textgeneration:train-loss", + "Regex": "'instance type specific': ([0-9]+\\.[0-9]+)", + }, + { + "Name": "huggingface-textgeneration:noneyourbusiness-loss", + "Regex": "'loss-noyb instance specific': ([0-9]+\\.[0-9]+)", + }, + { + "Name": "huggingface-textgeneration:wtafigo", + "Regex": "'evasadfasdl_loss': ([0-9]+\\.[0-9]+)", + }, + ] + + # assert that we can return default metric definitions for unrecognized instance + metrics = metric_definitions.retrieve_default( + region=region, + model_id=model_id, + model_version="*", + sagemaker_session=mock_session, + instance_type="ml.p9999.48xlarge", + ) + + assert metrics == [ + { + "Name": "huggingface-textgeneration:train-loss", + "Regex": "'loss default': ([0-9]+\\.[0-9]+)", + }, + { + "Name": "huggingface-textgeyyyuyuyuyneration:train-loss", + "Regex": "'loss default': ([0-9]+\\.[0-9]+)", + }, + ]