|
13 | 13 | """This module provides the JumpStart Curated Hub class."""
|
14 | 14 | from __future__ import absolute_import
|
15 | 15 | from datetime import datetime
|
16 |
| -from typing import Optional, Dict, List, Any |
| 16 | +from typing import Optional, Dict, List, Any, Tuple, Union, Set |
17 | 17 | from botocore import exceptions
|
18 | 18 |
|
19 | 19 | from sagemaker.jumpstart.hub.constants import JUMPSTART_MODEL_HUB_NAME
|
|
27 | 27 | from sagemaker.jumpstart.types import (
|
28 | 28 | HubContentType,
|
29 | 29 | )
|
| 30 | +from sagemaker.jumpstart.filters import Constant, ModelFilter, Operator, BooleanValues |
30 | 31 | from sagemaker.jumpstart.hub.utils import (
|
31 | 32 | create_hub_bucket_if_it_does_not_exist,
|
32 | 33 | generate_default_hub_bucket_name,
|
33 | 34 | create_s3_object_reference_from_uri,
|
| 35 | + construct_hub_arn_from_name, |
| 36 | + construct_hub_model_arn_from_inputs |
| 37 | +) |
| 38 | + |
| 39 | +from sagemaker.jumpstart.notebook_utils import ( |
| 40 | + list_jumpstart_models, |
34 | 41 | )
|
35 | 42 |
|
36 | 43 | from sagemaker.jumpstart.hub.types import (
|
@@ -158,25 +165,35 @@ def list_models(self, clear_cache: bool = True, **kwargs) -> List[Dict[str, Any]
|
158 | 165 | self._list_hubs_cache = hub_content_summaries
|
159 | 166 | return self._list_hubs_cache
|
160 | 167 |
|
161 |
| - # TODO: Update to use S3 source for listing the public models |
162 |
| - def list_jumpstart_service_hub_models(self, filter_name: Optional[str] = None, clear_cache: bool = True, **kwargs) -> List[Dict[str, Any]]: |
163 |
| - """Lists the models from AmazonSageMakerJumpStart Public Hub. |
164 |
| -
|
165 |
| - This function caches the models in local memory |
| 168 | + def list_jumpstart_service_hub_models(self, filter: Union[Operator, str] = Constant(BooleanValues.TRUE)) -> Dict[str, str]: |
| 169 | + """Lists the models and model arns from AmazonSageMakerJumpStart Public Hub. |
166 | 170 |
|
167 |
| - **kwargs: Passed to invocation of ``Session:list_hub_contents``. |
| 171 | + Args: |
| 172 | + filter (Union[Operator, str]): Optional. The filter to apply to list models. This can be |
| 173 | + either an ``Operator`` type filter (e.g. ``And("task == ic", "framework == pytorch")``), |
| 174 | + or simply a string filter which will get serialized into an Identity filter. |
| 175 | + (e.g. ``"task == ic"``). If this argument is not supplied, all models will be listed. |
| 176 | + (Default: Constant(BooleanValues.TRUE)). |
168 | 177 | """
|
169 |
| - if clear_cache: |
170 |
| - self._list_hubs_cache = None |
171 |
| - if self._list_hubs_cache is None: |
172 |
| - hub_content_summaries = self._sagemaker_session.list_hub_contents( |
173 |
| - hub_name=JUMPSTART_MODEL_HUB_NAME, |
174 |
| - hub_content_type=HubContentType.MODEL_REFERENCE.value, |
175 |
| - name_contains=filter_name, |
176 |
| - **kwargs |
| 178 | + |
| 179 | + jumpstart_public_models = {} |
| 180 | + |
| 181 | + jumpstart_public_hub_arn = construct_hub_arn_from_name( |
| 182 | + JUMPSTART_MODEL_HUB_NAME, |
| 183 | + self.region, |
| 184 | + self._sagemaker_session |
177 | 185 | )
|
178 |
| - self._list_hubs_cache = hub_content_summaries |
179 |
| - return self._list_hubs_cache |
| 186 | + |
| 187 | + models = list_jumpstart_models(filter) |
| 188 | + for model in models: |
| 189 | + if len(model[0])<=63: |
| 190 | + jumpstart_public_models[model[0]] = construct_hub_model_arn_from_inputs( |
| 191 | + jumpstart_public_hub_arn, |
| 192 | + model[0], |
| 193 | + model[1] |
| 194 | + ) |
| 195 | + |
| 196 | + return jumpstart_public_models |
180 | 197 |
|
181 | 198 | def delete(self) -> None:
|
182 | 199 | """Deletes this Curated Hub"""
|
|
0 commit comments