File tree 2 files changed +38
-2
lines changed
tests/unit/sagemaker/jumpstart
2 files changed +38
-2
lines changed Original file line number Diff line number Diff line change 13
13
"""This module defines the JumpStartModelsCache class."""
14
14
from __future__ import absolute_import
15
15
import datetime
16
+ from difflib import get_close_matches
16
17
from typing import List , Optional
17
18
import json
18
19
import boto3
@@ -211,7 +212,23 @@ def _get_manifest_key_from_model_id_semantic_version(
211
212
f"{ model_version_to_use_incompatible_with_sagemaker } of { model_id } ."
212
213
)
213
214
raise KeyError (error_msg )
214
- error_msg = f"Unable to find model manifest for { model_id } with version { version } ."
215
+
216
+ error_msg = f"Unable to find model manifest for { model_id } with version { version } . "
217
+
218
+ other_model_id_version = self ._select_version (
219
+ "*" , versions_incompatible_with_sagemaker
220
+ )
221
+ if other_model_id_version is not None :
222
+ error_msg += (
223
+ f"Consider using model id { model_id } with version "
224
+ f"{ other_model_id_version } ."
225
+ )
226
+
227
+ else :
228
+ possible_model_ids = [header .model_id for header in manifest .values ()]
229
+ closest_model_id = get_close_matches (model_id , possible_model_ids , n = 1 , cutoff = 0 )[0 ]
230
+ error_msg += f"Did you mean to use model id { closest_model_id } ?"
231
+
215
232
raise KeyError (error_msg )
216
233
217
234
def _get_file_from_s3 (
Original file line number Diff line number Diff line change @@ -161,7 +161,26 @@ def test_jumpstart_cache_get_header():
161
161
cache .get_header (
162
162
model_id = "pytorch-ic-imagenet-inception-v3-classification-4" , semantic_version_str = "3.*"
163
163
)
164
- assert "Consider upgrading" not in str (e .value )
164
+ assert (
165
+ "Unable to find model manifest for pytorch-ic-imagenet-inception-v3-classification-4 with "
166
+ "version 3.*. Consider using model id pytorch-ic-imagenet-inception-v3-classification-4 with "
167
+ "version 2.0.0."
168
+ ) in str (e .value )
169
+
170
+ with pytest .raises (KeyError ) as e :
171
+ cache .get_header (model_id = "pytorch-ic-" , semantic_version_str = "*" )
172
+ assert (
173
+ "Unable to find model manifest for pytorch-ic- with version *. "
174
+ "Did you mean to use model id pytorch-ic-imagenet-inception-v3-classification-4?"
175
+ ) in str (e .value )
176
+
177
+ with pytest .raises (KeyError ) as e :
178
+ cache .get_header (model_id = "tensorflow-ic-" , semantic_version_str = "*" )
179
+ assert (
180
+ "Unable to find model manifest for tensorflow-ic- with version *. "
181
+ "Did you mean to use model id tensorflow-ic-imagenet-inception-"
182
+ "v3-classification-4?"
183
+ ) in str (e .value )
165
184
166
185
with pytest .raises (KeyError ):
167
186
cache .get_header (
You can’t perform that action at this time.
0 commit comments