@@ -49,6 +49,15 @@ def sagemaker_session():
49
49
session_mock .sagemaker_client .describe_training_job = Mock (
50
50
return_value = {"ModelArtifacts" : {"S3ModelArtifacts" : "s3://m/m.tar.gz" }}
51
51
)
52
+ session_mock .boto_session .client ("ec2" ).describe_instance_types = Mock (
53
+ return_value = {
54
+ "InstanceTypes" : [
55
+ {
56
+ "CpuInfo" : {},
57
+ },
58
+ ],
59
+ }
60
+ )
52
61
return session_mock
53
62
54
63
@@ -733,6 +742,31 @@ def test_validate_smdistributed_not_raises():
733
742
)
734
743
735
744
745
+ def test_validate_distribution_instance_no_smdistributed (sagemaker_session ):
746
+ distribution = {}
747
+ instance_type = "mock_type"
748
+ fw_utils .validate_distribution_instance (sagemaker_session , distribution , instance_type )
749
+
750
+
751
+ def test_validate_distribution_instance_no_modelparallel (sagemaker_session ):
752
+ distribution = {"smdistributed" : {}}
753
+ instance_type = "mock_type"
754
+ fw_utils .validate_distribution_instance (sagemaker_session , distribution , instance_type )
755
+
756
+
757
+ def test_validate_distribution_instance_disabled_modelparallel (sagemaker_session ):
758
+ distribution = {"smdistributed" : {"modelparallel" : {"enabled" : False }}}
759
+ instance_type = "mock_type"
760
+ fw_utils .validate_distribution_instance (sagemaker_session , distribution , instance_type )
761
+
762
+
763
+ def test_validate_distribution_instance_raise (sagemaker_session ):
764
+ distribution = {"smdistributed" : {"modelparallel" : {"enabled" : True }}}
765
+ instance_type = "mock_type"
766
+ with pytest .raises (ValueError ):
767
+ fw_utils .validate_distribution_instance (sagemaker_session , distribution , instance_type )
768
+
769
+
736
770
def test_validate_smdistributed_raises ():
737
771
bad_args = [
738
772
{"smdistributed" : "dummy" },
0 commit comments