Skip to content

Commit 424254e

Browse files
committed
update tests
1 parent 354b33e commit 424254e

File tree

21 files changed

+253
-49
lines changed

21 files changed

+253
-49
lines changed

src/sagemaker/jumpstart/constants.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,7 @@
170170

171171
JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY = "models_manifest.json"
172172

173-
HUB_CONTENT_ARN_REGEX = r"arn:(.*?):sagemaker:(.*?):(.*?):hub_content/(.*?)/(.*?)/(.*?)/(.*?)$"
173+
HUB_CONTENT_ARN_REGEX = r"arn:(.*?):sagemaker:(.*?):(.*?):hub-content/(.*?)/(.*?)/(.*?)/(.*?)$"
174174
HUB_ARN_REGEX = r"arn:(.*?):sagemaker:(.*?):(.*?):hub/(.*?)$"
175175

176176
INFERENCE_ENTRY_POINT_SCRIPT_NAME = "inference.py"

src/sagemaker/jumpstart/types.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -867,7 +867,7 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
867867
self.inference_enable_network_isolation: bool = json_obj.get(
868868
"inference_enable_network_isolation", False
869869
)
870-
self.resource_name_base: bool = json_obj.get("resource_name_base")
870+
self.resource_name_base: Optional[str] = json_obj.get("resource_name_base")
871871

872872
self.hosting_eula_key: Optional[str] = json_obj.get("hosting_eula_key")
873873

src/sagemaker/jumpstart/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -870,7 +870,7 @@ def get_info_from_hub_resource_arn(
870870
account_id=account_id,
871871
hub_name=hub_name,
872872
)
873-
873+
874874
return None
875875

876876

tests/unit/sagemaker/accept_types/jumpstart/test_accept_types.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def test_jumpstart_default_accept_types(
4545
assert default_accept_type == "application/json"
4646

4747
patched_get_model_specs.assert_called_once_with(
48-
region=region, model_id=model_id, version=model_version, s3_client=mock_client
48+
region=region, model_id=model_id, version=model_version, s3_client=mock_client, hub_arn=None
4949
)
5050

5151

@@ -73,5 +73,5 @@ def test_jumpstart_supported_accept_types(
7373
]
7474

7575
patched_get_model_specs.assert_called_once_with(
76-
region=region, model_id=model_id, version=model_version, s3_client=mock_client
76+
region=region, model_id=model_id, version=model_version, s3_client=mock_client, hub_arn=None
7777
)

tests/unit/sagemaker/content_types/jumpstart/test_content_types.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def test_jumpstart_default_content_types(
4545
assert default_content_type == "application/x-text"
4646

4747
patched_get_model_specs.assert_called_once_with(
48-
region=region, model_id=model_id, version=model_version, s3_client=mock_client
48+
region=region, model_id=model_id, version=model_version, s3_client=mock_client, hub_arn=None
4949
)
5050

5151

@@ -72,5 +72,5 @@ def test_jumpstart_supported_content_types(
7272
]
7373

7474
patched_get_model_specs.assert_called_once_with(
75-
region=region, model_id=model_id, version=model_version, s3_client=mock_client
75+
region=region, model_id=model_id, version=model_version, s3_client=mock_client, hub_arn=None
7676
)

tests/unit/sagemaker/deserializers/jumpstart/test_deserializers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def test_jumpstart_default_deserializers(
4747
assert isinstance(default_deserializer, base_deserializers.JSONDeserializer)
4848

4949
patched_get_model_specs.assert_called_once_with(
50-
region=region, model_id=model_id, version=model_version, s3_client=mock_client
50+
region=region, model_id=model_id, version=model_version, s3_client=mock_client, hub_arn=None
5151
)
5252

5353

@@ -79,5 +79,5 @@ def test_jumpstart_deserializer_options(
7979
)
8080

8181
patched_get_model_specs.assert_called_once_with(
82-
region=region, model_id=model_id, version=model_version, s3_client=mock_client
82+
region=region, model_id=model_id, version=model_version, s3_client=mock_client, hub_arn=None
8383
)

tests/unit/sagemaker/environment_variables/jumpstart/test_default.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def test_jumpstart_default_environment_variables(patched_get_model_specs):
4848
}
4949

5050
patched_get_model_specs.assert_called_once_with(
51-
region=region, model_id=model_id, version="*", s3_client=mock_client
51+
region=region, model_id=model_id, version="*", s3_client=mock_client, hub_arn=None
5252
)
5353

5454
patched_get_model_specs.reset_mock()
@@ -68,7 +68,7 @@ def test_jumpstart_default_environment_variables(patched_get_model_specs):
6868
}
6969

7070
patched_get_model_specs.assert_called_once_with(
71-
region=region, model_id=model_id, version="1.*", s3_client=mock_client
71+
region=region, model_id=model_id, version="1.*", s3_client=mock_client, hub_arn=None
7272
)
7373

7474
patched_get_model_specs.reset_mock()
@@ -122,7 +122,7 @@ def test_jumpstart_sdk_environment_variables(patched_get_model_specs):
122122
}
123123

124124
patched_get_model_specs.assert_called_once_with(
125-
region=region, model_id=model_id, version="*", s3_client=mock_client
125+
region=region, model_id=model_id, version="*", s3_client=mock_client, hub_arn=None
126126
)
127127

128128
patched_get_model_specs.reset_mock()
@@ -143,7 +143,7 @@ def test_jumpstart_sdk_environment_variables(patched_get_model_specs):
143143
}
144144

145145
patched_get_model_specs.assert_called_once_with(
146-
region=region, model_id=model_id, version="1.*", s3_client=mock_client
146+
region=region, model_id=model_id, version="1.*", s3_client=mock_client, hub_arn=None
147147
)
148148

149149
patched_get_model_specs.reset_mock()

tests/unit/sagemaker/hyperparameters/jumpstart/test_default.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ def test_jumpstart_default_hyperparameters(patched_get_model_specs):
4747
model_id=model_id,
4848
version="*",
4949
s3_client=mock_client,
50+
hub_arn=None
5051
)
5152

5253
patched_get_model_specs.reset_mock()
@@ -63,7 +64,7 @@ def test_jumpstart_default_hyperparameters(patched_get_model_specs):
6364
region=region,
6465
model_id=model_id,
6566
version="1.*",
66-
s3_client=mock_client,
67+
s3_client=mock_client, hub_arn=None
6768
)
6869

6970
patched_get_model_specs.reset_mock()
@@ -88,7 +89,7 @@ def test_jumpstart_default_hyperparameters(patched_get_model_specs):
8889
region=region,
8990
model_id=model_id,
9091
version="1.*",
91-
s3_client=mock_client,
92+
s3_client=mock_client, hub_arn=None
9293
)
9394

9495
patched_get_model_specs.reset_mock()

tests/unit/sagemaker/hyperparameters/jumpstart/test_validate.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ def add_options_to_hyperparameter(*largs, **kwargs):
139139
region=region,
140140
model_id=model_id,
141141
version=model_version,
142-
s3_client=mock_client,
142+
s3_client=mock_client, hub_arn=None
143143
)
144144

145145
patched_get_model_specs.reset_mock()
@@ -437,7 +437,7 @@ def add_options_to_hyperparameter(*largs, **kwargs):
437437
)
438438

439439
patched_get_model_specs.assert_called_once_with(
440-
region=region, model_id=model_id, version=model_version, s3_client=mock_client
440+
region=region, model_id=model_id, version=model_version, s3_client=mock_client, hub_arn=None
441441
)
442442

443443
patched_get_model_specs.reset_mock()
@@ -491,7 +491,7 @@ def test_jumpstart_validate_all_hyperparameters(patched_get_model_specs):
491491
)
492492

493493
patched_get_model_specs.assert_called_once_with(
494-
region=region, model_id=model_id, version=model_version, s3_client=mock_client
494+
region=region, model_id=model_id, version=model_version, s3_client=mock_client, hub_arn=None
495495
)
496496

497497
patched_get_model_specs.reset_mock()

tests/unit/sagemaker/image_uris/jumpstart/test_common.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def test_jumpstart_common_image_uri(
4848
region="us-west-2",
4949
model_id="pytorch-ic-mobilenet-v2",
5050
version="*",
51-
s3_client=mock_client,
51+
s3_client=mock_client, hub_arn=None
5252
)
5353
patched_verify_model_region_and_return_specs.assert_called_once()
5454

@@ -68,7 +68,7 @@ def test_jumpstart_common_image_uri(
6868
region="us-west-2",
6969
model_id="pytorch-ic-mobilenet-v2",
7070
version="1.*",
71-
s3_client=mock_client,
71+
s3_client=mock_client, hub_arn=None
7272
)
7373
patched_verify_model_region_and_return_specs.assert_called_once()
7474

@@ -88,7 +88,7 @@ def test_jumpstart_common_image_uri(
8888
region=sagemaker_constants.JUMPSTART_DEFAULT_REGION_NAME,
8989
model_id="pytorch-ic-mobilenet-v2",
9090
version="*",
91-
s3_client=mock_client,
91+
s3_client=mock_client, hub_arn=None
9292
)
9393
patched_verify_model_region_and_return_specs.assert_called_once()
9494

@@ -108,7 +108,7 @@ def test_jumpstart_common_image_uri(
108108
region=sagemaker_constants.JUMPSTART_DEFAULT_REGION_NAME,
109109
model_id="pytorch-ic-mobilenet-v2",
110110
version="1.*",
111-
s3_client=mock_client,
111+
s3_client=mock_client, hub_arn=None
112112
)
113113
patched_verify_model_region_and_return_specs.assert_called_once()
114114

tests/unit/sagemaker/jumpstart/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6239,6 +6239,7 @@
62396239
"training_volume_size": 456,
62406240
"inference_enable_network_isolation": True,
62416241
"training_enable_network_isolation": False,
6242+
"resource_name_base": "dfsdfsds",
62426243
"hosting_resource_requirements": {"num_accelerators": 1, "min_memory_mb": 34360},
62436244
"dynamic_container_deployment_supported": True,
62446245
}

tests/unit/sagemaker/jumpstart/estimator/test_estimator.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -282,11 +282,12 @@ def test_prepacked(
282282
],
283283
)
284284

285+
@mock.patch("sagemaker.jumpstart.artifacts.resource_names._retrieve_resource_name_base")
285286
@mock.patch("sagemaker.session.Session.account_id")
286287
@mock.patch("sagemaker.jumpstart.factory.model._retrieve_model_init_kwargs")
287288
@mock.patch("sagemaker.jumpstart.factory.model._retrieve_model_deploy_kwargs")
288289
@mock.patch("sagemaker.jumpstart.factory.estimator._retrieve_estimator_fit_kwargs")
289-
@mock.patch("sagemaker.jumpstart.estimator.construct_hub_arn_from_name")
290+
@mock.patch("sagemaker.jumpstart.utils.construct_hub_arn_from_name")
290291
@mock.patch("sagemaker.jumpstart.estimator.is_valid_model_id")
291292
@mock.patch("sagemaker.jumpstart.factory.model.Session")
292293
@mock.patch("sagemaker.jumpstart.factory.estimator.Session")
@@ -310,7 +311,9 @@ def test_hub_model(
310311
mock_retrieve_model_deploy_kwargs: mock.Mock,
311312
mock_retrieve_model_init_kwargs: mock.Mock,
312313
mock_get_caller_identity: mock.Mock,
314+
mock_retrieve_resource_name_base: mock.Mock,
313315
):
316+
mock_retrieve_resource_name_base.return_value = "go-blue"
314317
mock_get_caller_identity.return_value = "123456789123"
315318
mock_estimator_deploy.return_value = default_predictor
316319

@@ -372,11 +375,11 @@ def test_hub_model(
372375
f"some-training-dataset-doesn't-matter",
373376
}
374377

375-
estimator.fit(channels)
378+
estimator.fit(channels, job_name="go-blue")
376379

377-
mock_estimator_fit.assert_called_once_with(inputs=channels, wait=True)
380+
mock_estimator_fit.assert_called_once_with(inputs=channels, wait=True, job_name="go-blue")
378381

379-
estimator.deploy()
382+
estimator.deploy(endpoint_name="go-blue", model_name="go-blue")
380383

381384
mock_estimator_deploy.assert_called_once_with(
382385
instance_type="ml.p2.xlarge",
@@ -386,6 +389,8 @@ def test_hub_model(
386389
source_dir="s3://jumpstart-cache-prod-us-west-2/source-directory-tarballs/"
387390
"pytorch/inference/ic/v1.0.0/sourcedir.tar.gz",
388391
entry_point="inference.py",
392+
endpoint_name="go-blue",
393+
model_name="go-blue",
389394
env={
390395
"SAGEMAKER_PROGRAM": "inference.py",
391396
"ENDPOINT_SERVER_TIMEOUT": "3600",
@@ -414,7 +419,7 @@ def test_hub_model(
414419
)
415420

416421
mock_construct_hub_arn_from_name.assert_called_once_with(
417-
hub_name="my-mock-hub", region=None, sagemaker_session=None
422+
hub_name="my-mock-hub", region=None, session=None
418423
)
419424

420425
@mock.patch("sagemaker.utils.sagemaker_timestamp")
@@ -1475,6 +1480,7 @@ def test_incremental_training_with_unsupported_model_logs_warning(
14751480
model_id=model_id,
14761481
model_version="*",
14771482
region=region,
1483+
hub_arn=None,
14781484
tolerate_deprecated_model=False,
14791485
tolerate_vulnerable_model=False,
14801486
sagemaker_session=sagemaker_session,
@@ -1526,6 +1532,7 @@ def test_incremental_training_with_supported_model_doesnt_log_warning(
15261532
model_id=model_id,
15271533
model_version="*",
15281534
region=region,
1535+
hub_arn=None,
15291536
tolerate_deprecated_model=False,
15301537
tolerate_vulnerable_model=False,
15311538
sagemaker_session=sagemaker_session,

0 commit comments

Comments
 (0)