Skip to content

Commit 049a57c

Browse files
committed
feat: jumpstart model id suggestions
1 parent e3398d9 commit 049a57c

File tree

2 files changed

+38
-2
lines changed

2 files changed

+38
-2
lines changed

src/sagemaker/jumpstart/cache.py

+18-1
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
"""This module defines the JumpStartModelsCache class."""
1414
from __future__ import absolute_import
1515
import datetime
16+
from difflib import get_close_matches
1617
from typing import List, Optional
1718
import json
1819
import boto3
@@ -211,7 +212,23 @@ def _get_manifest_key_from_model_id_semantic_version(
211212
f"{model_version_to_use_incompatible_with_sagemaker} of {model_id}."
212213
)
213214
raise KeyError(error_msg)
214-
error_msg = f"Unable to find model manifest for {model_id} with version {version}."
215+
216+
error_msg = f"Unable to find model manifest for {model_id} with version {version}. "
217+
218+
other_model_id_version = self._select_version(
219+
"*", versions_incompatible_with_sagemaker
220+
)
221+
if other_model_id_version is not None:
222+
error_msg += (
223+
f"Consider using model id {model_id} with version "
224+
f"{other_model_id_version}."
225+
)
226+
227+
else:
228+
possible_model_ids = [header.model_id for header in manifest.values()]
229+
closest_model_id = get_close_matches(model_id, possible_model_ids, n=1, cutoff=0)[0]
230+
error_msg += f"Did you mean to use model id {closest_model_id}?"
231+
215232
raise KeyError(error_msg)
216233

217234
def _get_file_from_s3(

tests/unit/sagemaker/jumpstart/test_cache.py

+20-1
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,26 @@ def test_jumpstart_cache_get_header():
161161
cache.get_header(
162162
model_id="pytorch-ic-imagenet-inception-v3-classification-4", semantic_version_str="3.*"
163163
)
164-
assert "Consider upgrading" not in str(e.value)
164+
assert (
165+
"Unable to find model manifest for pytorch-ic-imagenet-inception-v3-classification-4 with "
166+
"version 3.*. Consider using model id pytorch-ic-imagenet-inception-v3-classification-4 with "
167+
"version 2.0.0."
168+
) in str(e.value)
169+
170+
with pytest.raises(KeyError) as e:
171+
cache.get_header(model_id="pytorch-ic-", semantic_version_str="*")
172+
assert (
173+
"Unable to find model manifest for pytorch-ic- with version *. "
174+
"Did you mean to use model id pytorch-ic-imagenet-inception-v3-classification-4?"
175+
) in str(e.value)
176+
177+
with pytest.raises(KeyError) as e:
178+
cache.get_header(model_id="tensorflow-ic-", semantic_version_str="*")
179+
assert (
180+
"Unable to find model manifest for tensorflow-ic- with version *. "
181+
"Did you mean to use model id tensorflow-ic-imagenet-inception-"
182+
"v3-classification-4?"
183+
) in str(e.value)
165184

166185
with pytest.raises(KeyError):
167186
cache.get_header(

0 commit comments

Comments
 (0)