22
22
JUMPSTART_REGION_NAME_SET ,
23
23
)
24
24
from sagemaker .jumpstart .types import (
25
+ HubDataType ,
25
26
JumpStartCachedContentKey ,
26
27
JumpStartCachedContentValue ,
27
28
JumpStartModelSpecs ,
@@ -218,12 +219,10 @@ def patched_retrieval_function(
218
219
datatype , id_info = key .data_type , key .id_info
219
220
if datatype == JumpStartS3FileType .OPEN_WEIGHT_MANIFEST :
220
221
221
- return JumpStartCachedContentValue (
222
- formatted_content = get_formatted_manifest (BASE_MANIFEST )
223
- )
222
+ return JumpStartCachedContentValue (formatted_content = get_formatted_manifest (BASE_MANIFEST ))
224
223
225
- if datatype == JumpStartCachedContentValue .OPEN_WEIGHT_SPECS :
226
- _ , model_id , specs_version = s3_key .split ("/" )
224
+ if datatype == JumpStartS3FileType .OPEN_WEIGHT_SPECS :
225
+ _ , model_id , specs_version = id_info .split ("/" )
227
226
version = specs_version .replace ("specs_v" , "" ).replace (".json" , "" )
228
227
return JumpStartCachedContentValue (
229
228
formatted_content = get_spec_from_base_spec (model_id = model_id , version = version )
@@ -245,7 +244,7 @@ def patched_retrieval_function(
245
244
)
246
245
247
246
if datatype == JumpStartS3FileType .PROPRIETARY_SPECS :
248
- _ , model_id , specs_version = s3_key .split ("/" )
247
+ _ , model_id , specs_version = id_info .split ("/" )
249
248
version = specs_version .replace ("proprietary_specs_" , "" ).replace (".json" , "" )
250
249
return JumpStartCachedContentValue (
251
250
formatted_content = get_spec_from_base_spec (
0 commit comments