Skip to content

feat: jumpstart instance specific metric definitions #4200

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 29 additions & 1 deletion src/sagemaker/jumpstart/artifacts/metric_definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def _retrieve_default_training_metric_definitions(
tolerate_vulnerable_model: bool = False,
tolerate_deprecated_model: bool = False,
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
instance_type: Optional[str] = None,
) -> Optional[List[Dict[str, str]]]:
"""Retrieves the default training metric definitions for the model.

Expand All @@ -55,6 +56,8 @@ def _retrieve_default_training_metric_definitions(
object, used for SageMaker interactions. If not
specified, one is created using the default AWS configuration
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
instance_type (str): An instance type to optionally supply in order to get
metric definitions specific for the instance type.
Returns:
list: the default training metric definitions to use for the model or None.
"""
Expand All @@ -72,4 +75,29 @@ def _retrieve_default_training_metric_definitions(
sagemaker_session=sagemaker_session,
)

return deepcopy(model_specs.metrics) if model_specs.metrics else None
default_metric_definitions = (
deepcopy(model_specs.metrics) if getattr(model_specs, "metrics") else []
)

instance_specific_metric_definitions = (
model_specs.training_instance_type_variants.get_instance_specific_metric_definitions(
instance_type
)
if instance_type
and getattr(model_specs, "training_instance_type_variants", None) is not None
else []
)

instance_specific_metric_name: str
for instance_specific_metric_definition in instance_specific_metric_definitions:
instance_specific_metric_name = instance_specific_metric_definition["Name"]
default_metric_definitions = list(
filter(
lambda metric_definition: metric_definition["Name"]
!= instance_specific_metric_name,
default_metric_definitions,
)
)
default_metric_definitions.append(instance_specific_metric_definition)

return default_metric_definitions
1 change: 1 addition & 0 deletions src/sagemaker/jumpstart/factory/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -633,6 +633,7 @@ def _add_metric_definitions_to_kwargs(
tolerate_deprecated_model=kwargs.tolerate_deprecated_model,
tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model,
sagemaker_session=kwargs.sagemaker_session,
instance_type=kwargs.instance_type,
)
or []
)
Expand Down
36 changes: 36 additions & 0 deletions src/sagemaker/jumpstart/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,42 @@ def to_json(self) -> Dict[str, Any]:
json_obj = {att: getattr(self, att) for att in self.__slots__ if hasattr(self, att)}
return json_obj

def get_instance_specific_metric_definitions(
self, instance_type: str
) -> List[JumpStartHyperparameter]:
"""Returns instance specific metric definitions.

Returns empty list if a model, instance type tuple does not have specific
metric definitions.
"""

if self.variants is None:
return []

instance_specific_metric_definitions: List[Dict[str, Union[str, Any]]] = (
self.variants.get(instance_type, {}).get("properties", {}).get("metrics", [])
)

instance_type_family = get_instance_type_family(instance_type)

instance_family_metric_definitions: List[Dict[str, Union[str, Any]]] = (
self.variants.get(instance_type_family, {}).get("properties", {}).get("metrics", [])
if instance_type_family not in {"", None}
else []
)

instance_specific_metric_names = {
metric_definition["Name"] for metric_definition in instance_specific_metric_definitions
}

metric_definitions_to_return = deepcopy(instance_specific_metric_definitions)

for instance_family_metric_definition in instance_family_metric_definitions:
if instance_family_metric_definition["Name"] not in instance_specific_metric_names:
metric_definitions_to_return.append(instance_family_metric_definition)

return metric_definitions_to_return

def get_instance_specific_prepacked_artifact_key(self, instance_type: str) -> Optional[str]:
"""Returns instance specific model artifact key.

Expand Down
14 changes: 9 additions & 5 deletions src/sagemaker/metric_definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def retrieve_default(
region: Optional[str] = None,
model_id: Optional[str] = None,
model_version: Optional[str] = None,
instance_type: Optional[str] = None,
tolerate_vulnerable_model: bool = False,
tolerate_deprecated_model: bool = False,
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
Expand All @@ -42,6 +43,8 @@ def retrieve_default(
retrieve the default training metric definitions. (Default: None).
model_version (str): The version of the model for which to retrieve the
default training metric definitions. (Default: None).
instance_type (str): An instance type to optionally supply in order to get
metric definitions specific for the instance type.
tolerate_vulnerable_model (bool): True if vulnerable versions of model
specifications should be tolerated (exception not raised). If False, raises an
exception if the script used by this version of the model has dependencies with known
Expand All @@ -66,10 +69,11 @@ def retrieve_default(
)

return artifacts._retrieve_default_training_metric_definitions(
model_id,
model_version,
region,
tolerate_vulnerable_model,
tolerate_deprecated_model,
model_id=model_id,
model_version=model_version,
instance_type=instance_type,
region=region,
tolerate_vulnerable_model=tolerate_vulnerable_model,
tolerate_deprecated_model=tolerate_deprecated_model,
sagemaker_session=sagemaker_session,
)
157 changes: 102 additions & 55 deletions tests/unit/sagemaker/jumpstart/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,59 +190,8 @@
"framework_version": "1.5.0",
"py_version": "py3",
},
"hosting_instance_type_variants": {
"regional_aliases": {
"us-west-2": {
"gpu_image_uri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/"
"huggingface-pytorch-inference:1.13.1-transformers4.26.0-gpu-py39-cu117-ubuntu20.04",
"cpu_image_uri": "867930986793.dkr.us-west-2.amazonaws.com/cpu-blah",
"inf_model_package_arn": "us-west-2/blah/blah/blah/inf",
"gpu_model_package_arn": "us-west-2/blah/blah/blah/gpu",
}
},
"variants": {
"p2": {
"regional_properties": {
"image_uri": "$gpu_image_uri",
"model_package_arn": "$gpu_model_package_arn",
}
},
"p3": {
"regional_properties": {
"image_uri": "$gpu_image_uri",
"model_package_arn": "$gpu_model_package_arn",
}
},
"p4": {
"regional_properties": {
"image_uri": "$gpu_image_uri",
"model_package_arn": "$gpu_model_package_arn",
}
},
"g4dn": {
"regional_properties": {
"image_uri": "$gpu_image_uri",
"model_package_arn": "$gpu_model_package_arn",
}
},
"m2": {"regional_properties": {"image_uri": "$cpu_image_uri"}},
"c2": {"regional_properties": {"image_uri": "$cpu_image_uri"}},
"ml.g5.48xlarge": {
"properties": {"environment_variables": {"TENSOR_PARALLEL_DEGREE": "8"}}
},
"ml.g5.12xlarge": {
"properties": {"environment_variables": {"TENSOR_PARALLEL_DEGREE": "4"}}
},
"inf1": {"regional_properties": {"model_package_arn": "$inf_model_package_arn"}},
"inf2": {"regional_properties": {"model_package_arn": "$inf_model_package_arn"}},
},
},
"training_ecr_specs": {
"framework": "pytorch",
"framework_version": "1.5.0",
"py_version": "py3",
},
"training_instance_type_variants": {
"regional_aliases": {},
"variants": {
"ml.p2.12xlarge": {
"properties": {
Expand Down Expand Up @@ -305,9 +254,28 @@
"scope": "algorithm",
},
],
"metrics": [
{
"Name": "huggingface-textgeneration:instance-typemetric-loss",
"Regex": "'eval_loss': ([0-9]+\\.[0-9]+)",
},
{
"Name": "huggingface-textgeneration:eval-loss",
"Regex": "'eval_loss': ([0-9]+\\.[0-9]+)",
},
{
"Name": "huggingface-textgeneration:train-loss",
"Regex": "'instance type specific': ([0-9]+\\.[0-9]+)",
},
{
"Name": "huggingface-textgeneration:noneyourbusiness-loss",
"Regex": "'loss-noyb instance specific': ([0-9]+\\.[0-9]+)",
},
],
}
},
"p2": {
"regional_properties": {"image_uri": "$gpu_ecr_uri_2"},
"properties": {
"hyperparameters": [
{
Expand Down Expand Up @@ -372,10 +340,80 @@
"default": "20",
"scope": "container",
},
]
],
"metrics": [
{
"Name": "huggingface-textgeneration:wtafigo",
"Regex": "'evasadfasdl_loss': ([0-9]+\\.[0-9]+)",
},
{
"Name": "huggingface-textgeneration:eval-loss",
"Regex": "'eval_loss': ([0-9]+\\.[0-9]+)",
},
{
"Name": "huggingface-textgeneration:train-loss",
"Regex": "'instance family specific': ([0-9]+\\.[0-9]+)",
},
{
"Name": "huggingface-textgeneration:noneyourbusiness-loss",
"Regex": "'loss-noyb': ([0-9]+\\.[0-9]+)",
},
],
},
},
},
},
"hosting_instance_type_variants": {
"regional_aliases": {
"us-west-2": {
"gpu_image_uri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/"
"huggingface-pytorch-inference:1.13.1-transformers4.26.0-gpu-py39-cu117-ubuntu20.04",
"cpu_image_uri": "867930986793.dkr.us-west-2.amazonaws.com/cpu-blah",
"inf_model_package_arn": "us-west-2/blah/blah/blah/inf",
"gpu_model_package_arn": "us-west-2/blah/blah/blah/gpu",
}
},
"variants": {
"p2": {
"regional_properties": {
"image_uri": "$gpu_image_uri",
"model_package_arn": "$gpu_model_package_arn",
}
},
}
"p3": {
"regional_properties": {
"image_uri": "$gpu_image_uri",
"model_package_arn": "$gpu_model_package_arn",
}
},
"p4": {
"regional_properties": {
"image_uri": "$gpu_image_uri",
"model_package_arn": "$gpu_model_package_arn",
}
},
"g4dn": {
"regional_properties": {
"image_uri": "$gpu_image_uri",
"model_package_arn": "$gpu_model_package_arn",
}
},
"m2": {"regional_properties": {"image_uri": "$cpu_image_uri"}},
"c2": {"regional_properties": {"image_uri": "$cpu_image_uri"}},
"ml.g5.48xlarge": {
"properties": {"environment_variables": {"TENSOR_PARALLEL_DEGREE": "8"}}
},
"ml.g5.12xlarge": {
"properties": {"environment_variables": {"TENSOR_PARALLEL_DEGREE": "4"}}
},
"inf1": {"regional_properties": {"model_package_arn": "$inf_model_package_arn"}},
"inf2": {"regional_properties": {"model_package_arn": "$inf_model_package_arn"}},
},
},
"training_ecr_specs": {
"framework": "pytorch",
"framework_version": "1.5.0",
"py_version": "py3",
},
"hosting_artifact_key": "pytorch-infer/infer-pytorch-ic-mobilenet-v2.tar.gz",
"training_artifact_key": "pytorch-training/train-pytorch-ic-mobilenet-v2.tar.gz",
Expand Down Expand Up @@ -515,7 +553,16 @@
"ml.c5.2xlarge",
],
"hosting_use_script_uri": True,
"metrics": [{"Regex": "val_accuracy: ([0-9\\.]+)", "Name": "pytorch-ic:val-accuracy"}],
"metrics": [
{
"Name": "huggingface-textgeneration:train-loss",
"Regex": "'loss default': ([0-9]+\\.[0-9]+)",
},
{
"Name": "huggingface-textgeyyyuyuyuyneration:train-loss",
"Regex": "'loss default': ([0-9]+\\.[0-9]+)",
},
],
"model_kwargs": {"some-model-kwarg-key": "some-model-kwarg-value"},
"deploy_kwargs": {"some-model-deploy-kwarg-key": "some-model-deploy-kwarg-value"},
"estimator_kwargs": {
Expand Down
Loading