Skip to content

Commit 8df4478

Browse files
committed
add hub and hubcontent support in retrieval function for jumpstart model cache
1 parent ecd1f97 commit 8df4478

File tree

2 files changed

+35
-0
lines changed

2 files changed

+35
-0
lines changed

src/sagemaker/jumpstart/cache.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -467,6 +467,26 @@ def get_specs(self, model_id: str, semantic_version_str: str) -> JumpStartModelS
467467
)
468468
)
469469
return specs.formatted_content
470+
471+
def get_hub_model(self, hub_model_arn: str) -> JumpStartModelSpecs:
472+
"""Return JumpStart-compatible specs for a given Hub model
473+
474+
Args:
475+
hub_model_arn (str): Arn for the Hub model to get specs for
476+
"""
477+
478+
specs, _ = self._content_cache.get(JumpStartCachedContentKey(HubDataType.MODEL, hub_model_arn))
479+
return specs.formatted_content
480+
481+
def get_hub(self, hub_arn: str) -> Dict[str, Any]:
482+
"""Return descriptive info for a given Hub
483+
484+
Args:
485+
hub_arn (str): Arn for the Hub to get info for
486+
"""
487+
488+
manifest, _ = self._content_cache.get(JumpStartCachedContentKey(HubDataType.HUB, hub_arn))
489+
return manifest.formatted_content
470490

471491
def get_hub_model(self, hub_model_arn: str) -> JumpStartModelSpecs:
472492
"""Return JumpStart-compatible specs for a given Hub model
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from botocore.config import Config
14+
15+
DEFAULT_CLIENT_CONFIG = Config(retries={"max_attempts": 10, "mode": "standard"})

0 commit comments

Comments
 (0)