Skip to content

Commit 97001cc

Browse files
committed
update tests
1 parent 28c9186 commit 97001cc

File tree

3 files changed

+16
-12
lines changed

3 files changed

+16
-12
lines changed

src/sagemaker/jumpstart/curated_hub/accessors/multipartcopy.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def __init__(
6969
region: str,
7070
sync_request: HubSyncRequest,
7171
label: Optional[str] = None,
72-
thread_num: Optional[int] = 0
72+
thread_num: Optional[int] = 0,
7373
):
7474
"""Multi-part S3:Copy Handler initializer.
7575

src/sagemaker/jumpstart/curated_hub/curated_hub.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
from botocore import exceptions
2222
from botocore.client import BaseClient
2323
from packaging.version import Version
24-
import tqdm
2524

2625
from sagemaker.jumpstart import utils
2726
from sagemaker.jumpstart.curated_hub.accessors import file_generator
@@ -315,9 +314,9 @@ def sync(self, model_list: List[Dict[str, str]]):
315314
max_workers=self._default_thread_pool_size,
316315
thread_name_prefix="import-models-to-curated-hub",
317316
) as deploy_executor:
318-
for thread_num, model in enumerate(models_to_sync):
319-
task = deploy_executor.submit(self._sync_public_model_to_hub, model, thread_num)
320-
tasks.append(task)
317+
for thread_num, model in enumerate(models_to_sync):
318+
task = deploy_executor.submit(self._sync_public_model_to_hub, model, thread_num)
319+
tasks.append(task)
321320

322321
# Handle failed imports
323322
results = futures.wait(tasks)
@@ -365,7 +364,12 @@ def _sync_public_model_to_hub(self, model: JumpStartModelInfo, thread_num: int):
365364
).create()
366365

367366
if len(sync_request.files) > 0:
368-
MultiPartCopyHandler(thread_num=thread_num, sync_request=sync_request, region=self.region, label=dest_location.key).execute()
367+
MultiPartCopyHandler(
368+
thread_num=thread_num,
369+
sync_request=sync_request,
370+
region=self.region,
371+
label=dest_location.key,
372+
).execute()
369373
else:
370374
JUMPSTART_LOGGER.warning("Nothing to copy for %s v%s", model.model_id, model.version)
371375

tests/unit/sagemaker/jumpstart/curated_hub/test_curated_hub.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -163,8 +163,8 @@ def test_sync_kicks_off_parallel_syncs(
163163

164164
mock_sync_public_models.assert_has_calls(
165165
[
166-
mock.call(JumpStartModelInfo("mock-model-one-huggingface", "*")),
167-
mock.call(JumpStartModelInfo("mock-model-two-pytorch", "1.0.2")),
166+
mock.call(JumpStartModelInfo("mock-model-one-huggingface", "*"), 0),
167+
mock.call(JumpStartModelInfo("mock-model-two-pytorch", "1.0.2"), 1),
168168
]
169169
)
170170

@@ -206,7 +206,7 @@ def test_sync_filters_models_that_exist_in_hub(
206206
hub.sync([model_one, model_two])
207207

208208
mock_sync_public_models.assert_called_once_with(
209-
JumpStartModelInfo("mock-model-one-huggingface", "*")
209+
JumpStartModelInfo("mock-model-one-huggingface", "*"), 0
210210
)
211211

212212

@@ -252,8 +252,8 @@ def test_sync_updates_old_models_in_hub(
252252

253253
mock_sync_public_models.assert_has_calls(
254254
[
255-
mock.call(JumpStartModelInfo("mock-model-one-huggingface", "*")),
256-
mock.call(JumpStartModelInfo("mock-model-two-pytorch", "1.0.2")),
255+
mock.call(JumpStartModelInfo("mock-model-one-huggingface", "*"), 0),
256+
mock.call(JumpStartModelInfo("mock-model-two-pytorch", "1.0.2"), 1),
257257
]
258258
)
259259

@@ -299,7 +299,7 @@ def test_sync_passes_newer_hub_models(
299299
hub.sync([model_one, model_two])
300300

301301
mock_sync_public_models.assert_called_once_with(
302-
JumpStartModelInfo("mock-model-one-huggingface", "*")
302+
JumpStartModelInfo("mock-model-one-huggingface", "*"), 0
303303
)
304304

305305

0 commit comments

Comments
 (0)