Skip to content

Commit f51188d

Browse files
tomboussojiapinw
authored andcommitted
Convert pytorchddp distribution to smdistributed distribution (aws#4698)
* rewrite pytorchddp to smdistributed * remove instance type check * Update estimator.py * remove validate_pytorch_distribution * fix * fix unit tests * fix formatting * check instance type not None
1 parent 0cf3191 commit f51188d

File tree

4 files changed

+21
-172
lines changed

4 files changed

+21
-172
lines changed

src/sagemaker/fw_utils.py

+2-97
Original file line numberDiff line numberDiff line change
@@ -145,22 +145,6 @@
145145
],
146146
}
147147

148-
PYTORCHDDP_SUPPORTED_FRAMEWORK_VERSIONS = [
149-
"1.10",
150-
"1.10.0",
151-
"1.10.2",
152-
"1.11",
153-
"1.11.0",
154-
"1.12",
155-
"1.12.0",
156-
"1.12.1",
157-
"1.13.1",
158-
"2.0.0",
159-
"2.0.1",
160-
"2.1.0",
161-
"2.2.0",
162-
]
163-
164148
TORCH_DISTRIBUTED_GPU_SUPPORTED_FRAMEWORK_VERSIONS = [
165149
"1.13.1",
166150
"2.0.0",
@@ -795,7 +779,6 @@ def _validate_smdataparallel_args(
795779
796780
Raises:
797781
ValueError: if
798-
(`instance_type` is not in SM_DATAPARALLEL_SUPPORTED_INSTANCE_TYPES or
799782
`py_version` is not python3 or
800783
`framework_version` is not in SM_DATAPARALLEL_SUPPORTED_FRAMEWORK_VERSION
801784
"""
@@ -806,17 +789,10 @@ def _validate_smdataparallel_args(
806789
if not smdataparallel_enabled:
807790
return
808791

809-
is_instance_type_supported = instance_type in SM_DATAPARALLEL_SUPPORTED_INSTANCE_TYPES
810-
811792
err_msg = ""
812793

813-
if not is_instance_type_supported:
814-
# instance_type is required
815-
err_msg += (
816-
f"Provided instance_type {instance_type} is not supported by smdataparallel.\n"
817-
"Please specify one of the supported instance types:"
818-
f"{SM_DATAPARALLEL_SUPPORTED_INSTANCE_TYPES}\n"
819-
)
794+
if not instance_type:
795+
err_msg += "Please specify an instance_type for smdataparallel.\n"
820796

821797
if not image_uri:
822798
# ignore framework_version & py_version if image_uri is set
@@ -928,13 +904,6 @@ def validate_distribution(
928904
)
929905
if framework_name and framework_name == "pytorch":
930906
# We need to validate only for PyTorch framework
931-
validate_pytorch_distribution(
932-
distribution=validated_distribution,
933-
framework_name=framework_name,
934-
framework_version=framework_version,
935-
py_version=py_version,
936-
image_uri=image_uri,
937-
)
938907
validate_torch_distributed_distribution(
939908
instance_type=instance_type,
940909
distribution=validated_distribution,
@@ -968,13 +937,6 @@ def validate_distribution(
968937
)
969938
if framework_name and framework_name == "pytorch":
970939
# We need to validate only for PyTorch framework
971-
validate_pytorch_distribution(
972-
distribution=validated_distribution,
973-
framework_name=framework_name,
974-
framework_version=framework_version,
975-
py_version=py_version,
976-
image_uri=image_uri,
977-
)
978940
validate_torch_distributed_distribution(
979941
instance_type=instance_type,
980942
distribution=validated_distribution,
@@ -1023,63 +985,6 @@ def validate_distribution_for_instance_type(instance_type, distribution):
1023985
raise ValueError(err_msg)
1024986

1025987

1026-
def validate_pytorch_distribution(
1027-
distribution, framework_name, framework_version, py_version, image_uri
1028-
):
1029-
"""Check if pytorch distribution strategy is correctly invoked by the user.
1030-
1031-
Args:
1032-
distribution (dict): A dictionary with information to enable distributed training.
1033-
(Defaults to None if distributed training is not enabled.) For example:
1034-
1035-
.. code:: python
1036-
1037-
{
1038-
"pytorchddp": {
1039-
"enabled": True
1040-
}
1041-
}
1042-
framework_name (str): A string representing the name of framework selected.
1043-
framework_version (str): A string representing the framework version selected.
1044-
py_version (str): A string representing the python version selected.
1045-
image_uri (str): A string representing a Docker image URI.
1046-
1047-
Raises:
1048-
ValueError: if
1049-
`py_version` is not python3 or
1050-
`framework_version` is not in PYTORCHDDP_SUPPORTED_FRAMEWORK_VERSIONS
1051-
"""
1052-
if framework_name and framework_name != "pytorch":
1053-
# We need to validate only for PyTorch framework
1054-
return
1055-
1056-
pytorch_ddp_enabled = False
1057-
if "pytorchddp" in distribution:
1058-
pytorch_ddp_enabled = distribution.get("pytorchddp").get("enabled", False)
1059-
if not pytorch_ddp_enabled:
1060-
# Distribution strategy other than pytorchddp is selected
1061-
return
1062-
1063-
err_msg = ""
1064-
if not image_uri:
1065-
# ignore framework_version and py_version if image_uri is set
1066-
# in case image_uri is not set, then both are mandatory
1067-
if framework_version not in PYTORCHDDP_SUPPORTED_FRAMEWORK_VERSIONS:
1068-
err_msg += (
1069-
f"Provided framework_version {framework_version} is not supported by"
1070-
" pytorchddp.\n"
1071-
"Please specify one of the supported framework versions:"
1072-
f" {PYTORCHDDP_SUPPORTED_FRAMEWORK_VERSIONS} \n"
1073-
)
1074-
if "py3" not in py_version:
1075-
err_msg += (
1076-
f"Provided py_version {py_version} is not supported by pytorchddp.\n"
1077-
"Please specify py_version>=py3"
1078-
)
1079-
if err_msg:
1080-
raise ValueError(err_msg)
1081-
1082-
1083988
def validate_torch_distributed_distribution(
1084989
instance_type,
1085990
distribution,

src/sagemaker/pytorch/estimator.py

+14
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,20 @@ def __init__(
276276
kwargs["entry_point"] = entry_point
277277

278278
if distribution is not None:
279+
# rewrite pytorchddp to smdistributed
280+
if "pytorchddp" in distribution:
281+
if "smdistributed" in distribution:
282+
raise ValueError(
283+
"Cannot use both pytorchddp and smdistributed "
284+
"distribution options together.",
285+
distribution,
286+
)
287+
288+
# convert pytorchddp distribution into smdistributed distribution
289+
distribution = distribution.copy()
290+
distribution["smdistributed"] = {"dataparallel": distribution["pytorchddp"]}
291+
del distribution["pytorchddp"]
292+
279293
distribution = validate_distribution(
280294
distribution,
281295
self.instance_groups,

tests/unit/test_fw_utils.py

+2-73
Original file line numberDiff line numberDiff line change
@@ -854,17 +854,14 @@ def test_validate_smdataparallel_args_raises():
854854

855855
# Cases {PT|TF2}
856856
# 1. None instance type
857-
# 2. incorrect instance type
858-
# 3. incorrect python version
859-
# 4. incorrect framework version
857+
# 2. incorrect python version
858+
# 3. incorrect framework version
860859

861860
bad_args = [
862861
(None, "tensorflow", "2.3.1", "py3", smdataparallel_enabled),
863-
("ml.p3.2xlarge", "tensorflow", "2.3.1", "py3", smdataparallel_enabled),
864862
("ml.p3dn.24xlarge", "tensorflow", "2.3.1", "py2", smdataparallel_enabled),
865863
("ml.p3.16xlarge", "tensorflow", "1.3.1", "py3", smdataparallel_enabled),
866864
(None, "pytorch", "1.6.0", "py3", smdataparallel_enabled),
867-
("ml.p3.2xlarge", "pytorch", "1.6.0", "py3", smdataparallel_enabled),
868865
("ml.p3dn.24xlarge", "pytorch", "1.6.0", "py2", smdataparallel_enabled),
869866
("ml.p3.16xlarge", "pytorch", "1.5.0", "py3", smdataparallel_enabled),
870867
]
@@ -966,74 +963,6 @@ def test_validate_smdataparallel_args_not_raises():
966963
)
967964

968965

969-
def test_validate_pytorchddp_not_raises():
970-
# Case 1: Framework is not PyTorch
971-
fw_utils.validate_pytorch_distribution(
972-
distribution=None,
973-
framework_name="tensorflow",
974-
framework_version="2.9.1",
975-
py_version="py3",
976-
image_uri="custom-container",
977-
)
978-
# Case 2: Framework is PyTorch, but distribution is not PyTorchDDP
979-
pytorchddp_disabled = {"pytorchddp": {"enabled": False}}
980-
fw_utils.validate_pytorch_distribution(
981-
distribution=pytorchddp_disabled,
982-
framework_name="pytorch",
983-
framework_version="1.10",
984-
py_version="py3",
985-
image_uri="custom-container",
986-
)
987-
# Case 3: Framework is PyTorch, Distribution is PyTorchDDP enabled, supported framework and py versions
988-
pytorchddp_enabled = {"pytorchddp": {"enabled": True}}
989-
pytorchddp_supported_fw_versions = [
990-
"1.10",
991-
"1.10.0",
992-
"1.10.2",
993-
"1.11",
994-
"1.11.0",
995-
"1.12",
996-
"1.12.0",
997-
"1.12.1",
998-
"1.13.1",
999-
"2.0.0",
1000-
"2.0.1",
1001-
"2.1.0",
1002-
"2.2.0",
1003-
]
1004-
for framework_version in pytorchddp_supported_fw_versions:
1005-
fw_utils.validate_pytorch_distribution(
1006-
distribution=pytorchddp_enabled,
1007-
framework_name="pytorch",
1008-
framework_version=framework_version,
1009-
py_version="py3",
1010-
image_uri="custom-container",
1011-
)
1012-
1013-
1014-
def test_validate_pytorchddp_raises():
1015-
pytorchddp_enabled = {"pytorchddp": {"enabled": True}}
1016-
# Case 1: Unsupported framework version
1017-
with pytest.raises(ValueError):
1018-
fw_utils.validate_pytorch_distribution(
1019-
distribution=pytorchddp_enabled,
1020-
framework_name="pytorch",
1021-
framework_version="1.8",
1022-
py_version="py3",
1023-
image_uri=None,
1024-
)
1025-
1026-
# Case 2: Unsupported Py version
1027-
with pytest.raises(ValueError):
1028-
fw_utils.validate_pytorch_distribution(
1029-
distribution=pytorchddp_enabled,
1030-
framework_name="pytorch",
1031-
framework_version="1.10",
1032-
py_version="py2",
1033-
image_uri=None,
1034-
)
1035-
1036-
1037966
def test_validate_torch_distributed_not_raises():
1038967
# Case 1: Framework is PyTorch, but torch_distributed is not enabled
1039968
torch_distributed_disabled = {"torch_distributed": {"enabled": False}}

tests/unit/test_pytorch.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -801,14 +801,15 @@ def test_pytorch_ddp_distribution_configuration(
801801
distribution=pytorch.distribution
802802
)
803803
expected_torch_ddp = {
804-
"sagemaker_pytorch_ddp_enabled": True,
804+
"sagemaker_distributed_dataparallel_enabled": True,
805+
"sagemaker_distributed_dataparallel_custom_mpi_options": "",
805806
"sagemaker_instance_type": test_instance_type,
806807
}
807808
assert actual_pytorch_ddp == expected_torch_ddp
808809

809810

810811
def test_pytorch_ddp_distribution_configuration_unsupported(sagemaker_session):
811-
unsupported_framework_version = "1.9.1"
812+
unsupported_framework_version = "1.5.0"
812813
unsupported_py_version = "py2"
813814
with pytest.raises(ValueError) as error:
814815
_pytorch_estimator(

0 commit comments

Comments
 (0)