Skip to content

Commit 49ae11b

Browse files
committed
update linter
1 parent ef042d9 commit 49ae11b

File tree

3 files changed

+20
-16
lines changed

3 files changed

+20
-16
lines changed

src/sagemaker/jumpstart/cache.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -339,17 +339,23 @@ def _retrieval_function(
339339
formatted_content=model_specs
340340
)
341341
if data_type == HubDataType.MODEL:
342-
hub_name, hub_region, model_id, model_version = utils.extract_info_from_hub_content_arn(id_info)
343-
hub = CuratedHub(hub_name=hub_name, hub_region=hub_region)
344-
hub_content = hub.describe_model(model_id=model_id, model_version=model_version)
345-
utils.emit_logs_based_on_model_specs(hub_content.content_document, self.get_region(), self._s3_client)
342+
hub_name, hub_region, model_name, model_version = utils.extract_info_from_hub_content_arn(
343+
id_info
344+
)
345+
hub = CuratedHub(hub_name=hub_name, region=hub_region)
346+
hub_content = hub.describe_model(model_name=model_name, model_version=model_version)
347+
utils.emit_logs_based_on_model_specs(
348+
hub_content.content_document,
349+
self.get_region(),
350+
self._s3_client
351+
)
346352
model_specs = JumpStartModelSpecs(hub_content.content_document, is_hub_content=True)
347353
return JumpStartCachedContentValue(
348354
formatted_content=model_specs
349355
)
350356
if data_type == HubDataType.HUB:
351357
hub_name, hub_region, _, _ = utils.extract_info_from_hub_content_arn(id_info)
352-
hub = CuratedHub(hub_name=hub_name, hub_region=hub_region)
358+
hub = CuratedHub(hub_name=hub_name, region=hub_region)
353359
hub_info = hub.describe()
354360
return JumpStartCachedContentValue(formatted_content=hub_info)
355361
raise ValueError(
@@ -467,7 +473,9 @@ def get_hub_model(self, hub_model_arn: str) -> JumpStartModelSpecs:
467473
hub_model_arn (str): Arn for the Hub model to get specs for
468474
"""
469475

470-
details, _ = self._content_cache.get(JumpStartCachedContentKey(HubDataType.MODEL, hub_model_arn))
476+
details, _ = self._content_cache.get(
477+
JumpStartCachedContentKey(HubDataType.MODEL, hub_model_arn)
478+
)
471479
return details.formatted_content
472480

473481
def get_hub(self, hub_arn: str) -> Dict[str, Any]:

src/sagemaker/jumpstart/curated_hub/curated_hub.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
class CuratedHub:
2121
"""Class for creating and managing a curated JumpStart hub"""
2222

23-
def __init__(self, hub_name: str, region: str, session: Optional[Session]):
23+
def __init__(self, hub_name: str, region: str, session: Optional[Session] = None):
2424
self.hub_name = hub_name
2525
self.region = region
2626
self.session = session

tests/unit/sagemaker/jumpstart/test_cache.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -860,16 +860,12 @@ def test_jumpstart_local_metadata_override_specs_not_exist_both_directories(
860860
def test_jumpstart_cache_get_hub_model():
861861
cache = JumpStartModelsCache(s3_bucket_name="some_bucket")
862862

863-
model_arn = (
864-
"arn:aws:sagemaker:us-west-2:000000000000:hub-content/HubName/Model/huggingface-mock-model-123/1.2.3"
865-
)
863+
model_arn = "arn:aws:sagemaker:us-west-2:000000000000:hub-content/HubName/Model/huggingface-mock-model-123/1.2.3"
866864
assert get_spec_from_base_spec(
867865
model_id="huggingface-mock-model-123", version="1.2.3"
868866
) == cache.get_hub_model(hub_model_arn=model_arn)
869867

870-
model_arn = (
871-
"arn:aws:sagemaker:us-west-2:000000000000:hub-content/HubName/Model/pytorch-mock-model-123/*"
872-
)
873-
assert get_spec_from_base_spec(model_id="pytorch-mock-model-123", version="*") == cache.get_hub_model(
874-
hub_model_arn=model_arn
875-
)
868+
model_arn = "arn:aws:sagemaker:us-west-2:000000000000:hub-content/HubName/Model/pytorch-mock-model-123/*"
869+
assert get_spec_from_base_spec(
870+
model_id="pytorch-mock-model-123", version="*"
871+
) == cache.get_hub_model(hub_model_arn=model_arn)

0 commit comments

Comments
 (0)