Skip to content

Commit 151350c

Browse files
committed
finish tests
1 parent 8a3160a commit 151350c

File tree

1 file changed

+32
-32
lines changed

1 file changed

+32
-32
lines changed

tests/unit/sagemaker/jumpstart/test_artifacts.py

Lines changed: 32 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -470,9 +470,9 @@ def test_retrieve_uri_from_gated_bucket(self, patched_get_model_specs):
470470

471471

472472
class HubModelTest(unittest.TestCase):
473-
@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._cache")
474-
def test_retrieve_default_environment_variables(self, mock_cache):
475-
mock_cache.get_hub_model.return_value = JumpStartModelSpecs(spec=copy.deepcopy(BASE_SPEC))
473+
@patch("sagemaker.jumpstart.cache.JumpStartModelsCache.get_hub_model")
474+
def test_retrieve_default_environment_variables(self, mock_get_hub_model):
475+
mock_get_hub_model.return_value = JumpStartModelSpecs(spec=copy.deepcopy(BASE_SPEC))
476476

477477
model_id, version = "pytorch-ic-mobilenet-v2", "1.0.2"
478478
hub_arn = "arn:aws:sagemaker:us-west-2:000000000000:hub/my-cool-hub"
@@ -495,15 +495,15 @@ def test_retrieve_default_environment_variables(self, mock_cache):
495495
"SAGEMAKER_MODEL_SERVER_WORKERS": "1",
496496
},
497497
)
498-
mock_cache.get_hub_model.assert_called_once_with(
498+
mock_get_hub_model.assert_called_once_with(
499499
hub_model_arn=(
500500
f"arn:aws:sagemaker:us-west-2:000000000000:hub-content/my-cool-hub/Model/{model_id}/{version}"
501501
)
502502
)
503503

504-
@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._cache")
505-
def test_retrieve_image_uri(self, mock_cache):
506-
mock_cache.get_hub_model.return_value = JumpStartModelSpecs(spec=copy.deepcopy(BASE_SPEC))
504+
@patch("sagemaker.jumpstart.cache.JumpStartModelsCache.get_hub_model")
505+
def test_retrieve_image_uri(self, mock_get_hub_model):
506+
mock_get_hub_model.return_value = JumpStartModelSpecs(spec=copy.deepcopy(BASE_SPEC))
507507

508508
model_id, version = "pytorch-ic-mobilenet-v2", "1.0.2"
509509
hub_arn = "arn:aws:sagemaker:us-west-2:000000000000:hub/my-cool-hub"
@@ -518,15 +518,15 @@ def test_retrieve_image_uri(self, mock_cache):
518518
),
519519
"763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-training:1.5.0-gpu-py3",
520520
)
521-
mock_cache.get_hub_model.assert_called_once_with(
521+
mock_get_hub_model.assert_called_once_with(
522522
hub_model_arn=(
523523
f"arn:aws:sagemaker:us-west-2:000000000000:hub-content/my-cool-hub/Model/{model_id}/{version}"
524524
)
525525
)
526526

527-
@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._cache")
528-
def test_retrieve_default_hyperparameters(self, mock_cache):
529-
mock_cache.get_hub_model.return_value = JumpStartModelSpecs(spec=copy.deepcopy(BASE_SPEC))
527+
@patch("sagemaker.jumpstart.cache.JumpStartModelsCache.get_hub_model")
528+
def test_retrieve_default_hyperparameters(self, mock_get_hub_model):
529+
mock_get_hub_model.return_value = JumpStartModelSpecs(spec=copy.deepcopy(BASE_SPEC))
530530

531531
model_id, version = "pytorch-ic-mobilenet-v2", "1.0.2"
532532
hub_arn = "arn:aws:sagemaker:us-west-2:000000000000:hub/my-cool-hub"
@@ -541,15 +541,15 @@ def test_retrieve_default_hyperparameters(self, mock_cache):
541541
"batch-size": "4",
542542
},
543543
)
544-
mock_cache.get_hub_model.assert_called_once_with(
544+
mock_get_hub_model.assert_called_once_with(
545545
hub_model_arn=(
546546
f"arn:aws:sagemaker:us-west-2:000000000000:hub-content/my-cool-hub/Model/{model_id}/{version}"
547547
)
548548
)
549549

550-
@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._cache")
551-
def test_model_supports_incremental_training(self, mock_cache):
552-
mock_cache.get_hub_model.return_value = JumpStartModelSpecs(spec=copy.deepcopy(BASE_SPEC))
550+
@patch("sagemaker.jumpstart.cache.JumpStartModelsCache.get_hub_model")
551+
def test_model_supports_incremental_training(self, mock_get_hub_model):
552+
mock_get_hub_model.return_value = JumpStartModelSpecs(spec=copy.deepcopy(BASE_SPEC))
553553

554554
model_id, version = "pytorch-ic-mobilenet-v2", "1.0.2"
555555
hub_arn = "arn:aws:sagemaker:us-west-2:000000000000:hub/my-cool-hub"
@@ -560,15 +560,15 @@ def test_model_supports_incremental_training(self, mock_cache):
560560
),
561561
True,
562562
)
563-
mock_cache.get_hub_model.assert_called_once_with(
563+
mock_get_hub_model.assert_called_once_with(
564564
hub_model_arn=(
565565
f"arn:aws:sagemaker:us-west-2:000000000000:hub-content/my-cool-hub/Model/{model_id}/{version}"
566566
)
567567
)
568568

569-
@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._cache")
570-
def test_retrieve_default_instance_type(self, mock_cache):
571-
mock_cache.get_hub_model.return_value = JumpStartModelSpecs(spec=copy.deepcopy(BASE_SPEC))
569+
@patch("sagemaker.jumpstart.cache.JumpStartModelsCache.get_hub_model")
570+
def test_retrieve_default_instance_type(self, mock_get_hub_model):
571+
mock_get_hub_model.return_value = JumpStartModelSpecs(spec=copy.deepcopy(BASE_SPEC))
572572

573573
model_id, version = "pytorch-ic-mobilenet-v2", "1.0.2"
574574
hub_arn = "arn:aws:sagemaker:us-west-2:000000000000:hub/my-cool-hub"
@@ -582,7 +582,7 @@ def test_retrieve_default_instance_type(self, mock_cache):
582582
),
583583
"ml.p3.2xlarge",
584584
)
585-
mock_cache.get_hub_model.assert_called_once_with(
585+
mock_get_hub_model.assert_called_once_with(
586586
hub_model_arn=(
587587
f"arn:aws:sagemaker:us-west-2:000000000000:hub-content/my-cool-hub/Model/{model_id}/{version}"
588588
)
@@ -598,9 +598,9 @@ def test_retrieve_default_instance_type(self, mock_cache):
598598
"ml.p2.xlarge",
599599
)
600600

601-
@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._cache")
602-
def test_retrieve_default_training_metric_definitions(self, mock_cache):
603-
mock_cache.get_hub_model.return_value = JumpStartModelSpecs(spec=copy.deepcopy(BASE_SPEC))
601+
@patch("sagemaker.jumpstart.cache.JumpStartModelsCache.get_hub_model")
602+
def test_retrieve_default_training_metric_definitions(self, mock_get_hub_model):
603+
mock_get_hub_model.return_value = JumpStartModelSpecs(spec=copy.deepcopy(BASE_SPEC))
604604

605605
model_id, version = "pytorch-ic-mobilenet-v2", "1.0.2"
606606
hub_arn = "arn:aws:sagemaker:us-west-2:000000000000:hub/my-cool-hub"
@@ -611,15 +611,15 @@ def test_retrieve_default_training_metric_definitions(self, mock_cache):
611611
),
612612
[{"Regex": "val_accuracy: ([0-9\\.]+)", "Name": "pytorch-ic:val-accuracy"}],
613613
)
614-
mock_cache.get_hub_model.assert_called_once_with(
614+
mock_get_hub_model.assert_called_once_with(
615615
hub_model_arn=(
616616
f"arn:aws:sagemaker:us-west-2:000000000000:hub-content/my-cool-hub/Model/{model_id}/{version}"
617617
)
618618
)
619619

620-
@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._cache")
621-
def test_retrieve_model_uri(self, mock_cache):
622-
mock_cache.get_hub_model.return_value = JumpStartModelSpecs(spec=copy.deepcopy(BASE_SPEC))
620+
@patch("sagemaker.jumpstart.cache.JumpStartModelsCache.get_hub_model")
621+
def test_retrieve_model_uri(self, mock_get_hub_model):
622+
mock_get_hub_model.return_value = JumpStartModelSpecs(spec=copy.deepcopy(BASE_SPEC))
623623

624624
model_id, version = "pytorch-ic-mobilenet-v2", "1.0.2"
625625
hub_arn = "arn:aws:sagemaker:us-west-2:000000000000:hub/my-cool-hub"
@@ -630,7 +630,7 @@ def test_retrieve_model_uri(self, mock_cache):
630630
),
631631
"s3://jumpstart-cache-prod-us-west-2/pytorch-training/train-pytorch-ic-mobilenet-v2.tar.gz",
632632
)
633-
mock_cache.get_hub_model.assert_called_once_with(
633+
mock_get_hub_model.assert_called_once_with(
634634
hub_model_arn=(
635635
f"arn:aws:sagemaker:us-west-2:000000000000:hub-content/my-cool-hub/Model/{model_id}/{version}"
636636
)
@@ -643,9 +643,9 @@ def test_retrieve_model_uri(self, mock_cache):
643643
"s3://jumpstart-cache-prod-us-west-2/pytorch-infer/infer-pytorch-ic-mobilenet-v2.tar.gz",
644644
)
645645

646-
@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._cache")
647-
def test_retrieve_script_uri(self, mock_cache):
648-
mock_cache.get_hub_model.return_value = JumpStartModelSpecs(spec=copy.deepcopy(BASE_SPEC))
646+
@patch("sagemaker.jumpstart.cache.JumpStartModelsCache.get_hub_model")
647+
def test_retrieve_script_uri(self, mock_get_hub_model):
648+
mock_get_hub_model.return_value = JumpStartModelSpecs(spec=copy.deepcopy(BASE_SPEC))
649649

650650
model_id, version = "pytorch-ic-mobilenet-v2", "1.0.2"
651651
hub_arn = "arn:aws:sagemaker:us-west-2:000000000000:hub/my-cool-hub"
@@ -660,7 +660,7 @@ def test_retrieve_script_uri(self, mock_cache):
660660
"s3://jumpstart-cache-prod-us-west-2/source-directory-tarballs/pytorch/"
661661
"transfer_learning/ic/v1.0.0/sourcedir.tar.gz",
662662
)
663-
mock_cache.get_hub_model.assert_called_once_with(
663+
mock_get_hub_model.assert_called_once_with(
664664
hub_model_arn=(
665665
f"arn:aws:sagemaker:us-west-2:000000000000:hub-content/my-cool-hub/Model/{model_id}/{version}"
666666
)

0 commit comments

Comments
 (0)