@@ -470,9 +470,9 @@ def test_retrieve_uri_from_gated_bucket(self, patched_get_model_specs):
470
470
471
471
472
472
class HubModelTest (unittest .TestCase ):
473
- @patch ("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._cache " )
474
- def test_retrieve_default_environment_variables (self , mock_cache ):
475
- mock_cache . get_hub_model .return_value = JumpStartModelSpecs (spec = copy .deepcopy (BASE_SPEC ))
473
+ @patch ("sagemaker.jumpstart.cache.JumpStartModelsCache.get_hub_model " )
474
+ def test_retrieve_default_environment_variables (self , mock_get_hub_model ):
475
+ mock_get_hub_model .return_value = JumpStartModelSpecs (spec = copy .deepcopy (BASE_SPEC ))
476
476
477
477
model_id , version = "pytorch-ic-mobilenet-v2" , "1.0.2"
478
478
hub_arn = "arn:aws:sagemaker:us-west-2:000000000000:hub/my-cool-hub"
@@ -495,15 +495,15 @@ def test_retrieve_default_environment_variables(self, mock_cache):
495
495
"SAGEMAKER_MODEL_SERVER_WORKERS" : "1" ,
496
496
},
497
497
)
498
- mock_cache . get_hub_model .assert_called_once_with (
498
+ mock_get_hub_model .assert_called_once_with (
499
499
hub_model_arn = (
500
500
f"arn:aws:sagemaker:us-west-2:000000000000:hub-content/my-cool-hub/Model/{ model_id } /{ version } "
501
501
)
502
502
)
503
503
504
- @patch ("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._cache " )
505
- def test_retrieve_image_uri (self , mock_cache ):
506
- mock_cache . get_hub_model .return_value = JumpStartModelSpecs (spec = copy .deepcopy (BASE_SPEC ))
504
+ @patch ("sagemaker.jumpstart.cache.JumpStartModelsCache.get_hub_model " )
505
+ def test_retrieve_image_uri (self , mock_get_hub_model ):
506
+ mock_get_hub_model .return_value = JumpStartModelSpecs (spec = copy .deepcopy (BASE_SPEC ))
507
507
508
508
model_id , version = "pytorch-ic-mobilenet-v2" , "1.0.2"
509
509
hub_arn = "arn:aws:sagemaker:us-west-2:000000000000:hub/my-cool-hub"
@@ -518,15 +518,15 @@ def test_retrieve_image_uri(self, mock_cache):
518
518
),
519
519
"763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-training:1.5.0-gpu-py3" ,
520
520
)
521
- mock_cache . get_hub_model .assert_called_once_with (
521
+ mock_get_hub_model .assert_called_once_with (
522
522
hub_model_arn = (
523
523
f"arn:aws:sagemaker:us-west-2:000000000000:hub-content/my-cool-hub/Model/{ model_id } /{ version } "
524
524
)
525
525
)
526
526
527
- @patch ("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._cache " )
528
- def test_retrieve_default_hyperparameters (self , mock_cache ):
529
- mock_cache . get_hub_model .return_value = JumpStartModelSpecs (spec = copy .deepcopy (BASE_SPEC ))
527
+ @patch ("sagemaker.jumpstart.cache.JumpStartModelsCache.get_hub_model " )
528
+ def test_retrieve_default_hyperparameters (self , mock_get_hub_model ):
529
+ mock_get_hub_model .return_value = JumpStartModelSpecs (spec = copy .deepcopy (BASE_SPEC ))
530
530
531
531
model_id , version = "pytorch-ic-mobilenet-v2" , "1.0.2"
532
532
hub_arn = "arn:aws:sagemaker:us-west-2:000000000000:hub/my-cool-hub"
@@ -541,15 +541,15 @@ def test_retrieve_default_hyperparameters(self, mock_cache):
541
541
"batch-size" : "4" ,
542
542
},
543
543
)
544
- mock_cache . get_hub_model .assert_called_once_with (
544
+ mock_get_hub_model .assert_called_once_with (
545
545
hub_model_arn = (
546
546
f"arn:aws:sagemaker:us-west-2:000000000000:hub-content/my-cool-hub/Model/{ model_id } /{ version } "
547
547
)
548
548
)
549
549
550
- @patch ("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._cache " )
551
- def test_model_supports_incremental_training (self , mock_cache ):
552
- mock_cache . get_hub_model .return_value = JumpStartModelSpecs (spec = copy .deepcopy (BASE_SPEC ))
550
+ @patch ("sagemaker.jumpstart.cache.JumpStartModelsCache.get_hub_model " )
551
+ def test_model_supports_incremental_training (self , mock_get_hub_model ):
552
+ mock_get_hub_model .return_value = JumpStartModelSpecs (spec = copy .deepcopy (BASE_SPEC ))
553
553
554
554
model_id , version = "pytorch-ic-mobilenet-v2" , "1.0.2"
555
555
hub_arn = "arn:aws:sagemaker:us-west-2:000000000000:hub/my-cool-hub"
@@ -560,15 +560,15 @@ def test_model_supports_incremental_training(self, mock_cache):
560
560
),
561
561
True ,
562
562
)
563
- mock_cache . get_hub_model .assert_called_once_with (
563
+ mock_get_hub_model .assert_called_once_with (
564
564
hub_model_arn = (
565
565
f"arn:aws:sagemaker:us-west-2:000000000000:hub-content/my-cool-hub/Model/{ model_id } /{ version } "
566
566
)
567
567
)
568
568
569
- @patch ("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._cache " )
570
- def test_retrieve_default_instance_type (self , mock_cache ):
571
- mock_cache . get_hub_model .return_value = JumpStartModelSpecs (spec = copy .deepcopy (BASE_SPEC ))
569
+ @patch ("sagemaker.jumpstart.cache.JumpStartModelsCache.get_hub_model " )
570
+ def test_retrieve_default_instance_type (self , mock_get_hub_model ):
571
+ mock_get_hub_model .return_value = JumpStartModelSpecs (spec = copy .deepcopy (BASE_SPEC ))
572
572
573
573
model_id , version = "pytorch-ic-mobilenet-v2" , "1.0.2"
574
574
hub_arn = "arn:aws:sagemaker:us-west-2:000000000000:hub/my-cool-hub"
@@ -582,7 +582,7 @@ def test_retrieve_default_instance_type(self, mock_cache):
582
582
),
583
583
"ml.p3.2xlarge" ,
584
584
)
585
- mock_cache . get_hub_model .assert_called_once_with (
585
+ mock_get_hub_model .assert_called_once_with (
586
586
hub_model_arn = (
587
587
f"arn:aws:sagemaker:us-west-2:000000000000:hub-content/my-cool-hub/Model/{ model_id } /{ version } "
588
588
)
@@ -598,9 +598,9 @@ def test_retrieve_default_instance_type(self, mock_cache):
598
598
"ml.p2.xlarge" ,
599
599
)
600
600
601
- @patch ("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._cache " )
602
- def test_retrieve_default_training_metric_definitions (self , mock_cache ):
603
- mock_cache . get_hub_model .return_value = JumpStartModelSpecs (spec = copy .deepcopy (BASE_SPEC ))
601
+ @patch ("sagemaker.jumpstart.cache.JumpStartModelsCache.get_hub_model " )
602
+ def test_retrieve_default_training_metric_definitions (self , mock_get_hub_model ):
603
+ mock_get_hub_model .return_value = JumpStartModelSpecs (spec = copy .deepcopy (BASE_SPEC ))
604
604
605
605
model_id , version = "pytorch-ic-mobilenet-v2" , "1.0.2"
606
606
hub_arn = "arn:aws:sagemaker:us-west-2:000000000000:hub/my-cool-hub"
@@ -611,15 +611,15 @@ def test_retrieve_default_training_metric_definitions(self, mock_cache):
611
611
),
612
612
[{"Regex" : "val_accuracy: ([0-9\\ .]+)" , "Name" : "pytorch-ic:val-accuracy" }],
613
613
)
614
- mock_cache . get_hub_model .assert_called_once_with (
614
+ mock_get_hub_model .assert_called_once_with (
615
615
hub_model_arn = (
616
616
f"arn:aws:sagemaker:us-west-2:000000000000:hub-content/my-cool-hub/Model/{ model_id } /{ version } "
617
617
)
618
618
)
619
619
620
- @patch ("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._cache " )
621
- def test_retrieve_model_uri (self , mock_cache ):
622
- mock_cache . get_hub_model .return_value = JumpStartModelSpecs (spec = copy .deepcopy (BASE_SPEC ))
620
+ @patch ("sagemaker.jumpstart.cache.JumpStartModelsCache.get_hub_model " )
621
+ def test_retrieve_model_uri (self , mock_get_hub_model ):
622
+ mock_get_hub_model .return_value = JumpStartModelSpecs (spec = copy .deepcopy (BASE_SPEC ))
623
623
624
624
model_id , version = "pytorch-ic-mobilenet-v2" , "1.0.2"
625
625
hub_arn = "arn:aws:sagemaker:us-west-2:000000000000:hub/my-cool-hub"
@@ -630,7 +630,7 @@ def test_retrieve_model_uri(self, mock_cache):
630
630
),
631
631
"s3://jumpstart-cache-prod-us-west-2/pytorch-training/train-pytorch-ic-mobilenet-v2.tar.gz" ,
632
632
)
633
- mock_cache . get_hub_model .assert_called_once_with (
633
+ mock_get_hub_model .assert_called_once_with (
634
634
hub_model_arn = (
635
635
f"arn:aws:sagemaker:us-west-2:000000000000:hub-content/my-cool-hub/Model/{ model_id } /{ version } "
636
636
)
@@ -643,9 +643,9 @@ def test_retrieve_model_uri(self, mock_cache):
643
643
"s3://jumpstart-cache-prod-us-west-2/pytorch-infer/infer-pytorch-ic-mobilenet-v2.tar.gz" ,
644
644
)
645
645
646
- @patch ("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._cache " )
647
- def test_retrieve_script_uri (self , mock_cache ):
648
- mock_cache . get_hub_model .return_value = JumpStartModelSpecs (spec = copy .deepcopy (BASE_SPEC ))
646
+ @patch ("sagemaker.jumpstart.cache.JumpStartModelsCache.get_hub_model " )
647
+ def test_retrieve_script_uri (self , mock_get_hub_model ):
648
+ mock_get_hub_model .return_value = JumpStartModelSpecs (spec = copy .deepcopy (BASE_SPEC ))
649
649
650
650
model_id , version = "pytorch-ic-mobilenet-v2" , "1.0.2"
651
651
hub_arn = "arn:aws:sagemaker:us-west-2:000000000000:hub/my-cool-hub"
@@ -660,7 +660,7 @@ def test_retrieve_script_uri(self, mock_cache):
660
660
"s3://jumpstart-cache-prod-us-west-2/source-directory-tarballs/pytorch/"
661
661
"transfer_learning/ic/v1.0.0/sourcedir.tar.gz" ,
662
662
)
663
- mock_cache . get_hub_model .assert_called_once_with (
663
+ mock_get_hub_model .assert_called_once_with (
664
664
hub_model_arn = (
665
665
f"arn:aws:sagemaker:us-west-2:000000000000:hub-content/my-cool-hub/Model/{ model_id } /{ version } "
666
666
)
0 commit comments