Skip to content

Commit a8b718e

Browse files
committed
fix merge artifact
1 parent 9c6370c commit a8b718e

File tree

2 files changed

+24
-32
lines changed

2 files changed

+24
-32
lines changed

src/sagemaker/jumpstart/hub/utils.py

Lines changed: 24 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -219,43 +219,42 @@ def get_hub_model_version(
219219
except Exception as ex:
220220
raise Exception(f"Failed calling list_hub_content_versions: {str(ex)}")
221221

222-
<<<<<<< HEAD
223-
<<<<<<< HEAD
224-
=======
225-
marketplace_hub_content_version = _get_hub_model_version_for_marketplace_version(
226-
hub_content_summaries, hub_model_version
227-
)
228222

229-
>>>>>>> ff3eae05 (feat: Adding Bedrock Store model support for HubService (#1539))
230-
=======
231-
>>>>>>> 42acb4f4 (chore: Merge from main (#1600))
223+
def get_hub_model_version(
224+
hub_name: str,
225+
hub_model_name: str,
226+
hub_model_type: str,
227+
hub_model_version: Optional[str] = None,
228+
sagemaker_session: Session = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
229+
) -> str:
230+
"""Returns available Jumpstart hub model version.
231+
232+
It will attempt both a semantic HubContent version search and Marketplace version search.
233+
If the Marketplace version is also semantic, this function will default to HubContent version.
234+
235+
Raises:
236+
ClientError: If the specified model is not found in the hub.
237+
KeyError: If the specified model version is not found.
238+
"""
239+
240+
try:
241+
hub_content_summaries = sagemaker_session.list_hub_content_versions(
242+
hub_name=hub_name, hub_content_name=hub_model_name, hub_content_type=hub_model_type
243+
).get("HubContentSummaries")
244+
except Exception as ex:
245+
raise Exception(f"Failed calling list_hub_content_versions: {str(ex)}")
246+
232247
try:
233248
return _get_hub_model_version_for_open_weight_version(
234249
hub_content_summaries, hub_model_version
235250
)
236-
<<<<<<< HEAD
237-
<<<<<<< HEAD
238-
=======
239-
>>>>>>> 42acb4f4 (chore: Merge from main (#1600))
240251
except KeyError:
241252
marketplace_hub_content_version = _get_hub_model_version_for_marketplace_version(
242253
hub_content_summaries, hub_model_version
243254
)
244-
<<<<<<< HEAD
245-
if marketplace_hub_content_version:
246-
return marketplace_hub_content_version
247-
raise
248-
=======
249-
except KeyError as e:
250-
if marketplace_hub_content_version:
251-
return marketplace_hub_content_version
252-
raise e
253-
>>>>>>> ff3eae05 (feat: Adding Bedrock Store model support for HubService (#1539))
254-
=======
255255
if marketplace_hub_content_version:
256256
return marketplace_hub_content_version
257257
raise
258-
>>>>>>> 42acb4f4 (chore: Merge from main (#1600))
259258

260259

261260
def _get_hub_model_version_for_open_weight_version(

tests/integ/sagemaker/jumpstart/utils.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -53,9 +53,6 @@ def get_sm_session() -> Session:
5353
return Session(boto_session=boto3.Session(region_name=JUMPSTART_DEFAULT_REGION_NAME))
5454

5555

56-
<<<<<<< HEAD
57-
<<<<<<< HEAD
58-
=======
5956
def get_sm_session_with_override() -> Session:
6057
# [TODO]: Remove service endpoint override before GA
6158
# boto3.set_stream_logger(name='botocore', level=logging.DEBUG)
@@ -69,10 +66,6 @@ def get_sm_session_with_override() -> Session:
6966
sagemaker_client=sagemaker,
7067
)
7168

72-
73-
>>>>>>> ff3eae05 (feat: Adding Bedrock Store model support for HubService (#1539))
74-
=======
75-
>>>>>>> 42acb4f4 (chore: Merge from main (#1600))
7669
def get_training_dataset_for_model_and_version(model_id: str, version: str) -> dict:
7770
return TRAINING_DATASET_MODEL_DICT[(model_id, version)]
7871

0 commit comments

Comments
 (0)