Skip to content

Commit 8a3160a

Browse files
committed
black styles
1 parent 424254e commit 8a3160a

File tree

10 files changed

+83
-71
lines changed

10 files changed

+83
-71
lines changed

src/sagemaker/jumpstart/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -870,7 +870,7 @@ def get_info_from_hub_resource_arn(
870870
account_id=account_id,
871871
hub_name=hub_name,
872872
)
873-
873+
874874
return None
875875

876876

tests/unit/sagemaker/hyperparameters/jumpstart/test_default.py

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -43,11 +43,7 @@ def test_jumpstart_default_hyperparameters(patched_get_model_specs):
4343
assert params == {"adam-learning-rate": "0.05", "batch-size": "4", "epochs": "3"}
4444

4545
patched_get_model_specs.assert_called_once_with(
46-
region=region,
47-
model_id=model_id,
48-
version="*",
49-
s3_client=mock_client,
50-
hub_arn=None
46+
region=region, model_id=model_id, version="*", s3_client=mock_client, hub_arn=None
5147
)
5248

5349
patched_get_model_specs.reset_mock()
@@ -61,10 +57,7 @@ def test_jumpstart_default_hyperparameters(patched_get_model_specs):
6157
assert params == {"adam-learning-rate": "0.05", "batch-size": "4", "epochs": "3"}
6258

6359
patched_get_model_specs.assert_called_once_with(
64-
region=region,
65-
model_id=model_id,
66-
version="1.*",
67-
s3_client=mock_client, hub_arn=None
60+
region=region, model_id=model_id, version="1.*", s3_client=mock_client, hub_arn=None
6861
)
6962

7063
patched_get_model_specs.reset_mock()
@@ -86,10 +79,7 @@ def test_jumpstart_default_hyperparameters(patched_get_model_specs):
8679
}
8780

8881
patched_get_model_specs.assert_called_once_with(
89-
region=region,
90-
model_id=model_id,
91-
version="1.*",
92-
s3_client=mock_client, hub_arn=None
82+
region=region, model_id=model_id, version="1.*", s3_client=mock_client, hub_arn=None
9383
)
9484

9585
patched_get_model_specs.reset_mock()

tests/unit/sagemaker/hyperparameters/jumpstart/test_validate.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -136,10 +136,7 @@ def add_options_to_hyperparameter(*largs, **kwargs):
136136
)
137137

138138
patched_get_model_specs.assert_called_once_with(
139-
region=region,
140-
model_id=model_id,
141-
version=model_version,
142-
s3_client=mock_client, hub_arn=None
139+
region=region, model_id=model_id, version=model_version, s3_client=mock_client, hub_arn=None
143140
)
144141

145142
patched_get_model_specs.reset_mock()

tests/unit/sagemaker/image_uris/jumpstart/test_common.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,8 @@ def test_jumpstart_common_image_uri(
4848
region="us-west-2",
4949
model_id="pytorch-ic-mobilenet-v2",
5050
version="*",
51-
s3_client=mock_client, hub_arn=None
51+
s3_client=mock_client,
52+
hub_arn=None,
5253
)
5354
patched_verify_model_region_and_return_specs.assert_called_once()
5455

@@ -68,7 +69,8 @@ def test_jumpstart_common_image_uri(
6869
region="us-west-2",
6970
model_id="pytorch-ic-mobilenet-v2",
7071
version="1.*",
71-
s3_client=mock_client, hub_arn=None
72+
s3_client=mock_client,
73+
hub_arn=None,
7274
)
7375
patched_verify_model_region_and_return_specs.assert_called_once()
7476

@@ -88,7 +90,8 @@ def test_jumpstart_common_image_uri(
8890
region=sagemaker_constants.JUMPSTART_DEFAULT_REGION_NAME,
8991
model_id="pytorch-ic-mobilenet-v2",
9092
version="*",
91-
s3_client=mock_client, hub_arn=None
93+
s3_client=mock_client,
94+
hub_arn=None,
9295
)
9396
patched_verify_model_region_and_return_specs.assert_called_once()
9497

@@ -108,7 +111,8 @@ def test_jumpstart_common_image_uri(
108111
region=sagemaker_constants.JUMPSTART_DEFAULT_REGION_NAME,
109112
model_id="pytorch-ic-mobilenet-v2",
110113
version="1.*",
111-
s3_client=mock_client, hub_arn=None
114+
s3_client=mock_client,
115+
hub_arn=None,
112116
)
113117
patched_verify_model_region_and_return_specs.assert_called_once()
114118

tests/unit/sagemaker/jumpstart/test_artifacts.py

Lines changed: 45 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
# ANY KIND, either express or implied. See the License for the specific
1212
# language governing permissions and limitations under the License.
1313
from __future__ import absolute_import
14-
from importlib import reload
1514
import unittest
1615
from unittest.mock import Mock
1716

@@ -21,12 +20,16 @@
2120

2221
import copy
2322
from sagemaker.jumpstart import artifacts
24-
from sagemaker.jumpstart.artifacts.environment_variables import _retrieve_default_environment_variables
23+
from sagemaker.jumpstart.artifacts.environment_variables import (
24+
_retrieve_default_environment_variables,
25+
)
2526
from sagemaker.jumpstart.artifacts.hyperparameters import _retrieve_default_hyperparameters
2627
from sagemaker.jumpstart.artifacts.image_uris import _retrieve_image_uri
2728
from sagemaker.jumpstart.artifacts.incremental_training import _model_supports_incremental_training
2829
from sagemaker.jumpstart.artifacts.instance_types import _retrieve_default_instance_type
29-
from sagemaker.jumpstart.artifacts.metric_definitions import _retrieve_default_training_metric_definitions
30+
from sagemaker.jumpstart.artifacts.metric_definitions import (
31+
_retrieve_default_training_metric_definitions,
32+
)
3033
from sagemaker.jumpstart.artifacts.model_uris import (
3134
_retrieve_hosting_prepacked_artifact_key,
3235
_retrieve_hosting_artifact_key,
@@ -467,7 +470,6 @@ def test_retrieve_uri_from_gated_bucket(self, patched_get_model_specs):
467470

468471

469472
class HubModelTest(unittest.TestCase):
470-
471473
@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._cache")
472474
def test_retrieve_default_environment_variables(self, mock_cache):
473475
mock_cache.get_hub_model.return_value = JumpStartModelSpecs(spec=copy.deepcopy(BASE_SPEC))
@@ -477,7 +479,10 @@ def test_retrieve_default_environment_variables(self, mock_cache):
477479

478480
self.assertEqual(
479481
_retrieve_default_environment_variables(
480-
model_id=model_id, model_version=version, hub_arn=hub_arn, script=JumpStartScriptScope.INFERENCE
482+
model_id=model_id,
483+
model_version=version,
484+
hub_arn=hub_arn,
485+
script=JumpStartScriptScope.INFERENCE,
481486
),
482487
{
483488
"SAGEMAKER_PROGRAM": "inference.py",
@@ -487,16 +492,15 @@ def test_retrieve_default_environment_variables(self, mock_cache):
487492
"ENDPOINT_SERVER_TIMEOUT": "3600",
488493
"MODEL_CACHE_ROOT": "/opt/ml/model",
489494
"SAGEMAKER_ENV": "1",
490-
"SAGEMAKER_MODEL_SERVER_WORKERS": "1"
491-
}
495+
"SAGEMAKER_MODEL_SERVER_WORKERS": "1",
496+
},
492497
)
493498
mock_cache.get_hub_model.assert_called_once_with(
494499
hub_model_arn=(
495500
f"arn:aws:sagemaker:us-west-2:000000000000:hub-content/my-cool-hub/Model/{model_id}/{version}"
496501
)
497502
)
498503

499-
500504
@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._cache")
501505
def test_retrieve_image_uri(self, mock_cache):
502506
mock_cache.get_hub_model.return_value = JumpStartModelSpecs(spec=copy.deepcopy(BASE_SPEC))
@@ -506,9 +510,13 @@ def test_retrieve_image_uri(self, mock_cache):
506510

507511
self.assertEqual(
508512
_retrieve_image_uri(
509-
model_id=model_id, model_version=version, hub_arn=hub_arn, instance_type="ml.p3.2xlarge", image_scope=JumpStartScriptScope.TRAINING
513+
model_id=model_id,
514+
model_version=version,
515+
hub_arn=hub_arn,
516+
instance_type="ml.p3.2xlarge",
517+
image_scope=JumpStartScriptScope.TRAINING,
510518
),
511-
"763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-training:1.5.0-gpu-py3"
519+
"763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-training:1.5.0-gpu-py3",
512520
)
513521
mock_cache.get_hub_model.assert_called_once_with(
514522
hub_model_arn=(
@@ -531,7 +539,7 @@ def test_retrieve_default_hyperparameters(self, mock_cache):
531539
"epochs": "3",
532540
"adam-learning-rate": "0.05",
533541
"batch-size": "4",
534-
}
542+
},
535543
)
536544
mock_cache.get_hub_model.assert_called_once_with(
537545
hub_model_arn=(
@@ -550,7 +558,7 @@ def test_model_supports_incremental_training(self, mock_cache):
550558
_model_supports_incremental_training(
551559
model_id=model_id, model_version=version, hub_arn=hub_arn, region="us-west-2"
552560
),
553-
True
561+
True,
554562
)
555563
mock_cache.get_hub_model.assert_called_once_with(
556564
hub_model_arn=(
@@ -567,9 +575,12 @@ def test_retrieve_default_instance_type(self, mock_cache):
567575

568576
self.assertEqual(
569577
_retrieve_default_instance_type(
570-
model_id=model_id, model_version=version, hub_arn=hub_arn, scope=JumpStartScriptScope.TRAINING
578+
model_id=model_id,
579+
model_version=version,
580+
hub_arn=hub_arn,
581+
scope=JumpStartScriptScope.TRAINING,
571582
),
572-
"ml.p3.2xlarge"
583+
"ml.p3.2xlarge",
573584
)
574585
mock_cache.get_hub_model.assert_called_once_with(
575586
hub_model_arn=(
@@ -579,9 +590,12 @@ def test_retrieve_default_instance_type(self, mock_cache):
579590

580591
self.assertEqual(
581592
_retrieve_default_instance_type(
582-
model_id=model_id, model_version=version, hub_arn=hub_arn, scope=JumpStartScriptScope.INFERENCE
593+
model_id=model_id,
594+
model_version=version,
595+
hub_arn=hub_arn,
596+
scope=JumpStartScriptScope.INFERENCE,
583597
),
584-
"ml.p2.xlarge"
598+
"ml.p2.xlarge",
585599
)
586600

587601
@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._cache")
@@ -595,15 +609,14 @@ def test_retrieve_default_training_metric_definitions(self, mock_cache):
595609
_retrieve_default_training_metric_definitions(
596610
model_id=model_id, model_version=version, hub_arn=hub_arn, region="us-west-2"
597611
),
598-
[{"Regex": "val_accuracy: ([0-9\\.]+)", "Name": "pytorch-ic:val-accuracy"}]
612+
[{"Regex": "val_accuracy: ([0-9\\.]+)", "Name": "pytorch-ic:val-accuracy"}],
599613
)
600614
mock_cache.get_hub_model.assert_called_once_with(
601615
hub_model_arn=(
602616
f"arn:aws:sagemaker:us-west-2:000000000000:hub-content/my-cool-hub/Model/{model_id}/{version}"
603617
)
604618
)
605619

606-
607620
@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._cache")
608621
def test_retrieve_model_uri(self, mock_cache):
609622
mock_cache.get_hub_model.return_value = JumpStartModelSpecs(spec=copy.deepcopy(BASE_SPEC))
@@ -615,7 +628,7 @@ def test_retrieve_model_uri(self, mock_cache):
615628
_retrieve_model_uri(
616629
model_id=model_id, model_version=version, hub_arn=hub_arn, model_scope="training"
617630
),
618-
"s3://jumpstart-cache-prod-us-west-2/pytorch-training/train-pytorch-ic-mobilenet-v2.tar.gz"
631+
"s3://jumpstart-cache-prod-us-west-2/pytorch-training/train-pytorch-ic-mobilenet-v2.tar.gz",
619632
)
620633
mock_cache.get_hub_model.assert_called_once_with(
621634
hub_model_arn=(
@@ -627,7 +640,7 @@ def test_retrieve_model_uri(self, mock_cache):
627640
_retrieve_model_uri(
628641
model_id=model_id, model_version=version, hub_arn=hub_arn, model_scope="inference"
629642
),
630-
"s3://jumpstart-cache-prod-us-west-2/pytorch-infer/infer-pytorch-ic-mobilenet-v2.tar.gz"
643+
"s3://jumpstart-cache-prod-us-west-2/pytorch-infer/infer-pytorch-ic-mobilenet-v2.tar.gz",
631644
)
632645

633646
@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._cache")
@@ -639,9 +652,13 @@ def test_retrieve_script_uri(self, mock_cache):
639652

640653
self.assertEqual(
641654
_retrieve_script_uri(
642-
model_id=model_id, model_version=version, hub_arn=hub_arn, script_scope=JumpStartScriptScope.TRAINING
655+
model_id=model_id,
656+
model_version=version,
657+
hub_arn=hub_arn,
658+
script_scope=JumpStartScriptScope.TRAINING,
643659
),
644-
"s3://jumpstart-cache-prod-us-west-2/source-directory-tarballs/pytorch/transfer_learning/ic/v1.0.0/sourcedir.tar.gz"
660+
"s3://jumpstart-cache-prod-us-west-2/source-directory-tarballs/pytorch/"
661+
"transfer_learning/ic/v1.0.0/sourcedir.tar.gz",
645662
)
646663
mock_cache.get_hub_model.assert_called_once_with(
647664
hub_model_arn=(
@@ -651,7 +668,11 @@ def test_retrieve_script_uri(self, mock_cache):
651668

652669
self.assertEqual(
653670
_retrieve_script_uri(
654-
model_id=model_id, model_version=version, hub_arn=hub_arn, script_scope=JumpStartScriptScope.INFERENCE
671+
model_id=model_id,
672+
model_version=version,
673+
hub_arn=hub_arn,
674+
script_scope=JumpStartScriptScope.INFERENCE,
655675
),
656-
"s3://jumpstart-cache-prod-us-west-2/source-directory-tarballs/pytorch/inference/ic/v1.0.0/sourcedir.tar.gz"
676+
"s3://jumpstart-cache-prod-us-west-2/source-directory-tarballs/pytorch/"
677+
"inference/ic/v1.0.0/sourcedir.tar.gz",
657678
)

tests/unit/sagemaker/jumpstart/test_notebook_utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -696,5 +696,6 @@ def test_get_model_url(
696696
model_id=model_id,
697697
version=version,
698698
region=region,
699-
s3_client=DEFAULT_JUMPSTART_SAGEMAKER_SESSION.s3_client, hub_arn=None
699+
s3_client=DEFAULT_JUMPSTART_SAGEMAKER_SESSION.s3_client,
700+
hub_arn=None,
700701
)

tests/unit/sagemaker/model_uris/jumpstart/test_common.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,8 @@ def test_jumpstart_common_model_uri(
4646
region=sagemaker_constants.JUMPSTART_DEFAULT_REGION_NAME,
4747
model_id="pytorch-ic-mobilenet-v2",
4848
version="*",
49-
s3_client=mock_client, hub_arn=None
49+
s3_client=mock_client,
50+
hub_arn=None,
5051
)
5152
patched_verify_model_region_and_return_specs.assert_called_once()
5253

@@ -63,7 +64,8 @@ def test_jumpstart_common_model_uri(
6364
region=sagemaker_constants.JUMPSTART_DEFAULT_REGION_NAME,
6465
model_id="pytorch-ic-mobilenet-v2",
6566
version="1.*",
66-
s3_client=mock_client, hub_arn=None
67+
s3_client=mock_client,
68+
hub_arn=None,
6769
)
6870
patched_verify_model_region_and_return_specs.assert_called_once()
6971

@@ -81,7 +83,8 @@ def test_jumpstart_common_model_uri(
8183
region="us-west-2",
8284
model_id="pytorch-ic-mobilenet-v2",
8385
version="*",
84-
s3_client=mock_client, hub_arn=None
86+
s3_client=mock_client,
87+
hub_arn=None,
8588
)
8689
patched_verify_model_region_and_return_specs.assert_called_once()
8790

@@ -99,7 +102,8 @@ def test_jumpstart_common_model_uri(
99102
region="us-west-2",
100103
model_id="pytorch-ic-mobilenet-v2",
101104
version="1.*",
102-
s3_client=mock_client, hub_arn=None
105+
s3_client=mock_client,
106+
hub_arn=None,
103107
)
104108
patched_verify_model_region_and_return_specs.assert_called_once()
105109

tests/unit/sagemaker/resource_requirements/jumpstart/test_resource_requirements.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -42,10 +42,7 @@ def test_jumpstart_resource_requirements(patched_get_model_specs):
4242
assert default_inference_resource_requirements.requests["memory"] == 34360
4343

4444
patched_get_model_specs.assert_called_once_with(
45-
region=region,
46-
model_id=model_id,
47-
version=model_version,
48-
s3_client=mock_client, hub_arn=None
45+
region=region, model_id=model_id, version=model_version, s3_client=mock_client, hub_arn=None
4946
)
5047
patched_get_model_specs.reset_mock()
5148

@@ -69,10 +66,7 @@ def test_jumpstart_no_supported_resource_requirements(patched_get_model_specs):
6966
assert default_inference_resource_requirements is None
7067

7168
patched_get_model_specs.assert_called_once_with(
72-
region=region,
73-
model_id=model_id,
74-
version=model_version,
75-
s3_client=mock_client, hub_arn=None
69+
region=region, model_id=model_id, version=model_version, s3_client=mock_client, hub_arn=None
7670
)
7771
patched_get_model_specs.reset_mock()
7872

0 commit comments

Comments
 (0)