|
22 | 22 | from mock.mock import MagicMock
|
23 | 23 | import pytest
|
24 | 24 | from mock import patch
|
| 25 | +from packaging.version import Version |
25 | 26 |
|
| 27 | + |
| 28 | +from sagemaker.jumpstart import utils |
26 | 29 | from sagemaker.jumpstart.cache import (
|
27 | 30 | JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY,
|
28 | 31 | JUMPSTART_DEFAULT_PROPRIETARY_MANIFEST_KEY,
|
|
33 | 36 | ENV_VARIABLE_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE,
|
34 | 37 | )
|
35 | 38 | from sagemaker.jumpstart.types import (
|
| 39 | + JumpStartCachedContentValue, |
36 | 40 | JumpStartModelHeader,
|
37 | 41 | JumpStartModelSpecs,
|
38 | 42 | JumpStartVersionedModelId,
|
@@ -1119,3 +1123,124 @@ def test_jumpstart_local_metadata_override_specs_not_exist_both_directories(
|
1119 | 1123 | ),
|
1120 | 1124 | ]
|
1121 | 1125 | )
|
| 1126 | + |
| 1127 | + |
| 1128 | +@patch.object(JumpStartModelsCache, "_retrieval_function") |
| 1129 | +def test_jumpstart_cache_handles_versioning_correctly_for_open_source_weights( |
| 1130 | + retrieval_function: Mock, |
| 1131 | +): |
| 1132 | + sm_version = Version(utils.get_sagemaker_version()) |
| 1133 | + new_sm_version = Version(str(sm_version.major + 1) + ".0.0") |
| 1134 | + print(str(new_sm_version)) |
| 1135 | + versions = ["1.0.0", "2.9.1", "2.16.0"] |
| 1136 | + manifest = [ |
| 1137 | + { |
| 1138 | + "model_id": "test-model", |
| 1139 | + "version": version, |
| 1140 | + "min_version": "2.49.0", |
| 1141 | + "spec_key": "spec_key", |
| 1142 | + } |
| 1143 | + for version in versions |
| 1144 | + ] |
| 1145 | + |
| 1146 | + manifest.append( |
| 1147 | + { |
| 1148 | + "model_id": "test-model", |
| 1149 | + "version": "3.0.0", |
| 1150 | + "min_version": str(new_sm_version), |
| 1151 | + "spec_key": "spec_key", |
| 1152 | + } |
| 1153 | + ) |
| 1154 | + |
| 1155 | + manifest_dict = {} |
| 1156 | + for header in manifest: |
| 1157 | + header_obj = JumpStartModelHeader(header) |
| 1158 | + manifest_dict[JumpStartVersionedModelId(header_obj.model_id, header_obj.version)] = ( |
| 1159 | + header_obj |
| 1160 | + ) |
| 1161 | + retrieval_function.return_value = JumpStartCachedContentValue(formatted_content=manifest_dict) |
| 1162 | + key = JumpStartVersionedModelId("test-model", "*") |
| 1163 | + |
| 1164 | + cache = JumpStartModelsCache(s3_bucket_name="some_bucket") |
| 1165 | + result = cache._get_open_weight_manifest_key_from_model_id(key=key, value=None) |
| 1166 | + |
| 1167 | + assert_key = JumpStartVersionedModelId("test-model", "2.16.0") |
| 1168 | + |
| 1169 | + assert result == assert_key |
| 1170 | + |
| 1171 | + |
| 1172 | +@patch.object(JumpStartModelsCache, "_retrieval_function") |
| 1173 | +def test_jumpstart_cache_handles_versioning_correctly_for_proprietary_weights( |
| 1174 | + retrieval_function: Mock, |
| 1175 | +): |
| 1176 | + sm_version = Version(utils.get_sagemaker_version()) |
| 1177 | + new_sm_version = Version(str(sm_version.major + 1) + ".0.0") |
| 1178 | + print(str(new_sm_version)) |
| 1179 | + versions = ["1.0.0", "2.9.1", "2.16.0"] |
| 1180 | + manifest = [ |
| 1181 | + { |
| 1182 | + "model_id": "test-model", |
| 1183 | + "version": version, |
| 1184 | + "min_version": "2.49.0", |
| 1185 | + "spec_key": "spec_key", |
| 1186 | + } |
| 1187 | + for version in versions |
| 1188 | + ] |
| 1189 | + |
| 1190 | + manifest.append( |
| 1191 | + { |
| 1192 | + "model_id": "test-model", |
| 1193 | + "version": "3.0.0", |
| 1194 | + "min_version": str(new_sm_version), |
| 1195 | + "spec_key": "spec_key", |
| 1196 | + } |
| 1197 | + ) |
| 1198 | + |
| 1199 | + manifest_dict = {} |
| 1200 | + for header in manifest: |
| 1201 | + header_obj = JumpStartModelHeader(header) |
| 1202 | + manifest_dict[JumpStartVersionedModelId(header_obj.model_id, header_obj.version)] = ( |
| 1203 | + header_obj |
| 1204 | + ) |
| 1205 | + retrieval_function.return_value = JumpStartCachedContentValue(formatted_content=manifest_dict) |
| 1206 | + key = JumpStartVersionedModelId("test-model", "*") |
| 1207 | + |
| 1208 | + cache = JumpStartModelsCache(s3_bucket_name="some_bucket") |
| 1209 | + result = cache._get_proprietary_manifest_key_from_model_id(key=key, value=None) |
| 1210 | + |
| 1211 | + assert_key = JumpStartVersionedModelId("test-model", "2.16.0") |
| 1212 | + |
| 1213 | + assert result == assert_key |
| 1214 | + |
| 1215 | + |
| 1216 | +@patch.object(JumpStartModelsCache, "_retrieval_function") |
| 1217 | +def test_jumpstart_cache_handles_versioning_correctly_non_sem_ver(retrieval_function: Mock): |
| 1218 | + sm_version = Version(utils.get_sagemaker_version()) |
| 1219 | + new_sm_version = Version(str(sm_version.major + 1) + ".0.0") |
| 1220 | + print(str(new_sm_version)) |
| 1221 | + versions = ["abc", "2.9.1", "2.16.0"] |
| 1222 | + manifest = [ |
| 1223 | + { |
| 1224 | + "model_id": "test-model", |
| 1225 | + "version": version, |
| 1226 | + "min_version": "2.49.0", |
| 1227 | + "spec_key": "spec_key", |
| 1228 | + } |
| 1229 | + for version in versions |
| 1230 | + ] |
| 1231 | + |
| 1232 | + manifest_dict = {} |
| 1233 | + for header in manifest: |
| 1234 | + header_obj = JumpStartModelHeader(header) |
| 1235 | + manifest_dict[JumpStartVersionedModelId(header_obj.model_id, header_obj.version)] = ( |
| 1236 | + header_obj |
| 1237 | + ) |
| 1238 | + retrieval_function.return_value = JumpStartCachedContentValue(formatted_content=manifest_dict) |
| 1239 | + key = JumpStartVersionedModelId("test-model", "*") |
| 1240 | + |
| 1241 | + cache = JumpStartModelsCache(s3_bucket_name="some_bucket") |
| 1242 | + result = cache._get_open_weight_manifest_key_from_model_id(key=key, value=None) |
| 1243 | + |
| 1244 | + assert_key = JumpStartVersionedModelId("test-model", "abc") |
| 1245 | + |
| 1246 | + assert result == assert_key |
0 commit comments