Skip to content

Commit 8e69cc1

Browse files
malav-shastriMalav Shastri
and
Malav Shastri
committed
feat: implement list_jumpstart_service_hub_models function to fetch JumpStart public hub models (aws#1456)
* Implement CuratedHub Admin APIs * making some parameters optional in create_hub_content_reference as per the API design * add describe_hub and list_hubs APIs * implement delete_hub API * Implement list_hub_contents API * create CuratedHub class and supported utils * implement list_models and address comments * Add unit tests * add describe_model function * cache retrieval for describeHubContent changes * fix curated hub class unit tests * add utils needed for curatedHub * Cache retrieval * implement get_hub_model_reference() * cleanup HUB type datatype * cleanup constants * rename list_public_models to list_jumpstart_service_hub_models * implement describe_model_reference * Rename CuratedHub to Hub * address nit * address nits and fix failing tests * implement list_jumpstart_service_hub_models function --------- Co-authored-by: Malav Shastri <[email protected]>
1 parent 68d5c8a commit 8e69cc1

File tree

1 file changed

+34
-17
lines changed
  • src/sagemaker/jumpstart/hub

1 file changed

+34
-17
lines changed

src/sagemaker/jumpstart/hub/hub.py

+34-17
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
"""This module provides the JumpStart Curated Hub class."""
1414
from __future__ import absolute_import
1515
from datetime import datetime
16-
from typing import Optional, Dict, List, Any
16+
from typing import Optional, Dict, List, Any, Tuple, Union, Set
1717
from botocore import exceptions
1818

1919
from sagemaker.jumpstart.hub.constants import JUMPSTART_MODEL_HUB_NAME
@@ -27,10 +27,17 @@
2727
from sagemaker.jumpstart.types import (
2828
HubContentType,
2929
)
30+
from sagemaker.jumpstart.filters import Constant, ModelFilter, Operator, BooleanValues
3031
from sagemaker.jumpstart.hub.utils import (
3132
create_hub_bucket_if_it_does_not_exist,
3233
generate_default_hub_bucket_name,
3334
create_s3_object_reference_from_uri,
35+
construct_hub_arn_from_name,
36+
construct_hub_model_arn_from_inputs
37+
)
38+
39+
from sagemaker.jumpstart.notebook_utils import (
40+
list_jumpstart_models,
3441
)
3542

3643
from sagemaker.jumpstart.hub.types import (
@@ -158,25 +165,35 @@ def list_models(self, clear_cache: bool = True, **kwargs) -> List[Dict[str, Any]
158165
self._list_hubs_cache = hub_content_summaries
159166
return self._list_hubs_cache
160167

161-
# TODO: Update to use S3 source for listing the public models
162-
def list_jumpstart_service_hub_models(self, filter_name: Optional[str] = None, clear_cache: bool = True, **kwargs) -> List[Dict[str, Any]]:
163-
"""Lists the models from AmazonSageMakerJumpStart Public Hub.
164-
165-
This function caches the models in local memory
168+
def list_jumpstart_service_hub_models(self, filter: Union[Operator, str] = Constant(BooleanValues.TRUE)) -> Dict[str, str]:
169+
"""Lists the models and model arns from AmazonSageMakerJumpStart Public Hub.
166170
167-
**kwargs: Passed to invocation of ``Session:list_hub_contents``.
171+
Args:
172+
filter (Union[Operator, str]): Optional. The filter to apply to list models. This can be
173+
either an ``Operator`` type filter (e.g. ``And("task == ic", "framework == pytorch")``),
174+
or simply a string filter which will get serialized into an Identity filter.
175+
(e.g. ``"task == ic"``). If this argument is not supplied, all models will be listed.
176+
(Default: Constant(BooleanValues.TRUE)).
168177
"""
169-
if clear_cache:
170-
self._list_hubs_cache = None
171-
if self._list_hubs_cache is None:
172-
hub_content_summaries = self._sagemaker_session.list_hub_contents(
173-
hub_name=JUMPSTART_MODEL_HUB_NAME,
174-
hub_content_type=HubContentType.MODEL_REFERENCE.value,
175-
name_contains=filter_name,
176-
**kwargs
178+
179+
jumpstart_public_models = {}
180+
181+
jumpstart_public_hub_arn = construct_hub_arn_from_name(
182+
JUMPSTART_MODEL_HUB_NAME,
183+
self.region,
184+
self._sagemaker_session
177185
)
178-
self._list_hubs_cache = hub_content_summaries
179-
return self._list_hubs_cache
186+
187+
models = list_jumpstart_models(filter)
188+
for model in models:
189+
if len(model[0])<=63:
190+
jumpstart_public_models[model[0]] = construct_hub_model_arn_from_inputs(
191+
jumpstart_public_hub_arn,
192+
model[0],
193+
model[1]
194+
)
195+
196+
return jumpstart_public_models
180197

181198
def delete(self) -> None:
182199
"""Deletes this Curated Hub"""

0 commit comments

Comments
 (0)