Skip to content

Commit 8af63e6

Browse files
committed
MultiPartCopy with Sync Algorithm (aws#4475)
* first pass at sync function with util classes * adding tests and update clases * linting * file generator class inheritance * lint * multipart copy and algorithm updates * modularize sync * reformatting folders * testing for sync * do not tolerate vulnerable * remove prints * handle multithreading progress bar * update tests * optimize function and add hub bucket prefix * docstrings and linting
1 parent f4c72ca commit 8af63e6

File tree

3 files changed

+11
-6
lines changed

3 files changed

+11
-6
lines changed

src/sagemaker/jumpstart/cache.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -474,7 +474,9 @@ def _retrieval_function(
474474
formatted_content=hub_description,
475475
)
476476

477-
raise ValueError(self._file_type_error_msg(data_type))
477+
raise ValueError(
478+
self._file_type_error_msg(data_type)
479+
)
478480

479481
def get_manifest(
480482
self,

tests/unit/sagemaker/jumpstart/curated_hub/test_utils.py

+7-4
Original file line numberDiff line numberDiff line change
@@ -147,10 +147,7 @@ def test_generate_hub_arn_for_init_kwargs():
147147
utils.generate_hub_arn_for_init_kwargs(hub_arn, "us-east-1", mock_custom_session) == hub_arn
148148
)
149149

150-
assert (
151-
utils.generate_hub_arn_for_estimator_init_kwargs(hub_arn, None, mock_custom_session)
152-
== hub_arn
153-
)
150+
assert utils.generate_hub_arn_for_init_kwargs(hub_arn, None, mock_custom_session) == hub_arn
154151

155152

156153
def test_generate_default_hub_bucket_name():
@@ -170,8 +167,14 @@ def test_create_hub_bucket_if_it_does_not_exist():
170167
mock_sagemaker_session.client("sts").get_caller_identity.return_value = {
171168
"Account": "123456789123"
172169
}
170+
hub_arn = "arn:aws:sagemaker:us-west-2:12346789123:hub/my-awesome-hub"
171+
# Mock custom session with custom values
172+
mock_custom_session = Mock()
173+
mock_custom_session.account_id.return_value = "000000000000"
174+
mock_custom_session.boto_region_name = "us-east-2"
173175
mock_sagemaker_session.boto_session.resource("s3").Bucket().creation_date = None
174176
mock_sagemaker_session.boto_region_name = "us-east-1"
177+
175178
bucket_name = "sagemaker-hubs-us-east-1-123456789123"
176179
created_hub_bucket_name = utils.create_hub_bucket_if_it_does_not_exist(
177180
sagemaker_session=mock_sagemaker_session

tests/unit/sagemaker/jumpstart/test_cache.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1134,7 +1134,7 @@ def test_jumpstart_local_metadata_override_specs_not_exist_both_directories(
11341134

11351135
mocked_is_dir.assert_any_call("/some/directory/metadata/manifest/root")
11361136
assert mocked_is_dir.call_count == 2
1137-
assert mocked_open.call_count == 2
1137+
mocked_open.assert_not_called()
11381138
mocked_get_json_file_and_etag_from_s3.assert_has_calls(
11391139
calls=[
11401140
call("models_manifest.json"),

0 commit comments

Comments
 (0)