15
15
import datetime
16
16
from difflib import get_close_matches
17
17
import os
18
- from typing import List , Optional , Tuple , Union
18
+ from typing import Any , Dict , List , Optional , Tuple , Union
19
19
import json
20
20
import boto3
21
21
import botocore
43
43
JUMPSTART_DEFAULT_SEMANTIC_VERSION_CACHE_EXPIRATION_HORIZON ,
44
44
)
45
45
from sagemaker .jumpstart .types import (
46
- JumpStartCachedS3ContentKey ,
47
- JumpStartCachedS3ContentValue ,
46
+ JumpStartCachedContentKey ,
47
+ JumpStartCachedContentValue ,
48
48
JumpStartModelHeader ,
49
49
JumpStartModelSpecs ,
50
50
JumpStartS3FileType ,
51
51
JumpStartVersionedModelId ,
52
+ HubDataType ,
52
53
)
53
54
from sagemaker .jumpstart .enums import JumpStartModelType
54
55
from sagemaker .jumpstart import utils
@@ -103,7 +104,7 @@ def __init__(
103
104
"""
104
105
105
106
self ._region = region
106
- self ._s3_cache = LRUCache [JumpStartCachedS3ContentKey , JumpStartCachedS3ContentValue ](
107
+ self ._content_cache = LRUCache [JumpStartCachedContentKey , JumpStartCachedContentValue ](
107
108
max_cache_items = max_s3_cache_items ,
108
109
expiration_horizon = s3_cache_expiration_horizon ,
109
110
retrieval_function = self ._retrieval_function ,
@@ -234,7 +235,7 @@ def _model_id_retrieval_function(
234
235
model_id , version = key .model_id , key .version
235
236
sm_version = utils .get_sagemaker_version ()
236
237
manifest = self ._s3_cache .get (
237
- JumpStartCachedS3ContentKey (
238
+ JumpStartCachedContentKey (
238
239
MODEL_TYPE_TO_MANIFEST_MAP [model_type ], self ._manifest_file_s3_map [model_type ])
239
240
)[0 ].formatted_content
240
241
@@ -399,46 +400,69 @@ def _get_json_file_from_local_override(
399
400
400
401
def _retrieval_function (
401
402
self ,
402
- key : JumpStartCachedS3ContentKey ,
403
- value : Optional [JumpStartCachedS3ContentValue ],
404
- ) -> JumpStartCachedS3ContentValue :
405
- """Return s3 content given a file type and s3_key in ``JumpStartCachedS3ContentKey ``.
403
+ key : JumpStartCachedContentKey ,
404
+ value : Optional [JumpStartCachedContentValue ],
405
+ ) -> JumpStartCachedContentValue :
406
+ """Return s3 content given a data type and s3_key in ``JumpStartCachedContentKey ``.
406
407
407
408
If a manifest file is being fetched, we only download the object if the md5 hash in
408
409
``head_object`` does not match the current md5 hash for the stored value. This prevents
409
410
unnecessarily downloading the full manifest when it hasn't changed.
410
411
411
412
Args:
412
- key (JumpStartCachedS3ContentKey ): key for which to fetch s3 content.
413
+ key (JumpStartCachedContentKey ): key for which to fetch JumpStart content.
413
414
value (Optional[JumpStartVersionedModelId]): Current value of old cached
414
415
s3 content. This is used for the manifest file, so that it is only
415
416
downloaded when its content changes.
416
417
"""
417
418
418
- file_type , s3_key = key .file_type , key .s3_key
419
- if file_type in {
419
+ data_type , id_info = key .data_type , key .id_info
420
+
421
+ if data_type in {
420
422
JumpStartS3FileType .OPEN_WEIGHT_MANIFEST ,
421
423
JumpStartS3FileType .PROPRIETARY_MANIFEST ,
422
424
}:
423
425
if value is not None and not self ._is_local_metadata_mode ():
424
- etag = self ._get_json_md5_hash (s3_key )
426
+ etag = self ._get_json_md5_hash (id_info )
425
427
if etag == value .md5_hash :
426
428
return value
427
- formatted_body , etag = self ._get_json_file (s3_key , file_type )
428
- return JumpStartCachedS3ContentValue (
429
+ formatted_body , etag = self ._get_json_file (id_info , data_type )
430
+ return JumpStartCachedContentValue (
429
431
formatted_content = utils .get_formatted_manifest (formatted_body ),
430
432
md5_hash = etag ,
431
433
)
432
- if file_type in {
434
+ if data_type in {
433
435
JumpStartS3FileType .OPEN_WEIGHT_SPECS ,
434
436
JumpStartS3FileType .PROPRIETARY_SPECS ,
435
437
}:
436
- formatted_body , _ = self ._get_json_file (s3_key , file_type )
438
+ formatted_body , _ = self ._get_json_file (id_info , data_type )
437
439
model_specs = JumpStartModelSpecs (formatted_body )
438
440
utils .emit_logs_based_on_model_specs (model_specs , self .get_region (), self ._s3_client )
439
- return JumpStartCachedS3ContentValue (formatted_content = model_specs )
441
+ return JumpStartCachedContentValue (formatted_content = model_specs )
442
+
443
+ if data_type == HubDataType .MODEL :
444
+ hub_name , region , model_name , model_version = utils .extract_info_from_hub_content_arn (
445
+ id_info
446
+ )
447
+ hub = CuratedHub (hub_name = hub_name , region = region )
448
+ hub_content = hub .describe_model (model_name = model_name , model_version = model_version )
449
+ utils .emit_logs_based_on_model_specs (
450
+ hub_content .content_document ,
451
+ self .get_region (),
452
+ self ._s3_client
453
+ )
454
+ model_specs = JumpStartModelSpecs (hub_content .content_document , is_hub_content = True )
455
+ return JumpStartCachedContentValue (
456
+ formatted_content = model_specs
457
+ )
458
+ if data_type == HubDataType .HUB :
459
+ hub_name , region , _ , _ = utils .extract_info_from_hub_content_arn (id_info )
460
+ hub = CuratedHub (hub_name = hub_name , region = region )
461
+ hub_info = hub .describe ()
462
+ return JumpStartCachedContentValue (formatted_content = hub_info )
463
+
440
464
raise ValueError (
441
- self ._file_type_error_msg (file_type )
465
+ self ._file_type_error_msg (data_type )
442
466
)
443
467
444
468
def get_manifest (
@@ -447,7 +471,7 @@ def get_manifest(
447
471
) -> List [JumpStartModelHeader ]:
448
472
"""Return entire JumpStart models manifest."""
449
473
manifest_dict = self ._s3_cache .get (
450
- JumpStartCachedS3ContentKey (
474
+ JumpStartCachedContentKey (
451
475
MODEL_TYPE_TO_MANIFEST_MAP [model_type ], self ._manifest_file_s3_map [model_type ])
452
476
)[0 ].formatted_content
453
477
manifest = list (manifest_dict .values ()) # type: ignore
@@ -536,7 +560,7 @@ def _get_header_impl(
536
560
)[0 ]
537
561
538
562
manifest = self ._s3_cache .get (
539
- JumpStartCachedS3ContentKey (
563
+ JumpStartCachedContentKey (
540
564
MODEL_TYPE_TO_MANIFEST_MAP [model_type ], self ._manifest_file_s3_map [model_type ])
541
565
)[0 ].formatted_content
542
566
@@ -566,7 +590,7 @@ def get_specs(
566
590
header = self .get_header (model_id , version_str , model_type )
567
591
spec_key = header .spec_key
568
592
specs , cache_hit = self ._s3_cache .get (
569
- JumpStartCachedS3ContentKey (
593
+ JumpStartCachedContentKey (
570
594
MODEL_TYPE_TO_SPECS_MAP [model_type ], spec_key
571
595
)
572
596
)
@@ -579,8 +603,30 @@ def get_specs(
579
603
)
580
604
return specs .formatted_content
581
605
606
+ def get_hub_model (self , hub_model_arn : str ) -> JumpStartModelSpecs :
607
+ """Return JumpStart-compatible specs for a given Hub model
608
+
609
+ Args:
610
+ hub_model_arn (str): Arn for the Hub model to get specs for
611
+ """
612
+
613
+ details , _ = self ._content_cache .get (
614
+ JumpStartCachedContentKey (HubDataType .MODEL , hub_model_arn )
615
+ )
616
+ return details .formatted_content
617
+
618
+ def get_hub (self , hub_arn : str ) -> Dict [str , Any ]:
619
+ """Return descriptive info for a given Hub
620
+
621
+ Args:
622
+ hub_arn (str): Arn for the Hub to get info for
623
+ """
624
+
625
+ details , _ = self ._content_cache .get (JumpStartCachedContentKey (HubDataType .HUB , hub_arn ))
626
+ return details .formatted_content
627
+
582
628
def clear (self ) -> None :
583
629
"""Clears the model ID/version and s3 cache."""
584
- self ._s3_cache .clear ()
630
+ self ._content_cache .clear ()
585
631
self ._open_weight_model_id_manifest_key_cache .clear ()
586
632
self ._proprietary_model_id_manifest_key_cache .clear ()
0 commit comments