21
21
import botocore
22
22
from packaging .version import Version
23
23
from packaging .specifiers import SpecifierSet , InvalidSpecifier
24
+ from sagemaker .session import Session
25
+ from sagemaker .utilities .cache import LRUCache
24
26
from sagemaker .jumpstart .constants import (
25
27
ENV_VARIABLE_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE ,
26
28
ENV_VARIABLE_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE ,
27
29
JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY ,
28
30
JUMPSTART_DEFAULT_REGION_NAME ,
29
31
JUMPSTART_LOGGER ,
30
32
MODEL_ID_LIST_WEB_URL ,
33
+ DEFAULT_JUMPSTART_SAGEMAKER_SESSION ,
31
34
)
32
- from sagemaker .jumpstart .curated_hub .curated_hub import CuratedHub
33
- from sagemaker .jumpstart .curated_hub .utils import get_info_from_hub_resource_arn
34
35
from sagemaker .jumpstart .exceptions import get_wildcard_model_version_msg
35
36
from sagemaker .jumpstart .parameters import (
36
37
JUMPSTART_DEFAULT_MAX_S3_CACHE_ITEMS ,
37
38
JUMPSTART_DEFAULT_MAX_SEMANTIC_VERSION_CACHE_ITEMS ,
38
39
JUMPSTART_DEFAULT_S3_CACHE_EXPIRATION_HORIZON ,
39
40
JUMPSTART_DEFAULT_SEMANTIC_VERSION_CACHE_EXPIRATION_HORIZON ,
40
41
)
42
+ from sagemaker .jumpstart import utils
41
43
from sagemaker .jumpstart .types import (
42
44
JumpStartCachedContentKey ,
43
45
JumpStartCachedContentValue ,
44
46
JumpStartModelHeader ,
45
47
JumpStartModelSpecs ,
46
48
JumpStartS3FileType ,
47
49
JumpStartVersionedModelId ,
50
+ DescribeHubResponse ,
51
+ DescribeHubContentsResponse ,
52
+ HubType ,
48
53
HubContentType ,
49
54
)
50
- from sagemaker .jumpstart import utils
51
- from sagemaker .utilities .cache import LRUCache
55
+ from sagemaker .jumpstart .curated_hub import utils as hub_utils
52
56
53
57
54
58
class JumpStartModelsCache :
@@ -74,6 +78,7 @@ def __init__(
74
78
s3_bucket_name : Optional [str ] = None ,
75
79
s3_client_config : Optional [botocore .config .Config ] = None ,
76
80
s3_client : Optional [boto3 .client ] = None ,
81
+ sagemaker_session : Optional [Session ] = DEFAULT_JUMPSTART_SAGEMAKER_SESSION ,
77
82
) -> None : # fmt: on
78
83
"""Initialize a ``JumpStartModelsCache`` instance.
79
84
@@ -95,6 +100,8 @@ def __init__(
95
100
s3_client_config (Optional[botocore.config.Config]): s3 client config to use for cache.
96
101
Default: None (no config).
97
102
s3_client (Optional[boto3.client]): s3 client to use. Default: None.
103
+ sagemaker_session (Optional[sagemaker.session.Session]): A SageMaker Session object,
104
+ used for SageMaker interactions. Default: Session in region associated with boto3 session.
98
105
"""
99
106
100
107
self ._region = region
@@ -121,6 +128,7 @@ def __init__(
121
128
if s3_client_config
122
129
else boto3 .client ("s3" , region_name = self ._region )
123
130
)
131
+ self ._sagemaker_session = sagemaker_session
124
132
125
133
def set_region (self , region : str ) -> None :
126
134
"""Set region for cache. Clears cache after new region is set."""
@@ -340,32 +348,34 @@ def _retrieval_function(
340
348
formatted_content = model_specs
341
349
)
342
350
if data_type == HubContentType .MODEL :
343
- info = get_info_from_hub_resource_arn (
351
+ hub_name , _ , model_name , model_version = hub_utils . get_info_from_hub_resource_arn (
344
352
id_info
345
353
)
346
- hub = CuratedHub (hub_name = info .hub_name , region = info .region )
347
- hub_content = hub .describe_model (
348
- model_name = info .hub_content_name , model_version = info .hub_content_version
354
+ hub_model_description : Dict [str , Any ] = self ._sagemaker_session .describe_hub_content (
355
+ hub_name = hub_name ,
356
+ hub_content_name = model_name ,
357
+ hub_content_version = model_version ,
358
+ hub_content_type = data_type
349
359
)
360
+
361
+ model_specs = JumpStartModelSpecs (DescribeHubContentsResponse (hub_model_description ), is_hub_content = True )
362
+
350
363
utils .emit_logs_based_on_model_specs (
351
- hub_content . content_document ,
364
+ model_specs ,
352
365
self .get_region (),
353
366
self ._s3_client
354
367
)
355
- model_specs = JumpStartModelSpecs (hub_content .content_document , is_hub_content = True )
356
368
return JumpStartCachedContentValue (
357
369
formatted_content = model_specs
358
370
)
359
- if data_type == HubContentType .HUB :
360
- info = get_info_from_hub_resource_arn (
361
- id_info
362
- )
363
- hub = CuratedHub (hub_name = info .hub_name , region = info .region )
364
- hub_info = hub .describe ()
365
- return JumpStartCachedContentValue (formatted_content = hub_info )
371
+ if data_type == HubType .HUB :
372
+ hub_name , _ , _ , _ = hub_utils .get_info_from_hub_resource_arn (id_info )
373
+ response : Dict [str , Any ] = self ._sagemaker_session .describe_hub (hub_name = hub_name )
374
+ hub_description = DescribeHubResponse (response )
375
+ return JumpStartCachedContentValue (formatted_content = DescribeHubResponse (hub_description ))
366
376
raise ValueError (
367
- f"Bad value for key '{ key } ': must be in" ,
368
- f"{ [JumpStartS3FileType .MANIFEST , JumpStartS3FileType .SPECS , HubContentType .HUB , HubContentType .MODEL ]} "
377
+ f"Bad value for key '{ key } ': must be in " ,
378
+ f"{ [JumpStartS3FileType .MANIFEST , JumpStartS3FileType .SPECS , HubType .HUB , HubContentType .MODEL ]} "
369
379
)
370
380
371
381
def get_manifest (self ) -> List [JumpStartModelHeader ]:
@@ -490,7 +500,7 @@ def get_hub(self, hub_arn: str) -> Dict[str, Any]:
490
500
hub_arn (str): Arn for the Hub to get info for
491
501
"""
492
502
493
- details , _ = self ._content_cache .get (JumpStartCachedContentKey (HubContentType .HUB , hub_arn ))
503
+ details , _ = self ._content_cache .get (JumpStartCachedContentKey (HubType .HUB , hub_arn ))
494
504
return details .formatted_content
495
505
496
506
def clear (self ) -> None :
0 commit comments