11
11
# ANY KIND, either express or implied. See the License for the specific
12
12
# language governing permissions and limitations under the License.
13
13
from __future__ import absolute_import
14
- from importlib import reload
15
14
import unittest
16
15
from unittest .mock import Mock
17
16
21
20
22
21
import copy
23
22
from sagemaker .jumpstart import artifacts
24
- from sagemaker .jumpstart .artifacts .environment_variables import _retrieve_default_environment_variables
23
+ from sagemaker .jumpstart .artifacts .environment_variables import (
24
+ _retrieve_default_environment_variables ,
25
+ )
25
26
from sagemaker .jumpstart .artifacts .hyperparameters import _retrieve_default_hyperparameters
26
27
from sagemaker .jumpstart .artifacts .image_uris import _retrieve_image_uri
27
28
from sagemaker .jumpstart .artifacts .incremental_training import _model_supports_incremental_training
28
29
from sagemaker .jumpstart .artifacts .instance_types import _retrieve_default_instance_type
29
- from sagemaker .jumpstart .artifacts .metric_definitions import _retrieve_default_training_metric_definitions
30
+ from sagemaker .jumpstart .artifacts .metric_definitions import (
31
+ _retrieve_default_training_metric_definitions ,
32
+ )
30
33
from sagemaker .jumpstart .artifacts .model_uris import (
31
34
_retrieve_hosting_prepacked_artifact_key ,
32
35
_retrieve_hosting_artifact_key ,
@@ -467,7 +470,6 @@ def test_retrieve_uri_from_gated_bucket(self, patched_get_model_specs):
467
470
468
471
469
472
class HubModelTest (unittest .TestCase ):
470
-
471
473
@patch ("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._cache" )
472
474
def test_retrieve_default_environment_variables (self , mock_cache ):
473
475
mock_cache .get_hub_model .return_value = JumpStartModelSpecs (spec = copy .deepcopy (BASE_SPEC ))
@@ -477,7 +479,10 @@ def test_retrieve_default_environment_variables(self, mock_cache):
477
479
478
480
self .assertEqual (
479
481
_retrieve_default_environment_variables (
480
- model_id = model_id , model_version = version , hub_arn = hub_arn , script = JumpStartScriptScope .INFERENCE
482
+ model_id = model_id ,
483
+ model_version = version ,
484
+ hub_arn = hub_arn ,
485
+ script = JumpStartScriptScope .INFERENCE ,
481
486
),
482
487
{
483
488
"SAGEMAKER_PROGRAM" : "inference.py" ,
@@ -487,16 +492,15 @@ def test_retrieve_default_environment_variables(self, mock_cache):
487
492
"ENDPOINT_SERVER_TIMEOUT" : "3600" ,
488
493
"MODEL_CACHE_ROOT" : "/opt/ml/model" ,
489
494
"SAGEMAKER_ENV" : "1" ,
490
- "SAGEMAKER_MODEL_SERVER_WORKERS" : "1"
491
- }
495
+ "SAGEMAKER_MODEL_SERVER_WORKERS" : "1" ,
496
+ },
492
497
)
493
498
mock_cache .get_hub_model .assert_called_once_with (
494
499
hub_model_arn = (
495
500
f"arn:aws:sagemaker:us-west-2:000000000000:hub-content/my-cool-hub/Model/{ model_id } /{ version } "
496
501
)
497
502
)
498
503
499
-
500
504
@patch ("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._cache" )
501
505
def test_retrieve_image_uri (self , mock_cache ):
502
506
mock_cache .get_hub_model .return_value = JumpStartModelSpecs (spec = copy .deepcopy (BASE_SPEC ))
@@ -506,9 +510,13 @@ def test_retrieve_image_uri(self, mock_cache):
506
510
507
511
self .assertEqual (
508
512
_retrieve_image_uri (
509
- model_id = model_id , model_version = version , hub_arn = hub_arn , instance_type = "ml.p3.2xlarge" , image_scope = JumpStartScriptScope .TRAINING
513
+ model_id = model_id ,
514
+ model_version = version ,
515
+ hub_arn = hub_arn ,
516
+ instance_type = "ml.p3.2xlarge" ,
517
+ image_scope = JumpStartScriptScope .TRAINING ,
510
518
),
511
- "763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-training:1.5.0-gpu-py3"
519
+ "763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-training:1.5.0-gpu-py3" ,
512
520
)
513
521
mock_cache .get_hub_model .assert_called_once_with (
514
522
hub_model_arn = (
@@ -531,7 +539,7 @@ def test_retrieve_default_hyperparameters(self, mock_cache):
531
539
"epochs" : "3" ,
532
540
"adam-learning-rate" : "0.05" ,
533
541
"batch-size" : "4" ,
534
- }
542
+ },
535
543
)
536
544
mock_cache .get_hub_model .assert_called_once_with (
537
545
hub_model_arn = (
@@ -550,7 +558,7 @@ def test_model_supports_incremental_training(self, mock_cache):
550
558
_model_supports_incremental_training (
551
559
model_id = model_id , model_version = version , hub_arn = hub_arn , region = "us-west-2"
552
560
),
553
- True
561
+ True ,
554
562
)
555
563
mock_cache .get_hub_model .assert_called_once_with (
556
564
hub_model_arn = (
@@ -567,9 +575,12 @@ def test_retrieve_default_instance_type(self, mock_cache):
567
575
568
576
self .assertEqual (
569
577
_retrieve_default_instance_type (
570
- model_id = model_id , model_version = version , hub_arn = hub_arn , scope = JumpStartScriptScope .TRAINING
578
+ model_id = model_id ,
579
+ model_version = version ,
580
+ hub_arn = hub_arn ,
581
+ scope = JumpStartScriptScope .TRAINING ,
571
582
),
572
- "ml.p3.2xlarge"
583
+ "ml.p3.2xlarge" ,
573
584
)
574
585
mock_cache .get_hub_model .assert_called_once_with (
575
586
hub_model_arn = (
@@ -579,9 +590,12 @@ def test_retrieve_default_instance_type(self, mock_cache):
579
590
580
591
self .assertEqual (
581
592
_retrieve_default_instance_type (
582
- model_id = model_id , model_version = version , hub_arn = hub_arn , scope = JumpStartScriptScope .INFERENCE
593
+ model_id = model_id ,
594
+ model_version = version ,
595
+ hub_arn = hub_arn ,
596
+ scope = JumpStartScriptScope .INFERENCE ,
583
597
),
584
- "ml.p2.xlarge"
598
+ "ml.p2.xlarge" ,
585
599
)
586
600
587
601
@patch ("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._cache" )
@@ -595,15 +609,14 @@ def test_retrieve_default_training_metric_definitions(self, mock_cache):
595
609
_retrieve_default_training_metric_definitions (
596
610
model_id = model_id , model_version = version , hub_arn = hub_arn , region = "us-west-2"
597
611
),
598
- [{"Regex" : "val_accuracy: ([0-9\\ .]+)" , "Name" : "pytorch-ic:val-accuracy" }]
612
+ [{"Regex" : "val_accuracy: ([0-9\\ .]+)" , "Name" : "pytorch-ic:val-accuracy" }],
599
613
)
600
614
mock_cache .get_hub_model .assert_called_once_with (
601
615
hub_model_arn = (
602
616
f"arn:aws:sagemaker:us-west-2:000000000000:hub-content/my-cool-hub/Model/{ model_id } /{ version } "
603
617
)
604
618
)
605
619
606
-
607
620
@patch ("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._cache" )
608
621
def test_retrieve_model_uri (self , mock_cache ):
609
622
mock_cache .get_hub_model .return_value = JumpStartModelSpecs (spec = copy .deepcopy (BASE_SPEC ))
@@ -615,7 +628,7 @@ def test_retrieve_model_uri(self, mock_cache):
615
628
_retrieve_model_uri (
616
629
model_id = model_id , model_version = version , hub_arn = hub_arn , model_scope = "training"
617
630
),
618
- "s3://jumpstart-cache-prod-us-west-2/pytorch-training/train-pytorch-ic-mobilenet-v2.tar.gz"
631
+ "s3://jumpstart-cache-prod-us-west-2/pytorch-training/train-pytorch-ic-mobilenet-v2.tar.gz" ,
619
632
)
620
633
mock_cache .get_hub_model .assert_called_once_with (
621
634
hub_model_arn = (
@@ -627,7 +640,7 @@ def test_retrieve_model_uri(self, mock_cache):
627
640
_retrieve_model_uri (
628
641
model_id = model_id , model_version = version , hub_arn = hub_arn , model_scope = "inference"
629
642
),
630
- "s3://jumpstart-cache-prod-us-west-2/pytorch-infer/infer-pytorch-ic-mobilenet-v2.tar.gz"
643
+ "s3://jumpstart-cache-prod-us-west-2/pytorch-infer/infer-pytorch-ic-mobilenet-v2.tar.gz" ,
631
644
)
632
645
633
646
@patch ("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._cache" )
@@ -639,9 +652,13 @@ def test_retrieve_script_uri(self, mock_cache):
639
652
640
653
self .assertEqual (
641
654
_retrieve_script_uri (
642
- model_id = model_id , model_version = version , hub_arn = hub_arn , script_scope = JumpStartScriptScope .TRAINING
655
+ model_id = model_id ,
656
+ model_version = version ,
657
+ hub_arn = hub_arn ,
658
+ script_scope = JumpStartScriptScope .TRAINING ,
643
659
),
644
- "s3://jumpstart-cache-prod-us-west-2/source-directory-tarballs/pytorch/transfer_learning/ic/v1.0.0/sourcedir.tar.gz"
660
+ "s3://jumpstart-cache-prod-us-west-2/source-directory-tarballs/pytorch/"
661
+ "transfer_learning/ic/v1.0.0/sourcedir.tar.gz" ,
645
662
)
646
663
mock_cache .get_hub_model .assert_called_once_with (
647
664
hub_model_arn = (
@@ -651,7 +668,11 @@ def test_retrieve_script_uri(self, mock_cache):
651
668
652
669
self .assertEqual (
653
670
_retrieve_script_uri (
654
- model_id = model_id , model_version = version , hub_arn = hub_arn , script_scope = JumpStartScriptScope .INFERENCE
671
+ model_id = model_id ,
672
+ model_version = version ,
673
+ hub_arn = hub_arn ,
674
+ script_scope = JumpStartScriptScope .INFERENCE ,
655
675
),
656
- "s3://jumpstart-cache-prod-us-west-2/source-directory-tarballs/pytorch/inference/ic/v1.0.0/sourcedir.tar.gz"
676
+ "s3://jumpstart-cache-prod-us-west-2/source-directory-tarballs/pytorch/"
677
+ "inference/ic/v1.0.0/sourcedir.tar.gz" ,
657
678
)
0 commit comments