Skip to content

Commit e7979a5

Browse files
Captainiajiapinw
authored andcommitted
fix: list jumpstart models with invalid version strings (aws#4511)
* fix: list jumpstart models with invalid versions * docstyle * docstyle * pylint * add more test * fix
1 parent bcb2813 commit e7979a5

File tree

3 files changed

+34
-1
lines changed

3 files changed

+34
-1
lines changed

src/sagemaker/jumpstart/notebook_utils.py

+11-1
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,15 @@ def list_jumpstart_scripts( # pylint: disable=redefined-builtin
262262
return sorted(list(scripts))
263263

264264

265+
def _is_valid_version(version: str) -> bool:
266+
"""Checks if the version is convertable to Version class."""
267+
try:
268+
Version(version)
269+
return True
270+
except Exception: # pylint: disable=broad-except
271+
return False
272+
273+
265274
def list_jumpstart_models( # pylint: disable=redefined-builtin
266275
filter: Union[Operator, str] = Constant(BooleanValues.TRUE),
267276
region: Optional[str] = None,
@@ -304,7 +313,8 @@ def list_jumpstart_models( # pylint: disable=redefined-builtin
304313
):
305314
if model_id not in model_id_version_dict:
306315
model_id_version_dict[model_id] = list()
307-
model_id_version_dict[model_id].append(Version(version))
316+
model_version = Version(version) if _is_valid_version(version) else version
317+
model_id_version_dict[model_id].append(model_version)
308318

309319
if not list_versions:
310320
return sorted(list(model_id_version_dict.keys()))

tests/unit/sagemaker/jumpstart/constants.py

+14
Original file line numberDiff line numberDiff line change
@@ -7577,6 +7577,20 @@
75777577
"spec_key": "proprietary-models/ai21-paraphrase/proprietary_specs_1.0.005.json",
75787578
"search_keywords": ["Text2Text", "Generation"],
75797579
},
7580+
{
7581+
"model_id": "ai21-paraphrase",
7582+
"version": "v1.00-rc2-not-valid-version",
7583+
"min_version": "2.0.0",
7584+
"spec_key": "proprietary-models/ai21-paraphrase/proprietary_specs_1.0.005.json",
7585+
"search_keywords": ["Text2Text", "Generation"],
7586+
},
7587+
{
7588+
"model_id": "nc-soft-model-1",
7589+
"version": "v3.0-not-valid-version!",
7590+
"min_version": "2.0.0",
7591+
"spec_key": "proprietary-models/nc-soft-model-1/proprietary_specs_1.0.005.json",
7592+
"search_keywords": ["Text2Text", "Generation"],
7593+
},
75807594
]
75817595

75827596
BASE_PROPRIETARY_SPEC = {

tests/unit/sagemaker/jumpstart/test_notebook_utils.py

+9
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
list_jumpstart_models,
2626
list_jumpstart_scripts,
2727
list_jumpstart_tasks,
28+
_is_valid_version,
2829
)
2930

3031

@@ -185,6 +186,13 @@ def test_list_jumpstart_frameworks(
185186
patched_get_model_specs.assert_not_called()
186187

187188

189+
def test_is_valid_version():
190+
valid_version_strs = ["1.0", "1.0.0", "2012.4", "1!1.0", "1.dev0", "1.2.3+abc.dev1"]
191+
invalid_version_strs = ["1.1.053_m", "invalid version", "v1-1.0-v2", "@"]
192+
assert all(_is_valid_version(v) for v in valid_version_strs)
193+
assert not any(_is_valid_version(v) for v in invalid_version_strs)
194+
195+
188196
class ListJumpStartModels(TestCase):
189197
@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest")
190198
@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
@@ -626,6 +634,7 @@ def test_list_jumpstart_proprietary_models(
626634
"ai21-paraphrase",
627635
"ai21-summarization",
628636
"lighton-mini-instruct40b",
637+
"nc-soft-model-1",
629638
]
630639

631640
all_open_weight_model_ids = [

0 commit comments

Comments
 (0)