diff --git a/doc/frameworks/pytorch/using_pytorch.rst b/doc/frameworks/pytorch/using_pytorch.rst index c1efdcbf20..6281de2635 100644 --- a/doc/frameworks/pytorch/using_pytorch.rst +++ b/doc/frameworks/pytorch/using_pytorch.rst @@ -262,7 +262,8 @@ during the PyTorch DDP initialization. .. note:: - The SageMaker PyTorch estimator doesn’t use ``torchrun`` for distributed training. + The SageMaker PyTorch estimator operates ``mpirun`` in the backend. + It doesn’t use ``torchrun`` for distributed training. For more information about setting up PyTorch DDP in your training script, see `Getting Started with Distributed Data Parallel @@ -292,7 +293,106 @@ using two ``ml.p4d.24xlarge`` instances: pt_estimator.fit("s3://bucket/path/to/training/data") +.. _distributed-pytorch-training-on-trainium: +Distributed PyTorch Training on Trainium +======================================== + +SageMaker Training on Trainium instances now supports the ``xla`` +package through ``torchrun``. With this, you do not need to manually pass RANK, +WORLD_SIZE, MASTER_ADDR, and MASTER_PORT. You can launch the training job using the +:class:`sagemaker.pytorch.estimator.PyTorch` estimator class +with the ``torch_distributed`` option as the distribution strategy. + +.. note:: + + This ``torch_distributed`` support is available + in the SageMaker Trainium (trn1) PyTorch Deep Learning Containers starting v1.11.0. + To find a complete list of supported versions of PyTorch Neuron, see `Neuron Containers `_ in the *AWS Deep Learning Containers GitHub repository*. + + SageMaker Debugger and Profiler are currently not supported with Trainium instances. + +Adapt Your Training Script to Initialize with the XLA backend +------------------------------------------------------------- + +To initialize distributed training in your script, call +`torch.distributed.init_process_group +`_ +with the ``xla`` backend as shown below. + +.. code:: python + + import torch.distributed as dist + + dist.init_process_group('xla') + +SageMaker takes care of ``'MASTER_ADDR'`` and ``'MASTER_PORT'`` for you via ``torchrun`` + +For detailed documentation about modifying your training script for Trainium, see `Multi-worker data-parallel MLP training using torchrun `_ in the *AWS Neuron Documentation*. + +**Currently Supported backends:** + +- ``xla`` for Trainium (Trn1) instances + +For up-to-date information on supported backends for Trainium instances, see `AWS Neuron Documentation `_. + +Launching a Distributed Training Job on Trainium +------------------------------------------------ + +You can run multi-node distributed PyTorch training jobs on Trainium instances using the +:class:`sagemaker.pytorch.estimator.PyTorch` estimator class. +With ``instance_count=1``, the estimator submits a +single-node training job to SageMaker; with ``instance_count`` greater +than one, a multi-node training job is launched. + +With the ``torch_distributed`` option, the SageMaker PyTorch estimator runs a SageMaker +training container for PyTorch Neuron, sets up the environment, and launches +the training job using the ``torchrun`` command on each worker with the given information. + +**Examples** + +The following examples show how to run a PyTorch training using ``torch_distributed`` in SageMaker +on one ``ml.trn1.2xlarge`` instance and two ``ml.trn1.32xlarge`` instances: + +.. code:: python + + from sagemaker.pytorch import PyTorch + + pt_estimator = PyTorch( + entry_point="train_ptddp.py", + role="SageMakerRole", + framework_version="1.11.0", + py_version="py38", + instance_count=1, + instance_type="ml.trn1.2xlarge", + distribution={ + "torch_distributed": { + "enabled": True + } + } + ) + + pt_estimator.fit("s3://bucket/path/to/training/data") + +.. code:: python + + from sagemaker.pytorch import PyTorch + + pt_estimator = PyTorch( + entry_point="train_ptddp.py", + role="SageMakerRole", + framework_version="1.11.0", + py_version="py38", + instance_count=2, + instance_type="ml.trn1.32xlarge", + distribution={ + "torch_distributed": { + "enabled": True + } + } + ) + + pt_estimator.fit("s3://bucket/path/to/training/data") ********************* Deploy PyTorch Models diff --git a/src/sagemaker/fw_utils.py b/src/sagemaker/fw_utils.py index c6414ffd62..dacd0a229c 100644 --- a/src/sagemaker/fw_utils.py +++ b/src/sagemaker/fw_utils.py @@ -134,6 +134,10 @@ "1.12.0", ] +TORCH_DISTRIBUTED_SUPPORTED_FRAMEWORK_VERSIONS = ["1.11", "1.11.0"] + +TRAINIUM_SUPPORTED_DISTRIBUTION_STRATEGIES = ["torch_distributed"] + SMDISTRIBUTED_SUPPORTED_STRATEGIES = ["dataparallel", "modelparallel"] @@ -701,7 +705,13 @@ def _validate_smdataparallel_args( def validate_distribution( - distribution, instance_groups, framework_name, framework_version, py_version, image_uri, kwargs + distribution, + instance_groups, + framework_name, + framework_version, + py_version, + image_uri, + kwargs, ): """Check if distribution strategy is correctly invoked by the user. @@ -767,21 +777,35 @@ def validate_distribution( f"Invalid training instance group {train_instance_group.instance_group_name} !" ) instance_type = train_instance_group.instance_type - validate_smdistributed( + validate_distribution_for_instance_type( instance_type=instance_type, - framework_name=framework_name, - framework_version=framework_version, - py_version=py_version, distribution=distribution, - image_uri=image_uri, ) - validate_pytorch_distribution( - distribution=distribution, + validate_smdistributed( + instance_type=instance_type, framework_name=framework_name, framework_version=framework_version, py_version=py_version, + distribution=distribution, image_uri=image_uri, ) + if framework_name and framework_name == "pytorch": + # We need to validate only for PyTorch framework + validate_pytorch_distribution( + distribution=distribution, + framework_name=framework_name, + framework_version=framework_version, + py_version=py_version, + image_uri=image_uri, + ) + validate_torch_distributed_distribution( + instance_type=instance_type, + distribution=distribution, + framework_version=framework_version, + py_version=py_version, + image_uri=image_uri, + entry_point=kwargs["entry_point"], + ) warn_if_parameter_server_with_multi_gpu( training_instance_type=instance_type, distribution=distribution ) @@ -793,27 +817,75 @@ def validate_distribution( instance_type = renamed_kwargs( "train_instance_type", "instance_type", kwargs.get("instance_type"), kwargs ) - validate_smdistributed( + validate_distribution_for_instance_type( instance_type=instance_type, - framework_name=framework_name, - framework_version=framework_version, - py_version=py_version, distribution=distribution, - image_uri=image_uri, ) - validate_pytorch_distribution( - distribution=distribution, + validate_smdistributed( + instance_type=instance_type, framework_name=framework_name, framework_version=framework_version, py_version=py_version, + distribution=distribution, image_uri=image_uri, ) + if framework_name and framework_name == "pytorch": + # We need to validate only for PyTorch framework + validate_pytorch_distribution( + distribution=distribution, + framework_name=framework_name, + framework_version=framework_version, + py_version=py_version, + image_uri=image_uri, + ) + validate_torch_distributed_distribution( + instance_type=instance_type, + distribution=distribution, + framework_version=framework_version, + py_version=py_version, + image_uri=image_uri, + entry_point=kwargs["entry_point"], + ) warn_if_parameter_server_with_multi_gpu( training_instance_type=instance_type, distribution=distribution ) return distribution +def validate_distribution_for_instance_type(instance_type, distribution): + """Check if the provided distribution strategy is supported for the instance_type + + Args: + instance_type (str): A string representing the type of training instance selected. + distribution (dict): A dictionary with information to enable distributed training. + """ + err_msg = "" + if isinstance(instance_type, str): + match = re.match(r"^ml[\._]([a-z\d]+)\.?\w*$", instance_type) + if match and match[1].startswith("trn"): + keys = list(distribution.keys()) + if len(keys) == 0: + return + if len(keys) == 1: + distribution_strategy = keys[0] + if distribution_strategy != "torch_distributed": + err_msg += ( + f"Provided distribution strategy {distribution_strategy} is not supported" + " for Trainium instances.\n" + "Please specify one of the following supported distribution strategies:" + f" {TRAINIUM_SUPPORTED_DISTRIBUTION_STRATEGIES} \n" + ) + elif len(keys) > 1: + err_msg += ( + "Multiple distribution strategies are not supported for Trainium instances.\n" + "Please specify one of the following supported distribution strategies:" + f" {TRAINIUM_SUPPORTED_DISTRIBUTION_STRATEGIES} " + ) + + if err_msg: + raise ValueError(err_msg) + + def validate_pytorch_distribution( distribution, framework_name, framework_version, py_version, image_uri ): @@ -871,6 +943,86 @@ def validate_pytorch_distribution( raise ValueError(err_msg) +def validate_torch_distributed_distribution( + instance_type, + distribution, + framework_version, + py_version, + image_uri, + entry_point, +): + """Check if torch_distributed distribution strategy is correctly invoked by the user. + + Args: + instance_type (str): A string representing the type of training instance selected. + distribution (dict): A dictionary with information to enable distributed training. + (Defaults to None if distributed training is not enabled.) For example: + + .. code:: python + + { + "torch_distributed": { + "enabled": True + } + } + framework_version (str): A string representing the framework version selected. + py_version (str): A string representing the python version selected. + image_uri (str): A string representing a Docker image URI. + entry_point (str or PipelineVariable): The absolute or relative path to the local Python + source file that should be executed as the entry point to + training. + + Raises: + ValueError: if + `py_version` is not python3 or + `framework_version` is not in TORCH_DISTRIBUTED_SUPPORTED_FRAMEWORK_VERSIONS + """ + + torch_distributed_enabled = False + if "torch_distributed" in distribution: + torch_distributed_enabled = distribution.get("torch_distributed").get("enabled", False) + if not torch_distributed_enabled: + # Distribution strategy other than torch_distributed is selected + return + + err_msg = "" + if not image_uri: + # ignore framework_version and py_version if image_uri is set + # in case image_uri is not set, then both are mandatory + if framework_version not in TORCH_DISTRIBUTED_SUPPORTED_FRAMEWORK_VERSIONS: + err_msg += ( + f"Provided framework_version {framework_version} is not supported by" + " torch_distributed.\n" + "Please specify one of the supported framework versions:" + f" {TORCH_DISTRIBUTED_SUPPORTED_FRAMEWORK_VERSIONS} \n" + ) + if "py3" not in py_version: + err_msg += ( + f"Provided py_version {py_version} is not supported by torch_distributed.\n" + "Please specify py_version>=py3" + ) + + # Check instance compatibility + match = re.match(r"^ml[\._]([a-z\d]+)\.?\w*$", instance_type) + if match: + if not match[1].startswith("trn"): + err_msg += ( + "torch_distributed is currently supported only for trainium instances.\n" + " Please refer https://sagemaker.readthedocs.io/en/stable/frameworks/pytorch/using_pytorch.html#distributed-pytorch-training \n" # noqa E501 # pylint: disable=c0301 + "for information regarding distributed training on non-trainium instances" + ) + + # Check entry point type + if not entry_point.endswith(".py"): + err_msg += ( + "Unsupported entry point type for the distribution torch_distributed.\n" + "Only python programs (*.py) are supported." + ) + + if err_msg: + raise ValueError(err_msg) + + def python_deprecation_warning(framework, latest_supported_version): """Placeholder docstring""" return PYTHON_2_DEPRECATION_WARNING.format( diff --git a/src/sagemaker/pytorch/estimator.py b/src/sagemaker/pytorch/estimator.py index 8f2a73b6e8..686de4a78c 100644 --- a/src/sagemaker/pytorch/estimator.py +++ b/src/sagemaker/pytorch/estimator.py @@ -39,6 +39,7 @@ class PyTorch(Framework): _framework_name = "pytorch" LAUNCH_PYTORCH_DDP_ENV_NAME = "sagemaker_pytorch_ddp_enabled" + LAUNCH_TORCH_DISTRIBUTED_ENV_NAME = "sagemaker_torch_distributed_enabled" INSTANCE_TYPE_ENV_NAME = "sagemaker_instance_type" def __init__( @@ -168,6 +169,19 @@ def __init__( To learn more, see `Distributed PyTorch Training `_. + **To enable Torch Distributed (for Trainium instances only):** + + .. code:: python + + { + "torch_distributed": { + "enabled": True + } + } + + To learn more, see `Distributed PyTorch Training on Trainium + `_. + **To enable MPI:** .. code:: python @@ -219,6 +233,10 @@ def __init__( super(PyTorch, self).__init__( entry_point, source_dir, hyperparameters, image_uri=image_uri, **kwargs ) + + if "entry_point" not in kwargs: + kwargs["entry_point"] = entry_point + if distribution is not None: distribution = validate_distribution( distribution, @@ -242,13 +260,21 @@ def _pytorch_distribution_configuration(self, distribution): """ distribution_config = {} pytorch_ddp_enabled = False + torch_distributed_enabled = False + if "pytorchddp" in distribution: pytorch_ddp_enabled = distribution.get("pytorchddp").get("enabled", False) + elif "torch_distributed" in distribution: + torch_distributed_enabled = distribution.get("torch_distributed").get("enabled", False) if pytorch_ddp_enabled: distribution_config[self.LAUNCH_PYTORCH_DDP_ENV_NAME] = pytorch_ddp_enabled if self.instance_type is not None: distribution_config[self.INSTANCE_TYPE_ENV_NAME] = self.instance_type + elif torch_distributed_enabled: + distribution_config[self.LAUNCH_TORCH_DISTRIBUTED_ENV_NAME] = torch_distributed_enabled + if self.instance_type is not None: + distribution_config[self.INSTANCE_TYPE_ENV_NAME] = self.instance_type else: distribution_config = self._distribution_configuration(distribution=distribution) diff --git a/tests/conftest.py b/tests/conftest.py index 59397ec9af..17a5c1db9c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -447,6 +447,16 @@ def pytorch_ddp_framework_version(request): return request.param +@pytest.fixture(scope="module") +def torch_distributed_py_version(): + return "py3" + + +@pytest.fixture(scope="module", params=["1.11.0"]) +def torch_distributed_framework_version(request): + return request.param + + @pytest.fixture(scope="session") def cpu_instance_type(sagemaker_session, request): region = sagemaker_session.boto_session.region_name diff --git a/tests/data/torch_distributed/mnist_mlp_trainium.py b/tests/data/torch_distributed/mnist_mlp_trainium.py new file mode 100644 index 0000000000..42dcf4aed7 --- /dev/null +++ b/tests/data/torch_distributed/mnist_mlp_trainium.py @@ -0,0 +1,95 @@ +import os +import time +import torch +from model import MLP + +from torchvision.datasets import mnist +from torch.utils.data import DataLoader +from torchvision.transforms import ToTensor + +# XLA imports +import torch_xla.core.xla_model as xm + +# XLA imports for parallel loader and multi-processing +import torch_xla.distributed.parallel_loader as pl +from torch.utils.data.distributed import DistributedSampler + +# Initialize XLA process group for torchrun +import torch_xla.distributed.xla_backend + +torch.distributed.init_process_group("xla") + +# Global constants +EPOCHS = 4 +WARMUP_STEPS = 2 +BATCH_SIZE = 32 + +# Load MNIST train dataset +train_dataset = mnist.MNIST( + root=os.path.join("./MNIST_DATA_train", str(xm.get_ordinal())), + train=True, + download=True, + transform=ToTensor(), +) + + +def main(): + # XLA MP: get world size + world_size = xm.xrt_world_size() + # multi-processing: ensure each worker has same initial weights + torch.manual_seed(0) + + # Move model to device and declare optimizer and loss function + device = "xla" + model = MLP().to(device) + # For multiprocessing, scale up learning rate + optimizer = torch.optim.SGD(model.parameters(), lr=0.01 * world_size) + loss_fn = torch.nn.NLLLoss() + + # Prepare data loader + train_sampler = None + if world_size > 1: + train_sampler = DistributedSampler( + train_dataset, num_replicas=world_size, rank=xm.get_ordinal(), shuffle=True + ) + train_loader = DataLoader( + train_dataset, + batch_size=BATCH_SIZE, + sampler=train_sampler, + shuffle=False if train_sampler else True, + ) + # XLA MP: use MpDeviceLoader from torch_xla.distributed + train_device_loader = pl.MpDeviceLoader(train_loader, device) + + # Run the training loop + print("----------Training ---------------") + model.train() + for epoch in range(EPOCHS): + start = time.time() + print(f"Epoch: {epoch}") + for idx, (train_x, train_label) in enumerate(train_device_loader): + optimizer.zero_grad() + train_x = train_x.view(train_x.size(0), -1) + output = model(train_x) + loss = loss_fn(output, train_label) + loss.backward() + xm.optimizer_step(optimizer) # XLA MP: performs grad allreduce and optimizer step + if idx < WARMUP_STEPS: # skip warmup iterations + start = time.time() + + # Compute statistics for the last epoch + interval = idx - WARMUP_STEPS # skip warmup iterations + throughput = interval / (time.time() - start) + print("Train throughput (iter/sec): {}".format(throughput)) + print("Final loss is {:0.4f}".format(loss.detach().to("cpu"))) + + # Save checkpoint for evaluation (xm.save ensures only one process save) + os.makedirs("checkpoints", exist_ok=True) + checkpoint = {"state_dict": model.state_dict()} + xm.save(checkpoint, "checkpoints/checkpoint.pt") + + print("----------End Training ---------------") + + +if __name__ == "__main__": + main() diff --git a/tests/integ/test_torch_distributed.py b/tests/integ/test_torch_distributed.py new file mode 100644 index 0000000000..860836b7d1 --- /dev/null +++ b/tests/integ/test_torch_distributed.py @@ -0,0 +1,49 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import os +import pytest +import sagemaker.utils +import tests.integ as integ +from sagemaker.pytorch import PyTorch +from tests.integ import timeout +from tests.integ.test_pytorch import _upload_training_data + +torch_distributed_dir = os.path.join(os.path.dirname(__file__), "..", "data", "torch_distributed") + + +@pytest.mark.skip( + reason="Disabling until the launch of SM Trainium containers" + "This test should be re-enabled later." +) +def test_torch_distributed_trn1_pt_mnist( + sagemaker_session, + torch_distributed_framework_version, + torch_distributed_py_version, +): + job_name = sagemaker.utils.unique_name_from_base("pt-torch-distributed") + estimator = PyTorch( + entry_point="mnist_mlp_neuron.py", + role="SageMakerRole", + source_dir=torch_distributed_dir, + instance_count=1, + instance_type="ml.trn1.2xlarge", + sagemaker_session=sagemaker_session, + framework_version=torch_distributed_framework_version, + py_version=torch_distributed_py_version, + distribution={"torch_distributed": {"enabled": True}}, + ) + + with timeout.timeout(minutes=integ.TRAINING_DEFAULT_TIMEOUT_MINUTES): + estimator.fit({"training": _upload_training_data(estimator)}, job_name=job_name) diff --git a/tests/unit/test_fw_utils.py b/tests/unit/test_fw_utils.py index 0278725a61..1badd1be0c 100644 --- a/tests/unit/test_fw_utils.py +++ b/tests/unit/test_fw_utils.py @@ -655,7 +655,7 @@ def test_validate_distribution_not_raises(): None, # framework_version None, # py_version "custom-container", - {"instance_type": instance_type}, # kwargs + {"instance_type": instance_type, "entry_point": "train.py"}, # kwargs ) for framework in frameworks: @@ -683,7 +683,7 @@ def test_validate_distribution_not_raises(): None, # framework_version None, # py_version "custom-container", - {}, # kwargs + {"entry_point": "train.py"}, # kwargs ) @@ -723,7 +723,7 @@ def test_validate_distribution_raises(): None, # framework_version None, # py_version "custom-container", - {"instance_type": instance_type}, # kwargs + {"instance_type": instance_type, "entry_point": "train.py"}, # kwargs ) for framework in frameworks: @@ -946,3 +946,97 @@ def test_validate_pytorchddp_raises(): py_version="py2", image_uri=None, ) + + +def test_validate_torch_distributed_not_raises(): + + # Case 1: Framework is PyTorch, but distribution is not torch_distributed + torch_distributed_disabled = {"torch_distributed": {"enabled": False}} + fw_utils.validate_torch_distributed_distribution( + instance_type="ml.trn1.2xlarge", + distribution=torch_distributed_disabled, + framework_version="1.11.0", + py_version="py3", + image_uri="custom-container", + entry_point="train.py", + ) + # Case 2: Distribution is torch_distributed enabled, supported framework and py versions + torch_distributed_enabled = {"torch_distributed": {"enabled": True}} + torch_distributed_supported_fw_versions = [ + "1.11.0", + ] + for framework_version in torch_distributed_supported_fw_versions: + fw_utils.validate_torch_distributed_distribution( + instance_type="ml.trn1.2xlarge", + distribution=torch_distributed_enabled, + framework_version=framework_version, + py_version="py3", + image_uri="custom-container", + entry_point="train.py", + ) + + +def test_validate_torch_distributed_raises(): + torch_distributed_enabled = {"torch_distributed": {"enabled": True}} + # Case 1: Unsupported framework version + with pytest.raises(ValueError): + fw_utils.validate_torch_distributed_distribution( + instance_type="ml.trn1.2xlarge", + distribution=torch_distributed_enabled, + framework_version="1.10.0", + py_version="py3", + image_uri=None, + entry_point="train.py", + ) + + # Case 2: Unsupported Py version + with pytest.raises(ValueError): + fw_utils.validate_torch_distributed_distribution( + instance_type="ml.trn1.2xlarge", + distribution=torch_distributed_enabled, + framework_version="1.11.0", + py_version="py2", + image_uri=None, + entry_point="train.py", + ) + + # Case 3: Unsupported Entry point type + with pytest.raises(ValueError): + fw_utils.validate_torch_distributed_distribution( + instance_type="ml.trn1.2xlarge", + distribution=torch_distributed_enabled, + framework_version="1.11.0", + py_version="py3", + image_uri=None, + entry_point="train.sh", + ) + + +def test_validate_unsupported_distributions_trainium_raises(): + with pytest.raises(ValueError): + mpi_enabled = {"mpi": {"enabled": True}} + fw_utils.validate_distribution_for_instance_type( + distribution=mpi_enabled, + instance_type="ml.trn1.2xlarge", + ) + + with pytest.raises(ValueError): + mpi_enabled = {"mpi": {"enabled": True}} + fw_utils.validate_distribution_for_instance_type( + distribution=mpi_enabled, + instance_type="ml.trn1.32xlarge", + ) + + with pytest.raises(ValueError): + pytorch_ddp_enabled = {"pytorch_ddp": {"enabled": True}} + fw_utils.validate_distribution_for_instance_type( + distribution=pytorch_ddp_enabled, + instance_type="ml.trn1.32xlarge", + ) + + with pytest.raises(ValueError): + smdataparallel_enabled = {"smdataparallel": {"enabled": True}} + fw_utils.validate_distribution_for_instance_type( + distribution=smdataparallel_enabled, + instance_type="ml.trn1.32xlarge", + )