@@ -191,6 +191,16 @@ def sagemaker_session():
191
191
return sms
192
192
193
193
194
+ @pytest .fixture ()
195
+ def training_job_description (sagemaker_session ):
196
+ returned_job_description = RETURNED_JOB_DESCRIPTION .copy ()
197
+ mock_describe_training_job = Mock (
198
+ name = "describe_training_job" , return_value = returned_job_description
199
+ )
200
+ sagemaker_session .sagemaker_client .describe_training_job = mock_describe_training_job
201
+ return returned_job_description
202
+
203
+
194
204
def test_framework_all_init_args (sagemaker_session ):
195
205
f = DummyFramework (
196
206
"my_script.py" ,
@@ -651,13 +661,9 @@ def test_enable_cloudwatch_metrics(sagemaker_session):
651
661
assert train_kwargs ["hyperparameters" ]["sagemaker_enable_cloudwatch_metrics" ]
652
662
653
663
654
- def test_attach_framework (sagemaker_session ):
655
- returned_job_description = RETURNED_JOB_DESCRIPTION .copy ()
656
- returned_job_description ["VpcConfig" ] = {"Subnets" : ["foo" ], "SecurityGroupIds" : ["bar" ]}
657
- returned_job_description ["EnableNetworkIsolation" ] = True
658
- sagemaker_session .sagemaker_client .describe_training_job = Mock (
659
- name = "describe_training_job" , return_value = returned_job_description
660
- )
664
+ def test_attach_framework (sagemaker_session , training_job_description ):
665
+ training_job_description ["VpcConfig" ] = {"Subnets" : ["foo" ], "SecurityGroupIds" : ["bar" ]}
666
+ training_job_description ["EnableNetworkIsolation" ] = True
661
667
662
668
framework_estimator = DummyFramework .attach (
663
669
training_job_name = "neo" , sagemaker_session = sagemaker_session
@@ -681,50 +687,25 @@ def test_attach_framework(sagemaker_session):
681
687
assert framework_estimator .enable_network_isolation () is True
682
688
683
689
684
- def test_attach_no_logs (sagemaker_session ):
685
- returned_job_description = RETURNED_JOB_DESCRIPTION .copy ()
686
- mock_describe_training_job = Mock (
687
- name = "describe_training_job" , return_value = returned_job_description
688
- )
689
- sagemaker_session .sagemaker_client .describe_training_job = mock_describe_training_job
690
+ def test_attach_no_logs (sagemaker_session , training_job_description ):
690
691
Estimator .attach (training_job_name = "job" , sagemaker_session = sagemaker_session )
691
692
sagemaker_session .logs_for_job .assert_not_called ()
692
693
693
694
694
- def test_logs (sagemaker_session ):
695
- returned_job_description = RETURNED_JOB_DESCRIPTION .copy ()
696
- mock_describe_training_job = Mock (
697
- name = "describe_training_job" , return_value = returned_job_description
698
- )
699
- sagemaker_session .sagemaker_client .describe_training_job = mock_describe_training_job
695
+ def test_logs (sagemaker_session , training_job_description ):
700
696
estimator = Estimator .attach (training_job_name = "job" , sagemaker_session = sagemaker_session )
701
697
estimator .logs ()
702
698
sagemaker_session .logs_for_job .assert_called_with (estimator .latest_training_job , wait = True )
703
699
704
700
705
- def test_attach_without_hyperparameters (sagemaker_session ):
706
- returned_job_description = RETURNED_JOB_DESCRIPTION .copy ()
707
- del returned_job_description ["HyperParameters" ]
708
-
709
- mock_describe_training_job = Mock (
710
- name = "describe_training_job" , return_value = returned_job_description
711
- )
712
- sagemaker_session .sagemaker_client .describe_training_job = mock_describe_training_job
713
-
701
+ def test_attach_without_hyperparameters (sagemaker_session , training_job_description ):
702
+ del training_job_description ["HyperParameters" ]
714
703
estimator = Estimator .attach (training_job_name = "job" , sagemaker_session = sagemaker_session )
715
-
716
704
assert estimator .hyperparameters () == {}
717
705
718
706
719
- def test_attach_framework_with_tuning (sagemaker_session ):
720
- returned_job_description = RETURNED_JOB_DESCRIPTION .copy ()
721
- returned_job_description ["HyperParameters" ]["_tuning_objective_metric" ] = "Validation-accuracy"
722
-
723
- mock_describe_training_job = Mock (
724
- name = "describe_training_job" , return_value = returned_job_description
725
- )
726
- sagemaker_session .sagemaker_client .describe_training_job = mock_describe_training_job
727
-
707
+ def test_attach_framework_with_tuning (sagemaker_session , training_job_description ):
708
+ training_job_description ["HyperParameters" ]["_tuning_objective_metric" ] = "Validation-accuracy"
728
709
framework_estimator = DummyFramework .attach (
729
710
training_job_name = "neo" , sagemaker_session = sagemaker_session
730
711
)
@@ -744,48 +725,35 @@ def test_attach_framework_with_tuning(sagemaker_session):
744
725
assert framework_estimator .encrypt_inter_container_traffic is False
745
726
746
727
747
- def test_attach_framework_with_model_channel (sagemaker_session ):
728
+ def test_attach_framework_with_model_channel (sagemaker_session , training_job_description ):
748
729
s3_uri = "s3://some/s3/path/model.tar.gz"
749
- returned_job_description = RETURNED_JOB_DESCRIPTION .copy ()
750
- returned_job_description ["InputDataConfig" ] = [
730
+ training_job_description ["InputDataConfig" ] = [
751
731
{
752
732
"ChannelName" : "model" ,
753
733
"InputMode" : "File" ,
754
734
"DataSource" : {"S3DataSource" : {"S3Uri" : s3_uri }},
755
735
}
756
736
]
757
737
758
- sagemaker_session .sagemaker_client .describe_training_job = Mock (
759
- name = "describe_training_job" , return_value = returned_job_description
760
- )
761
-
762
738
framework_estimator = DummyFramework .attach (
763
739
training_job_name = "neo" , sagemaker_session = sagemaker_session
764
740
)
765
741
assert framework_estimator .model_uri is s3_uri
766
742
assert framework_estimator .encrypt_inter_container_traffic is False
767
743
768
744
769
- def test_attach_framework_with_inter_container_traffic_encryption_flag (sagemaker_session ):
770
- returned_job_description = RETURNED_JOB_DESCRIPTION .copy ()
771
- returned_job_description ["EnableInterContainerTrafficEncryption" ] = True
772
-
773
- sagemaker_session .sagemaker_client .describe_training_job = Mock (
774
- name = "describe_training_job" , return_value = returned_job_description
775
- )
776
-
745
+ def test_attach_framework_with_inter_container_traffic_encryption_flag (
746
+ sagemaker_session , training_job_description
747
+ ):
748
+ training_job_description ["EnableInterContainerTrafficEncryption" ] = True
777
749
framework_estimator = DummyFramework .attach (
778
750
training_job_name = "neo" , sagemaker_session = sagemaker_session
779
751
)
780
752
781
753
assert framework_estimator .encrypt_inter_container_traffic is True
782
754
783
755
784
- def test_attach_framework_base_from_generated_name (sagemaker_session ):
785
- sagemaker_session .sagemaker_client .describe_training_job = Mock (
786
- name = "describe_training_job" , return_value = RETURNED_JOB_DESCRIPTION
787
- )
788
-
756
+ def test_attach_framework_base_from_generated_name (sagemaker_session , training_job_description ):
789
757
base_job_name = "neo"
790
758
framework_estimator = DummyFramework .attach (
791
759
training_job_name = utils .name_from_base ("neo" ), sagemaker_session = sagemaker_session
0 commit comments