Skip to content

Commit c62c78e

Browse files
evakravinavinsonimufaddal-rohawalaqidewenwhenHappyAmazonian
committed
feat: jumpstart model id suggestions (#2899)
Co-authored-by: Navin Soni <[email protected]> Co-authored-by: Mufaddal Rohawala <[email protected]> Co-authored-by: qidewenwhen <[email protected]> Co-authored-by: HappyAmazonian <[email protected]>
1 parent 89ef9ee commit c62c78e

File tree

5 files changed

+123
-43
lines changed

5 files changed

+123
-43
lines changed

src/sagemaker/jumpstart/cache.py

+26-5
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
"""This module defines the JumpStartModelsCache class."""
1414
from __future__ import absolute_import
1515
import datetime
16+
from difflib import get_close_matches
1617
from typing import List, Optional
1718
import json
1819
import boto3
@@ -204,14 +205,34 @@ def _get_manifest_key_from_model_id_semantic_version(
204205
sm_version_to_use = sm_version_to_use_list[0]
205206

206207
error_msg = (
207-
f"Unable to find model manifest for {model_id} with version {version} "
208-
f"compatible with your SageMaker version ({sm_version}). "
208+
f"Unable to find model manifest for '{model_id}' with version '{version}' "
209+
f"compatible with your SageMaker version ('{sm_version}'). "
209210
f"Consider upgrading your SageMaker library to at least version "
210-
f"{sm_version_to_use} so you can use version "
211-
f"{model_version_to_use_incompatible_with_sagemaker} of {model_id}."
211+
f"'{sm_version_to_use}' so you can use version "
212+
f"'{model_version_to_use_incompatible_with_sagemaker}' of '{model_id}'."
212213
)
213214
raise KeyError(error_msg)
214-
error_msg = f"Unable to find model manifest for {model_id} with version {version}."
215+
216+
error_msg = f"Unable to find model manifest for '{model_id}' with version '{version}'. "
217+
error_msg += (
218+
"Visit https://sagemaker.readthedocs.io/en/stable/doc_utils/jumpstart.html"
219+
" for updated list of models. "
220+
)
221+
222+
other_model_id_version = self._select_version(
223+
"*", versions_incompatible_with_sagemaker
224+
) # all versions here are incompatible with sagemaker
225+
if other_model_id_version is not None:
226+
error_msg += (
227+
f"Consider using model ID '{model_id}' with version "
228+
f"'{other_model_id_version}'."
229+
)
230+
231+
else:
232+
possible_model_ids = [header.model_id for header in manifest.values()]
233+
closest_model_id = get_close_matches(model_id, possible_model_ids, n=1, cutoff=0)[0]
234+
error_msg += f"Did you mean to use model ID '{closest_model_id}'?"
235+
215236
raise KeyError(error_msg)
216237

217238
def _get_file_from_s3(

src/sagemaker/jumpstart/types.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -135,12 +135,12 @@ def from_json(self, json_obj: Dict[str, str]) -> None:
135135
class JumpStartECRSpecs(JumpStartDataHolderType):
136136
"""Data class for JumpStart ECR specs."""
137137

138-
__slots__ = {
138+
__slots__ = [
139139
"framework",
140140
"framework_version",
141141
"py_version",
142142
"huggingface_transformers_version",
143-
}
143+
]
144144

145145
def __init__(self, spec: Dict[str, Any]):
146146
"""Initializes a JumpStartECRSpecs object from its json representation.
@@ -173,7 +173,7 @@ def to_json(self) -> Dict[str, Any]:
173173
class JumpStartHyperparameter(JumpStartDataHolderType):
174174
"""Data class for JumpStart hyperparameter definition in the training container."""
175175

176-
__slots__ = {
176+
__slots__ = [
177177
"name",
178178
"type",
179179
"options",
@@ -183,7 +183,7 @@ class JumpStartHyperparameter(JumpStartDataHolderType):
183183
"max",
184184
"exclusive_min",
185185
"exclusive_max",
186-
}
186+
]
187187

188188
def __init__(self, spec: Dict[str, Any]):
189189
"""Initializes a JumpStartHyperparameter object from its json representation.
@@ -234,12 +234,12 @@ def to_json(self) -> Dict[str, Any]:
234234
class JumpStartEnvironmentVariable(JumpStartDataHolderType):
235235
"""Data class for JumpStart environment variable definitions in the hosting container."""
236236

237-
__slots__ = {
237+
__slots__ = [
238238
"name",
239239
"type",
240240
"default",
241241
"scope",
242-
}
242+
]
243243

244244
def __init__(self, spec: Dict[str, Any]):
245245
"""Initializes a JumpStartEnvironmentVariable object from its json representation.

src/sagemaker/jumpstart/validators.py

+11-11
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def _validate_hyperparameter(
4949

5050
if len(hyperparameter_spec) > 1:
5151
raise JumpStartHyperparametersError(
52-
f"Unable to perform validation -- found multiple hyperparameter "
52+
"Unable to perform validation -- found multiple hyperparameter "
5353
f"'{hyperparameter_name}' in model specs."
5454
)
5555

@@ -76,35 +76,35 @@ def _validate_hyperparameter(
7676
if hyperparameter_value not in hyperparameter_spec.options:
7777
raise JumpStartHyperparametersError(
7878
f"Hyperparameter '{hyperparameter_name}' must have one of the following "
79-
f"values: {', '.join(hyperparameter_spec.options)}"
79+
f"values: {', '.join(hyperparameter_spec.options)}."
8080
)
8181

8282
if hasattr(hyperparameter_spec, "min"):
8383
if len(hyperparameter_value) < hyperparameter_spec.min:
8484
raise JumpStartHyperparametersError(
8585
f"Hyperparameter '{hyperparameter_name}' must have length no less than "
86-
f"{hyperparameter_spec.min}"
86+
f"{hyperparameter_spec.min}."
8787
)
8888

8989
if hasattr(hyperparameter_spec, "exclusive_min"):
9090
if len(hyperparameter_value) <= hyperparameter_spec.exclusive_min:
9191
raise JumpStartHyperparametersError(
9292
f"Hyperparameter '{hyperparameter_name}' must have length greater than "
93-
f"{hyperparameter_spec.exclusive_min}"
93+
f"{hyperparameter_spec.exclusive_min}."
9494
)
9595

9696
if hasattr(hyperparameter_spec, "max"):
9797
if len(hyperparameter_value) > hyperparameter_spec.max:
9898
raise JumpStartHyperparametersError(
9999
f"Hyperparameter '{hyperparameter_name}' must have length no greater than "
100-
f"{hyperparameter_spec.max}"
100+
f"{hyperparameter_spec.max}."
101101
)
102102

103103
if hasattr(hyperparameter_spec, "exclusive_max"):
104104
if len(hyperparameter_value) >= hyperparameter_spec.exclusive_max:
105105
raise JumpStartHyperparametersError(
106106
f"Hyperparameter '{hyperparameter_name}' must have length less than "
107-
f"{hyperparameter_spec.exclusive_max}"
107+
f"{hyperparameter_spec.exclusive_max}."
108108
)
109109

110110
# validate numeric types
@@ -125,35 +125,35 @@ def _validate_hyperparameter(
125125
if not hyperparameter_value_str[start_index:].isdigit():
126126
raise JumpStartHyperparametersError(
127127
f"Hyperparameter '{hyperparameter_name}' must be integer type "
128-
"('{hyperparameter_value}')."
128+
f"('{hyperparameter_value}')."
129129
)
130130

131131
if hasattr(hyperparameter_spec, "min"):
132132
if numeric_hyperparam_value < hyperparameter_spec.min:
133133
raise JumpStartHyperparametersError(
134134
f"Hyperparameter '{hyperparameter_name}' can be no less than "
135-
"{hyperparameter_spec.min}."
135+
f"{hyperparameter_spec.min}."
136136
)
137137

138138
if hasattr(hyperparameter_spec, "max"):
139139
if numeric_hyperparam_value > hyperparameter_spec.max:
140140
raise JumpStartHyperparametersError(
141141
f"Hyperparameter '{hyperparameter_name}' can be no greater than "
142-
"{hyperparameter_spec.max}."
142+
f"{hyperparameter_spec.max}."
143143
)
144144

145145
if hasattr(hyperparameter_spec, "exclusive_min"):
146146
if numeric_hyperparam_value <= hyperparameter_spec.exclusive_min:
147147
raise JumpStartHyperparametersError(
148148
f"Hyperparameter '{hyperparameter_name}' must be greater than "
149-
"{hyperparameter_spec.exclusive_min}."
149+
f"{hyperparameter_spec.exclusive_min}."
150150
)
151151

152152
if hasattr(hyperparameter_spec, "exclusive_max"):
153153
if numeric_hyperparam_value >= hyperparameter_spec.exclusive_max:
154154
raise JumpStartHyperparametersError(
155155
f"Hyperparameter '{hyperparameter_name}' must be less than "
156-
"{hyperparameter_spec.exclusive_max}."
156+
f"{hyperparameter_spec.exclusive_max}."
157157
)
158158

159159

0 commit comments

Comments
 (0)