-
Notifications
You must be signed in to change notification settings - Fork 1.2k
feature: Add support for native PT DDP distribution #3223
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
7985703
66cc822
2af1e60
a99b54e
810a36a
3f98b33
b2cf2b6
9dcbb02
458031d
5d0bef2
86f9667
72bb2d8
46f882c
a21719e
6c28d98
a8016c0
77848f2
a013b71
905b6b2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -103,6 +103,17 @@ | |
"1.11.0", | ||
], | ||
} | ||
|
||
PYTORCHDDP_SUPPORTED_FRAMEWORK_VERSIONS = [ | ||
"1.10", | ||
"1.10.0", | ||
"1.10.2", | ||
"1.11", | ||
"1.11.0", | ||
"1.12", | ||
"1.12.0", | ||
] | ||
|
||
SMDISTRIBUTED_SUPPORTED_STRATEGIES = ["dataparallel", "modelparallel"] | ||
|
||
|
||
|
@@ -728,6 +739,12 @@ def validate_distribution( | |
distribution=distribution, | ||
image_uri=image_uri, | ||
) | ||
validate_pytorch_distribution( | ||
framework_name=framework_name, | ||
framework_version=framework_version, | ||
py_version=py_version, | ||
image_uri=image_uri, | ||
) | ||
warn_if_parameter_server_with_multi_gpu( | ||
training_instance_type=instance_type, distribution=distribution | ||
) | ||
|
@@ -747,12 +764,56 @@ def validate_distribution( | |
distribution=distribution, | ||
image_uri=image_uri, | ||
) | ||
validate_pytorch_distribution( | ||
framework_name=framework_name, | ||
framework_version=framework_version, | ||
py_version=py_version, | ||
image_uri=image_uri, | ||
) | ||
warn_if_parameter_server_with_multi_gpu( | ||
training_instance_type=instance_type, distribution=distribution | ||
) | ||
return distribution | ||
|
||
|
||
def validate_pytorch_distribution(framework_name, framework_version, py_version, image_uri): | ||
"""Check if pytorch distribution strategy is correctly invoked by the user. | ||
|
||
Args: | ||
framework_name (str): A string representing the name of framework selected. | ||
framework_version (str): A string representing the framework version selected. | ||
py_version (str): A string representing the python version selected. | ||
image_uri (str): A string representing a Docker image URI. | ||
|
||
Raises: | ||
ValueError: if | ||
`py_version` is not python3 or | ||
`framework_version` is not in PYTORCHDDP_SUPPORTED_FRAMEWORK_VERSIONS | ||
""" | ||
if framework_name != "pytorch": | ||
# We need to validate only for PyTorch framework | ||
return | ||
|
||
err_msg = "" | ||
if not image_uri: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what will happen if using an custom docker image uri and there is old toolkit installed there? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am mirroring the logic that we have for smdistributed: https://github.com/aws/sagemaker-python-sdk/blob/master/src/sagemaker/fw_utils.py#L635. |
||
# ignore framework_version and py_version if image_uri is set | ||
# in case image_uri is not set, then both are mandatory | ||
if framework_version not in PYTORCHDDP_SUPPORTED_FRAMEWORK_VERSIONS: | ||
err_msg += ( | ||
f"Provided framework_version {framework_version} is not supported by" | ||
" pytorchddp.\n" | ||
"Please specify one of the supported framework versions:" | ||
f" {PYTORCHDDP_SUPPORTED_FRAMEWORK_VERSIONS} \n" | ||
) | ||
if "py3" not in py_version: | ||
err_msg += ( | ||
f"Provided py_version {py_version} is not supported by pytorchddp.\n" | ||
"Please specify py_version>=py3" | ||
) | ||
if err_msg: | ||
raise ValueError(err_msg) | ||
|
||
|
||
def python_deprecation_warning(framework, latest_supported_version): | ||
"""Placeholder docstring""" | ||
return PYTHON_2_DEPRECATION_WARNING.format( | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -38,6 +38,8 @@ class PyTorch(Framework): | |
"""Handle end-to-end training and deployment of custom PyTorch code.""" | ||
|
||
_framework_name = "pytorch" | ||
LAUNCH_PYTORCH_DDP_ENV_NAME = "sagemaker_pytorch_ddp_enabled" | ||
INSTANCE_TYPE_ENV_NAME = "sagemaker_instance_type" | ||
|
||
def __init__( | ||
self, | ||
|
@@ -153,6 +155,19 @@ def __init__( | |
To find a complete list of parameters for SageMaker model parallelism, | ||
see :ref:`sm-sdk-modelparallel-general`. | ||
|
||
**To enable PyTorch DDP:** | ||
|
||
.. code:: python | ||
|
||
{ | ||
"pytorchddp": { | ||
"enabled": True | ||
} | ||
} | ||
|
||
To learn more, see `Distributed PyTorch Training | ||
<https://sagemaker.readthedocs.io/en/stable/frameworks/pytorch/using_pytorch.html#distributed-pytorch-training>`_. | ||
|
||
**To enable MPI:** | ||
|
||
.. code:: python | ||
|
@@ -217,10 +232,32 @@ def __init__( | |
|
||
self.distribution = distribution or {} | ||
|
||
def _pytorch_distribution_configuration(self, distribution): | ||
"""Returns a dict of distribution config for PyTorch training | ||
|
||
Args: | ||
distribution (dict): A dictionary with information on how to run distributed training. | ||
Returns: | ||
dict containing Pytorch DDP config | ||
""" | ||
distribution_config = {} | ||
pytorch_ddp_enabled = False | ||
if "pytorchddp" in distribution: | ||
pytorch_ddp_enabled = distribution.get("pytorchddp").get("enabled", False) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What if the user set up There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Makes sense, fixed. |
||
|
||
if pytorch_ddp_enabled: | ||
distribution_config[self.LAUNCH_PYTORCH_DDP_ENV_NAME] = pytorch_ddp_enabled | ||
if self.instance_type is not None: | ||
distribution_config[self.INSTANCE_TYPE_ENV_NAME] = self.instance_type | ||
else: | ||
distribution_config = self._distribution_configuration(distribution=distribution) | ||
|
||
return distribution_config | ||
|
||
def hyperparameters(self): | ||
"""Return hyperparameters used by your custom PyTorch code during model training.""" | ||
hyperparameters = super(PyTorch, self).hyperparameters() | ||
additional_hyperparameters = self._distribution_configuration( | ||
additional_hyperparameters = self._pytorch_distribution_configuration( | ||
distribution=self.distribution | ||
) | ||
hyperparameters.update( | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't see 1.10.2 here. Also are these versions essentially use the same DLC docker? which is 1.10.2? as we only have one DLC docker file for each minor release.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Adding 1.10.2.
Yeah, seems like there is only one DLC docker file for 1.10 (which is for 1.10.2): https://github.com/aws/deep-learning-containers/blob/master/available_images.md#prior-sagemaker-framework-container-versions