Skip to content

Commit 1af132e

Browse files
committed
linting
1 parent b50c557 commit 1af132e

File tree

4 files changed

+4
-23
lines changed

4 files changed

+4
-23
lines changed

src/sagemaker/jumpstart/curated_hub/curated_hub.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -394,6 +394,7 @@ def _sync_public_model_to_hub(self, model: JumpStartModelInfo, thread_num: int):
394394
self._sagemaker_session.import_hub_content(
395395
document_schema_version=HubContentDocument_v2.SCHEMA_VERSION,
396396
hub_content_name=model.model_id,
397+
hub_content_version=model.version,
397398
hub_name=self.hub_name,
398399
hub_content_document=hub_content_document,
399400
hub_content_type=HubContentType.MODEL,

src/sagemaker/jumpstart/types.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1434,8 +1434,6 @@ class JumpStartModelDeployKwargs(JumpStartKwargs):
14341434
"model_version",
14351435
"model_type",
14361436
"hub_arn",
1437-
"model_type",
1438-
"hub_arn",
14391437
"region",
14401438
"tolerate_deprecated_model",
14411439
"tolerate_vulnerable_model",

tests/unit/sagemaker/jumpstart/test_accessors.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,9 @@ def test_jumpstart_models_cache_get_model_specs(mock_cache):
8383
accessors.JumpStartModelsAccessor.get_model_specs(
8484
region=region, model_id=model_id, version=version
8585
)
86-
mock_cache.get_specs.assert_called_once_with(model_id=model_id, version_str=version, model_type=JumpStartModelType.OPEN_WEIGHTS)
86+
mock_cache.get_specs.assert_called_once_with(
87+
model_id=model_id, version_str=version, model_type=JumpStartModelType.OPEN_WEIGHTS
88+
)
8789
mock_cache.get_hub_model.assert_not_called()
8890

8991
accessors.JumpStartModelsAccessor.get_model_specs(

tests/unit/sagemaker/jumpstart/utils.py

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -254,26 +254,6 @@ def patched_retrieval_function(
254254
)
255255
)
256256

257-
if datatype == HubContentType.MODEL:
258-
_, _, _, model_name, model_version = id_info.split("/")
259-
return JumpStartCachedContentValue(
260-
formatted_content=get_spec_from_base_spec(model_id=model_name, version=model_version)
261-
)
262-
263-
# TODO: Implement
264-
if datatype == HubType.HUB:
265-
return None
266-
267-
if datatype == HubContentType.MODEL:
268-
_, _, _, model_name, model_version = id_info.split("/")
269-
return JumpStartCachedContentValue(
270-
formatted_content=get_spec_from_base_spec(model_id=model_name, version=model_version)
271-
)
272-
273-
# TODO: Implement
274-
if datatype == HubType.HUB:
275-
return None
276-
277257
raise ValueError(f"Bad value for datatype: {datatype}")
278258

279259

0 commit comments

Comments
 (0)