Skip to content

feat: jumpstart model id suggestions #2899

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Feb 18, 2022
31 changes: 26 additions & 5 deletions src/sagemaker/jumpstart/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
"""This module defines the JumpStartModelsCache class."""
from __future__ import absolute_import
import datetime
from difflib import get_close_matches
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice find!

from typing import List, Optional
import json
import boto3
Expand Down Expand Up @@ -204,14 +205,34 @@ def _get_manifest_key_from_model_id_semantic_version(
sm_version_to_use = sm_version_to_use_list[0]

error_msg = (
f"Unable to find model manifest for {model_id} with version {version} "
f"compatible with your SageMaker version ({sm_version}). "
f"Unable to find model manifest for '{model_id}' with version '{version}' "
f"compatible with your SageMaker version ('{sm_version}'). "
f"Consider upgrading your SageMaker library to at least version "
f"{sm_version_to_use} so you can use version "
f"{model_version_to_use_incompatible_with_sagemaker} of {model_id}."
f"'{sm_version_to_use}' so you can use version "
f"'{model_version_to_use_incompatible_with_sagemaker}' of '{model_id}'."
)
raise KeyError(error_msg)
error_msg = f"Unable to find model manifest for {model_id} with version {version}."

error_msg = f"Unable to find model manifest for '{model_id}' with version '{version}'. "
error_msg += (
"Visit https://sagemaker.readthedocs.io/en/stable/doc_utils/jumpstart.html"
" for updated list of models. "
)

other_model_id_version = self._select_version(
"*", versions_incompatible_with_sagemaker
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you please add a comment as to why you use this variable.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

reiterate: please explain why we use a variable called versions_incompatitble_with_sagemaker here.

) # all versions here are incompatible with sagemaker
if other_model_id_version is not None:
error_msg += (
f"Consider using model ID '{model_id}' with version "
f"'{other_model_id_version}'."
)

else:
possible_model_ids = [header.model_id for header in manifest.values()]
closest_model_id = get_close_matches(model_id, possible_model_ids, n=1, cutoff=0)[0]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

question: should you show only one? or say for example: at least 1, up to 3 when score > xx ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's just do 1 and make it simple. If we want to make this really good, we're better off just having a separate utility for searching model ids.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

another question, can you get an IndexError here if get_close_matches returns an empty list?

If that's possible, could you handle this edge case and add a unit test for it?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, because the cutoff is 0, there will always be a match.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. I know this is super edge case, but could there be a case where possible_model_ids is an empty list?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only if the manifest is empty

error_msg += f"Did you mean to use model ID '{closest_model_id}'?"

raise KeyError(error_msg)

def _get_file_from_s3(
Expand Down
12 changes: 6 additions & 6 deletions src/sagemaker/jumpstart/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,12 +135,12 @@ def from_json(self, json_obj: Dict[str, str]) -> None:
class JumpStartECRSpecs(JumpStartDataHolderType):
"""Data class for JumpStart ECR specs."""

__slots__ = {
__slots__ = [
"framework",
"framework_version",
"py_version",
"huggingface_transformers_version",
}
]

def __init__(self, spec: Dict[str, Any]):
"""Initializes a JumpStartECRSpecs object from its json representation.
Expand Down Expand Up @@ -173,7 +173,7 @@ def to_json(self) -> Dict[str, Any]:
class JumpStartHyperparameter(JumpStartDataHolderType):
"""Data class for JumpStart hyperparameter definition in the training container."""

__slots__ = {
__slots__ = [
"name",
"type",
"options",
Expand All @@ -183,7 +183,7 @@ class JumpStartHyperparameter(JumpStartDataHolderType):
"max",
"exclusive_min",
"exclusive_max",
}
]

def __init__(self, spec: Dict[str, Any]):
"""Initializes a JumpStartHyperparameter object from its json representation.
Expand Down Expand Up @@ -234,12 +234,12 @@ def to_json(self) -> Dict[str, Any]:
class JumpStartEnvironmentVariable(JumpStartDataHolderType):
"""Data class for JumpStart environment variable definitions in the hosting container."""

__slots__ = {
__slots__ = [
"name",
"type",
"default",
"scope",
}
]

def __init__(self, spec: Dict[str, Any]):
"""Initializes a JumpStartEnvironmentVariable object from its json representation.
Expand Down
22 changes: 11 additions & 11 deletions src/sagemaker/jumpstart/validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def _validate_hyperparameter(

if len(hyperparameter_spec) > 1:
raise JumpStartHyperparametersError(
f"Unable to perform validation -- found multiple hyperparameter "
"Unable to perform validation -- found multiple hyperparameter "
f"'{hyperparameter_name}' in model specs."
)

Expand All @@ -76,35 +76,35 @@ def _validate_hyperparameter(
if hyperparameter_value not in hyperparameter_spec.options:
raise JumpStartHyperparametersError(
f"Hyperparameter '{hyperparameter_name}' must have one of the following "
f"values: {', '.join(hyperparameter_spec.options)}"
f"values: {', '.join(hyperparameter_spec.options)}."
)

if hasattr(hyperparameter_spec, "min"):
if len(hyperparameter_value) < hyperparameter_spec.min:
raise JumpStartHyperparametersError(
f"Hyperparameter '{hyperparameter_name}' must have length no less than "
f"{hyperparameter_spec.min}"
f"{hyperparameter_spec.min}."
)

if hasattr(hyperparameter_spec, "exclusive_min"):
if len(hyperparameter_value) <= hyperparameter_spec.exclusive_min:
raise JumpStartHyperparametersError(
f"Hyperparameter '{hyperparameter_name}' must have length greater than "
f"{hyperparameter_spec.exclusive_min}"
f"{hyperparameter_spec.exclusive_min}."
)

if hasattr(hyperparameter_spec, "max"):
if len(hyperparameter_value) > hyperparameter_spec.max:
raise JumpStartHyperparametersError(
f"Hyperparameter '{hyperparameter_name}' must have length no greater than "
f"{hyperparameter_spec.max}"
f"{hyperparameter_spec.max}."
)

if hasattr(hyperparameter_spec, "exclusive_max"):
if len(hyperparameter_value) >= hyperparameter_spec.exclusive_max:
raise JumpStartHyperparametersError(
f"Hyperparameter '{hyperparameter_name}' must have length less than "
f"{hyperparameter_spec.exclusive_max}"
f"{hyperparameter_spec.exclusive_max}."
)

# validate numeric types
Expand All @@ -125,35 +125,35 @@ def _validate_hyperparameter(
if not hyperparameter_value_str[start_index:].isdigit():
raise JumpStartHyperparametersError(
f"Hyperparameter '{hyperparameter_name}' must be integer type "
"('{hyperparameter_value}')."
f"('{hyperparameter_value}')."
)

if hasattr(hyperparameter_spec, "min"):
if numeric_hyperparam_value < hyperparameter_spec.min:
raise JumpStartHyperparametersError(
f"Hyperparameter '{hyperparameter_name}' can be no less than "
"{hyperparameter_spec.min}."
f"{hyperparameter_spec.min}."
)

if hasattr(hyperparameter_spec, "max"):
if numeric_hyperparam_value > hyperparameter_spec.max:
raise JumpStartHyperparametersError(
f"Hyperparameter '{hyperparameter_name}' can be no greater than "
"{hyperparameter_spec.max}."
f"{hyperparameter_spec.max}."
)

if hasattr(hyperparameter_spec, "exclusive_min"):
if numeric_hyperparam_value <= hyperparameter_spec.exclusive_min:
raise JumpStartHyperparametersError(
f"Hyperparameter '{hyperparameter_name}' must be greater than "
"{hyperparameter_spec.exclusive_min}."
f"{hyperparameter_spec.exclusive_min}."
)

if hasattr(hyperparameter_spec, "exclusive_max"):
if numeric_hyperparam_value >= hyperparameter_spec.exclusive_max:
raise JumpStartHyperparametersError(
f"Hyperparameter '{hyperparameter_name}' must be less than "
"{hyperparameter_spec.exclusive_max}."
f"{hyperparameter_spec.exclusive_max}."
)


Expand Down
Loading