21
21
import botocore
22
22
from packaging .version import Version
23
23
from packaging .specifiers import SpecifierSet , InvalidSpecifier
24
+ from sagemaker .session import Session
25
+ from sagemaker .utilities .cache import LRUCache
24
26
from sagemaker .jumpstart .constants import (
25
27
ENV_VARIABLE_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE ,
26
28
ENV_VARIABLE_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE ,
31
33
MODEL_ID_LIST_WEB_URL ,
32
34
MODEL_TYPE_TO_MANIFEST_MAP ,
33
35
MODEL_TYPE_TO_SPECS_MAP ,
36
+ DEFAULT_JUMPSTART_SAGEMAKER_SESSION ,
34
37
)
35
38
from sagemaker .jumpstart .exceptions import (
36
39
get_wildcard_model_version_msg ,
37
40
get_wildcard_proprietary_model_version_msg ,
38
41
)
39
- from sagemaker .jumpstart .curated_hub .curated_hub import CuratedHub
40
- from sagemaker .jumpstart .curated_hub .utils import get_info_from_hub_resource_arn
41
42
from sagemaker .jumpstart .exceptions import get_wildcard_model_version_msg
42
43
from sagemaker .jumpstart .parameters import (
43
44
JUMPSTART_DEFAULT_MAX_S3_CACHE_ITEMS ,
44
45
JUMPSTART_DEFAULT_MAX_SEMANTIC_VERSION_CACHE_ITEMS ,
45
46
JUMPSTART_DEFAULT_S3_CACHE_EXPIRATION_HORIZON ,
46
47
JUMPSTART_DEFAULT_SEMANTIC_VERSION_CACHE_EXPIRATION_HORIZON ,
47
48
)
49
+ from sagemaker .jumpstart import utils
48
50
from sagemaker .jumpstart .types import (
49
51
JumpStartCachedContentKey ,
50
52
JumpStartCachedContentValue ,
51
53
JumpStartModelHeader ,
52
54
JumpStartModelSpecs ,
53
55
JumpStartS3FileType ,
54
56
JumpStartVersionedModelId ,
57
+ DescribeHubResponse ,
58
+ DescribeHubContentsResponse ,
59
+ HubType ,
55
60
HubContentType ,
56
61
)
57
62
from sagemaker .jumpstart .enums import JumpStartModelType
58
63
from sagemaker .jumpstart import utils
59
64
from sagemaker .utilities .cache import LRUCache
65
+ from sagemaker .jumpstart .curated_hub import utils as hub_utils
60
66
61
67
62
68
class JumpStartModelsCache :
@@ -83,6 +89,7 @@ def __init__(
83
89
s3_bucket_name : Optional [str ] = None ,
84
90
s3_client_config : Optional [botocore .config .Config ] = None ,
85
91
s3_client : Optional [boto3 .client ] = None ,
92
+ sagemaker_session : Optional [Session ] = DEFAULT_JUMPSTART_SAGEMAKER_SESSION ,
86
93
) -> None : # fmt: on
87
94
"""Initialize a ``JumpStartModelsCache`` instance.
88
95
@@ -104,6 +111,8 @@ def __init__(
104
111
s3_client_config (Optional[botocore.config.Config]): s3 client config to use for cache.
105
112
Default: None (no config).
106
113
s3_client (Optional[boto3.client]): s3 client to use. Default: None.
114
+ sagemaker_session (Optional[sagemaker.session.Session]): A SageMaker Session object,
115
+ used for SageMaker interactions. Default: Session in region associated with boto3 session.
107
116
"""
108
117
109
118
self ._region = region
@@ -142,6 +151,7 @@ def __init__(
142
151
if s3_client_config
143
152
else boto3 .client ("s3" , region_name = self ._region )
144
153
)
154
+ self ._sagemaker_session = sagemaker_session
145
155
146
156
def set_region (self , region : str ) -> None :
147
157
"""Set region for cache. Clears cache after new region is set."""
@@ -445,30 +455,31 @@ def _retrieval_function(
445
455
formatted_content = model_specs
446
456
)
447
457
if data_type == HubContentType .MODEL :
448
- info = get_info_from_hub_resource_arn (
458
+ hub_name , _ , model_name , model_version = hub_utils . get_info_from_hub_resource_arn (
449
459
id_info
450
460
)
451
- hub = CuratedHub (hub_name = info .hub_name , region = info .region )
452
- hub_content = hub .describe_model (
453
- model_name = info .hub_content_name , model_version = info .hub_content_version
461
+ hub_model_description : Dict [str , Any ] = self ._sagemaker_session .describe_hub_content (
462
+ hub_name = hub_name ,
463
+ hub_content_name = model_name ,
464
+ hub_content_version = model_version ,
465
+ hub_content_type = data_type
454
466
)
467
+
468
+ model_specs = JumpStartModelSpecs (DescribeHubContentsResponse (hub_model_description ), is_hub_content = True )
469
+
455
470
utils .emit_logs_based_on_model_specs (
456
- hub_content . content_document ,
471
+ model_specs ,
457
472
self .get_region (),
458
473
self ._s3_client
459
474
)
460
- model_specs = JumpStartModelSpecs (hub_content .content_document , is_hub_content = True )
461
475
return JumpStartCachedContentValue (
462
476
formatted_content = model_specs
463
477
)
464
- if data_type == HubContentType .HUB :
465
- info = get_info_from_hub_resource_arn (
466
- id_info
467
- )
468
- hub = CuratedHub (hub_name = info .hub_name , region = info .region )
469
- hub_info = hub .describe ()
470
- return JumpStartCachedContentValue (formatted_content = hub_info )
471
-
478
+ if data_type == HubType .HUB :
479
+ hub_name , _ , _ , _ = hub_utils .get_info_from_hub_resource_arn (id_info )
480
+ response : Dict [str , Any ] = self ._sagemaker_session .describe_hub (hub_name = hub_name )
481
+ hub_description = DescribeHubResponse (response )
482
+ return JumpStartCachedContentValue (formatted_content = DescribeHubResponse (hub_description ))
472
483
raise ValueError (
473
484
self ._file_type_error_msg (data_type )
474
485
)
@@ -630,7 +641,7 @@ def get_hub(self, hub_arn: str) -> Dict[str, Any]:
630
641
hub_arn (str): Arn for the Hub to get info for
631
642
"""
632
643
633
- details , _ = self ._content_cache .get (JumpStartCachedContentKey (HubContentType .HUB , hub_arn ))
644
+ details , _ = self ._content_cache .get (JumpStartCachedContentKey (HubType .HUB , hub_arn ))
634
645
return details .formatted_content
635
646
636
647
def clear (self ) -> None :
0 commit comments