@@ -339,10 +339,10 @@ def _retrieval_function(
339
339
formatted_content = model_specs
340
340
)
341
341
if data_type == HubDataType .MODEL :
342
- hub_name , hub_region , model_name , model_version = utils .extract_info_from_hub_content_arn (
342
+ hub_name , region , model_name , model_version = utils .extract_info_from_hub_content_arn (
343
343
id_info
344
344
)
345
- hub = CuratedHub (hub_name = hub_name , region = hub_region )
345
+ hub = CuratedHub (hub_name = hub_name , region = region )
346
346
hub_content = hub .describe_model (model_name = model_name , model_version = model_version )
347
347
utils .emit_logs_based_on_model_specs (
348
348
hub_content .content_document ,
@@ -354,8 +354,8 @@ def _retrieval_function(
354
354
formatted_content = model_specs
355
355
)
356
356
if data_type == HubDataType .HUB :
357
- hub_name , hub_region , _ , _ = utils .extract_info_from_hub_content_arn (id_info )
358
- hub = CuratedHub (hub_name = hub_name , region = hub_region )
357
+ hub_name , region , _ , _ = utils .extract_info_from_hub_content_arn (id_info )
358
+ hub = CuratedHub (hub_name = hub_name , region = region )
359
359
hub_info = hub .describe ()
360
360
return JumpStartCachedContentValue (formatted_content = hub_info )
361
361
raise ValueError (
@@ -465,10 +465,10 @@ def get_specs(self, model_id: str, semantic_version_str: str) -> JumpStartModelS
465
465
)
466
466
)
467
467
return specs .formatted_content
468
-
468
+
469
469
def get_hub_model (self , hub_model_arn : str ) -> JumpStartModelSpecs :
470
470
"""Return JumpStart-compatible specs for a given Hub model
471
-
471
+
472
472
Args:
473
473
hub_model_arn (str): Arn for the Hub model to get specs for
474
474
"""
@@ -477,14 +477,14 @@ def get_hub_model(self, hub_model_arn: str) -> JumpStartModelSpecs:
477
477
JumpStartCachedContentKey (HubDataType .MODEL , hub_model_arn )
478
478
)
479
479
return details .formatted_content
480
-
480
+
481
481
def get_hub (self , hub_arn : str ) -> Dict [str , Any ]:
482
482
"""Return descriptive info for a given Hub
483
-
483
+
484
484
Args:
485
485
hub_arn (str): Arn for the Hub to get info for
486
486
"""
487
-
487
+
488
488
details , _ = self ._content_cache .get (JumpStartCachedContentKey (HubDataType .HUB , hub_arn ))
489
489
return details .formatted_content
490
490
0 commit comments