Skip to content

Commit a39ae5f

Browse files
committed
linter
1 parent 195b84b commit a39ae5f

File tree

4 files changed

+18
-22
lines changed

4 files changed

+18
-22
lines changed

src/sagemaker/jumpstart/artifacts/model_uris.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ def _retrieve_training_artifact_key(model_specs: JumpStartModelSpecs, instance_t
8989
def _retrieve_model_uri(
9090
model_id: str,
9191
model_version: str,
92+
hub_arn: Optional[str] = None,
9293
model_scope: Optional[str] = None,
9394
instance_type: Optional[str] = None,
9495
region: Optional[str] = None,
@@ -105,6 +106,8 @@ def _retrieve_model_uri(
105106
the model artifact S3 URI.
106107
model_version (str): Version of the JumpStart model for which to retrieve the model
107108
artifact S3 URI.
109+
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
110+
model details from (default: None).
108111
model_scope (str): The model type, i.e. what it is used for.
109112
Valid values: "training" and "inference".
110113
instance_type (str): The ML compute instance type for the specified scope. (Default: None).
@@ -135,6 +138,7 @@ def _retrieve_model_uri(
135138
model_specs = verify_model_region_and_return_specs(
136139
model_id=model_id,
137140
version=model_version,
141+
hub_arn=hub_arn,
138142
scope=model_scope,
139143
region=region,
140144
tolerate_vulnerable_model=tolerate_vulnerable_model,

src/sagemaker/jumpstart/cache.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -343,7 +343,9 @@ def _retrieval_function(
343343
id_info
344344
)
345345
hub = CuratedHub(hub_name=info.hub_name, region=info.region)
346-
hub_content = hub.describe_model(model_name=info.hub_content_name, model_version=info.hub_content_version)
346+
hub_content = hub.describe_model(
347+
model_name=info.hub_content_name, model_version=info.hub_content_version
348+
)
347349
utils.emit_logs_based_on_model_specs(
348350
hub_content.content_document,
349351
self.get_region(),
@@ -467,10 +469,10 @@ def get_specs(self, model_id: str, semantic_version_str: str) -> JumpStartModelS
467469
)
468470
)
469471
return specs.formatted_content
470-
472+
471473
def get_hub_model(self, hub_model_arn: str) -> JumpStartModelSpecs:
472474
"""Return JumpStart-compatible specs for a given Hub model
473-
475+
474476
Args:
475477
hub_model_arn (str): Arn for the Hub model to get specs for
476478
"""
@@ -479,14 +481,14 @@ def get_hub_model(self, hub_model_arn: str) -> JumpStartModelSpecs:
479481
JumpStartCachedContentKey(HubDataType.MODEL, hub_model_arn)
480482
)
481483
return details.formatted_content
482-
484+
483485
def get_hub(self, hub_arn: str) -> Dict[str, Any]:
484486
"""Return descriptive info for a given Hub
485-
487+
486488
Args:
487489
hub_arn (str): Arn for the Hub to get info for
488490
"""
489-
491+
490492
details, _ = self._content_cache.get(JumpStartCachedContentKey(HubDataType.HUB, hub_arn))
491493
return details.formatted_content
492494

src/sagemaker/jumpstart/curated_hub/constants.py

Lines changed: 0 additions & 15 deletions
This file was deleted.

src/sagemaker/jumpstart/curated_hub/curated_hub.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,12 @@
2222
class CuratedHub:
2323
"""Class for creating and managing a curated JumpStart hub"""
2424

25-
def __init__(self, hub_name: str, region: str, session: Optional[Session] = DEFAULT_JUMPSTART_SAGEMAKER_SESSION):
25+
def __init__(
26+
self,
27+
hub_name: str,
28+
region: str,
29+
session: Optional[Session] = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
30+
):
2631
self.hub_name = hub_name
2732
self.region = region
2833
self._sm_session = session

0 commit comments

Comments
 (0)