Skip to content

Commit 17d7948

Browse files
committed
fix unit tests
1 parent 00563d1 commit 17d7948

File tree

2 files changed

+5
-78
lines changed

2 files changed

+5
-78
lines changed

tests/unit/test_fw_utils.py

+2-76
Original file line numberDiff line numberDiff line change
@@ -853,18 +853,12 @@ def test_validate_smdataparallel_args_raises():
853853
smdataparallel_enabled = {"smdistributed": {"dataparallel": {"enabled": True}}}
854854

855855
# Cases {PT|TF2}
856-
# 1. None instance type
857-
# 2. incorrect instance type
858-
# 3. incorrect python version
859-
# 4. incorrect framework version
856+
# 1. incorrect python version
857+
# 2. incorrect framework version
860858

861859
bad_args = [
862-
(None, "tensorflow", "2.3.1", "py3", smdataparallel_enabled),
863-
("ml.p3.2xlarge", "tensorflow", "2.3.1", "py3", smdataparallel_enabled),
864860
("ml.p3dn.24xlarge", "tensorflow", "2.3.1", "py2", smdataparallel_enabled),
865861
("ml.p3.16xlarge", "tensorflow", "1.3.1", "py3", smdataparallel_enabled),
866-
(None, "pytorch", "1.6.0", "py3", smdataparallel_enabled),
867-
("ml.p3.2xlarge", "pytorch", "1.6.0", "py3", smdataparallel_enabled),
868862
("ml.p3dn.24xlarge", "pytorch", "1.6.0", "py2", smdataparallel_enabled),
869863
("ml.p3.16xlarge", "pytorch", "1.5.0", "py3", smdataparallel_enabled),
870864
]
@@ -966,74 +960,6 @@ def test_validate_smdataparallel_args_not_raises():
966960
)
967961

968962

969-
def test_validate_pytorchddp_not_raises():
970-
# Case 1: Framework is not PyTorch
971-
fw_utils.validate_pytorch_distribution(
972-
distribution=None,
973-
framework_name="tensorflow",
974-
framework_version="2.9.1",
975-
py_version="py3",
976-
image_uri="custom-container",
977-
)
978-
# Case 2: Framework is PyTorch, but distribution is not PyTorchDDP
979-
pytorchddp_disabled = {"pytorchddp": {"enabled": False}}
980-
fw_utils.validate_pytorch_distribution(
981-
distribution=pytorchddp_disabled,
982-
framework_name="pytorch",
983-
framework_version="1.10",
984-
py_version="py3",
985-
image_uri="custom-container",
986-
)
987-
# Case 3: Framework is PyTorch, Distribution is PyTorchDDP enabled, supported framework and py versions
988-
pytorchddp_enabled = {"pytorchddp": {"enabled": True}}
989-
pytorchddp_supported_fw_versions = [
990-
"1.10",
991-
"1.10.0",
992-
"1.10.2",
993-
"1.11",
994-
"1.11.0",
995-
"1.12",
996-
"1.12.0",
997-
"1.12.1",
998-
"1.13.1",
999-
"2.0.0",
1000-
"2.0.1",
1001-
"2.1.0",
1002-
"2.2.0",
1003-
]
1004-
for framework_version in pytorchddp_supported_fw_versions:
1005-
fw_utils.validate_pytorch_distribution(
1006-
distribution=pytorchddp_enabled,
1007-
framework_name="pytorch",
1008-
framework_version=framework_version,
1009-
py_version="py3",
1010-
image_uri="custom-container",
1011-
)
1012-
1013-
1014-
def test_validate_pytorchddp_raises():
1015-
pytorchddp_enabled = {"pytorchddp": {"enabled": True}}
1016-
# Case 1: Unsupported framework version
1017-
with pytest.raises(ValueError):
1018-
fw_utils.validate_pytorch_distribution(
1019-
distribution=pytorchddp_enabled,
1020-
framework_name="pytorch",
1021-
framework_version="1.8",
1022-
py_version="py3",
1023-
image_uri=None,
1024-
)
1025-
1026-
# Case 2: Unsupported Py version
1027-
with pytest.raises(ValueError):
1028-
fw_utils.validate_pytorch_distribution(
1029-
distribution=pytorchddp_enabled,
1030-
framework_name="pytorch",
1031-
framework_version="1.10",
1032-
py_version="py2",
1033-
image_uri=None,
1034-
)
1035-
1036-
1037963
def test_validate_torch_distributed_not_raises():
1038964
# Case 1: Framework is PyTorch, but torch_distributed is not enabled
1039965
torch_distributed_disabled = {"torch_distributed": {"enabled": False}}

tests/unit/test_pytorch.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -801,14 +801,15 @@ def test_pytorch_ddp_distribution_configuration(
801801
distribution=pytorch.distribution
802802
)
803803
expected_torch_ddp = {
804-
"sagemaker_pytorch_ddp_enabled": True,
804+
"sagemaker_distributed_dataparallel_enabled": True,
805+
"sagemaker_distributed_dataparallel_custom_mpi_options": "",
805806
"sagemaker_instance_type": test_instance_type,
806807
}
807808
assert actual_pytorch_ddp == expected_torch_ddp
808809

809810

810811
def test_pytorch_ddp_distribution_configuration_unsupported(sagemaker_session):
811-
unsupported_framework_version = "1.9.1"
812+
unsupported_framework_version = "1.5.0"
812813
unsupported_py_version = "py2"
813814
with pytest.raises(ValueError) as error:
814815
_pytorch_estimator(

0 commit comments

Comments
 (0)