Skip to content

Commit 133c61d

Browse files
e-davidsonEli Davidsonsage-maker
authored
fix: bug in get latest version was getting the max sorted alphabetically (#5014)
* fix: bug in get latest version was getting the max sorted alphabetically instead of sem-ver * handle invalid sev ver and incompatible sagemaker versions --------- Co-authored-by: Eli Davidson <[email protected]> Co-authored-by: parknate@ <[email protected]>
1 parent e7ce13c commit 133c61d

File tree

4 files changed

+152
-5
lines changed

4 files changed

+152
-5
lines changed

src/sagemaker/jumpstart/cache.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -262,7 +262,7 @@ def _model_id_retrieval_function(
262262
return JumpStartVersionedModelId(model_id, sm_compatible_model_version)
263263

264264
versions_incompatible_with_sagemaker = [
265-
Version(header.version)
265+
header.version
266266
for header in manifest.values() # type: ignore
267267
if header.model_id == model_id
268268
]
@@ -540,9 +540,7 @@ def _select_version(
540540
"""
541541

542542
if version_str == "*":
543-
if len(available_versions) == 0:
544-
return None
545-
return str(max(available_versions))
543+
return utils.get_latest_version(available_versions)
546544

547545
if model_type == JumpStartModelType.PROPRIETARY:
548546
if "*" in version_str:

src/sagemaker/jumpstart/utils.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from urllib.parse import urlparse
2222
import boto3
2323
from botocore.exceptions import ClientError
24-
from packaging.version import Version
24+
from packaging.version import Version, InvalidVersion
2525
import botocore
2626
from sagemaker_core.shapes import ModelAccessConfig
2727
import sagemaker
@@ -1630,3 +1630,11 @@ def get_draft_model_content_bucket(provider: Dict, region: str) -> str:
16301630
return get_jumpstart_gated_content_bucket(region=region)
16311631
return get_jumpstart_content_bucket(region=region)
16321632
return neo_bucket
1633+
1634+
1635+
def get_latest_version(versions: List[str]) -> Optional[str]:
1636+
"""Returns the latest version using sem-ver when possible."""
1637+
try:
1638+
return None if not versions else max(versions, key=Version)
1639+
except InvalidVersion:
1640+
return max(versions)

tests/unit/sagemaker/jumpstart/test_cache.py

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,10 @@
2222
from mock.mock import MagicMock
2323
import pytest
2424
from mock import patch
25+
from packaging.version import Version
2526

27+
28+
from sagemaker.jumpstart import utils
2629
from sagemaker.jumpstart.cache import (
2730
JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY,
2831
JUMPSTART_DEFAULT_PROPRIETARY_MANIFEST_KEY,
@@ -33,6 +36,7 @@
3336
ENV_VARIABLE_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE,
3437
)
3538
from sagemaker.jumpstart.types import (
39+
JumpStartCachedContentValue,
3640
JumpStartModelHeader,
3741
JumpStartModelSpecs,
3842
JumpStartVersionedModelId,
@@ -1119,3 +1123,124 @@ def test_jumpstart_local_metadata_override_specs_not_exist_both_directories(
11191123
),
11201124
]
11211125
)
1126+
1127+
1128+
@patch.object(JumpStartModelsCache, "_retrieval_function")
1129+
def test_jumpstart_cache_handles_versioning_correctly_for_open_source_weights(
1130+
retrieval_function: Mock,
1131+
):
1132+
sm_version = Version(utils.get_sagemaker_version())
1133+
new_sm_version = Version(str(sm_version.major + 1) + ".0.0")
1134+
print(str(new_sm_version))
1135+
versions = ["1.0.0", "2.9.1", "2.16.0"]
1136+
manifest = [
1137+
{
1138+
"model_id": "test-model",
1139+
"version": version,
1140+
"min_version": "2.49.0",
1141+
"spec_key": "spec_key",
1142+
}
1143+
for version in versions
1144+
]
1145+
1146+
manifest.append(
1147+
{
1148+
"model_id": "test-model",
1149+
"version": "3.0.0",
1150+
"min_version": str(new_sm_version),
1151+
"spec_key": "spec_key",
1152+
}
1153+
)
1154+
1155+
manifest_dict = {}
1156+
for header in manifest:
1157+
header_obj = JumpStartModelHeader(header)
1158+
manifest_dict[JumpStartVersionedModelId(header_obj.model_id, header_obj.version)] = (
1159+
header_obj
1160+
)
1161+
retrieval_function.return_value = JumpStartCachedContentValue(formatted_content=manifest_dict)
1162+
key = JumpStartVersionedModelId("test-model", "*")
1163+
1164+
cache = JumpStartModelsCache(s3_bucket_name="some_bucket")
1165+
result = cache._get_open_weight_manifest_key_from_model_id(key=key, value=None)
1166+
1167+
assert_key = JumpStartVersionedModelId("test-model", "2.16.0")
1168+
1169+
assert result == assert_key
1170+
1171+
1172+
@patch.object(JumpStartModelsCache, "_retrieval_function")
1173+
def test_jumpstart_cache_handles_versioning_correctly_for_proprietary_weights(
1174+
retrieval_function: Mock,
1175+
):
1176+
sm_version = Version(utils.get_sagemaker_version())
1177+
new_sm_version = Version(str(sm_version.major + 1) + ".0.0")
1178+
print(str(new_sm_version))
1179+
versions = ["1.0.0", "2.9.1", "2.16.0"]
1180+
manifest = [
1181+
{
1182+
"model_id": "test-model",
1183+
"version": version,
1184+
"min_version": "2.49.0",
1185+
"spec_key": "spec_key",
1186+
}
1187+
for version in versions
1188+
]
1189+
1190+
manifest.append(
1191+
{
1192+
"model_id": "test-model",
1193+
"version": "3.0.0",
1194+
"min_version": str(new_sm_version),
1195+
"spec_key": "spec_key",
1196+
}
1197+
)
1198+
1199+
manifest_dict = {}
1200+
for header in manifest:
1201+
header_obj = JumpStartModelHeader(header)
1202+
manifest_dict[JumpStartVersionedModelId(header_obj.model_id, header_obj.version)] = (
1203+
header_obj
1204+
)
1205+
retrieval_function.return_value = JumpStartCachedContentValue(formatted_content=manifest_dict)
1206+
key = JumpStartVersionedModelId("test-model", "*")
1207+
1208+
cache = JumpStartModelsCache(s3_bucket_name="some_bucket")
1209+
result = cache._get_proprietary_manifest_key_from_model_id(key=key, value=None)
1210+
1211+
assert_key = JumpStartVersionedModelId("test-model", "2.16.0")
1212+
1213+
assert result == assert_key
1214+
1215+
1216+
@patch.object(JumpStartModelsCache, "_retrieval_function")
1217+
def test_jumpstart_cache_handles_versioning_correctly_non_sem_ver(retrieval_function: Mock):
1218+
sm_version = Version(utils.get_sagemaker_version())
1219+
new_sm_version = Version(str(sm_version.major + 1) + ".0.0")
1220+
print(str(new_sm_version))
1221+
versions = ["abc", "2.9.1", "2.16.0"]
1222+
manifest = [
1223+
{
1224+
"model_id": "test-model",
1225+
"version": version,
1226+
"min_version": "2.49.0",
1227+
"spec_key": "spec_key",
1228+
}
1229+
for version in versions
1230+
]
1231+
1232+
manifest_dict = {}
1233+
for header in manifest:
1234+
header_obj = JumpStartModelHeader(header)
1235+
manifest_dict[JumpStartVersionedModelId(header_obj.model_id, header_obj.version)] = (
1236+
header_obj
1237+
)
1238+
retrieval_function.return_value = JumpStartCachedContentValue(formatted_content=manifest_dict)
1239+
key = JumpStartVersionedModelId("test-model", "*")
1240+
1241+
cache = JumpStartModelsCache(s3_bucket_name="some_bucket")
1242+
result = cache._get_open_weight_manifest_key_from_model_id(key=key, value=None)
1243+
1244+
assert_key = JumpStartVersionedModelId("test-model", "abc")
1245+
1246+
assert result == assert_key

tests/unit/sagemaker/jumpstart/test_utils.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2144,6 +2144,22 @@ def test_has_instance_rate_stat(stats, expected):
21442144
assert utils.has_instance_rate_stat(stats) is expected
21452145

21462146

2147+
def test_get_latest_version():
2148+
assert utils.get_latest_version(["2.9.1", "2.16.0", "1.0.0"]) == "2.16.0"
2149+
2150+
2151+
def test_get_latest_version_empty_list_is_none():
2152+
assert utils.get_latest_version([]) is None
2153+
2154+
2155+
def test_get_latest_version_none_is_none():
2156+
assert utils.get_latest_version(None) is None
2157+
2158+
2159+
def test_get_latest_version_with_invalid_sem_ver():
2160+
assert utils.get_latest_version(["2.9.1", "2.16.0", "1.0.0", "abc"]) == "abc"
2161+
2162+
21472163
@pytest.mark.parametrize(
21482164
"data, expected",
21492165
[(None, []), ([], []), (get_base_deployment_configs_metadata(), get_base_deployment_configs())],

0 commit comments

Comments
 (0)