diff --git a/doc/frameworks/pytorch/using_pytorch.rst b/doc/frameworks/pytorch/using_pytorch.rst index 9d4a4de3de..52720fe12b 100644 --- a/doc/frameworks/pytorch/using_pytorch.rst +++ b/doc/frameworks/pytorch/using_pytorch.rst @@ -200,17 +200,32 @@ fit Optional Arguments Distributed PyTorch Training ============================ -You can run a multi-machine, distributed PyTorch training using the PyTorch Estimator. By default, PyTorch objects will -submit single-machine training jobs to SageMaker. If you set ``instance_count`` to be greater than one, multi-machine -training jobs will be launched when ``fit`` is called. When you run multi-machine training, SageMaker will import your -training script and run it on each host in the cluster. +SageMaker supports the `PyTorch DistributedDataParallel (DDP) +`_ +package. You simply need to check the variables in your training script, +such as the world size and the rank of the current host, when initializing +process groups for distributed training. +And then, launch the training job using the +:class:`sagemaker.pytorch.estimator.PyTorch` estimator class +with the ``pytorchddp`` option as the distribution strategy. -To initialize distributed training in your script you would call ``dist.init_process_group`` providing desired backend -and rank and setting 'WORLD_SIZE' environment variable similar to how you would do it outside of SageMaker using -environment variable initialization: +.. note:: + + This PyTorch DDP support is available + in the SageMaker PyTorch Deep Learning Containers v1.12 and later. + +Adapt Your Training Script +-------------------------- + +To initialize distributed training in your script, call +`torch.distributed.init_process_group +`_ +with the desired backend and the rank of the current host. .. code:: python + import torch.distributed as dist + if args.distributed: # Initialize the distributed environment. world_size = len(args.hosts) @@ -218,11 +233,65 @@ environment variable initialization: host_rank = args.hosts.index(args.current_host) dist.init_process_group(backend=args.backend, rank=host_rank) -SageMaker sets 'MASTER_ADDR' and 'MASTER_PORT' environment variables for you, but you can overwrite them. +SageMaker sets ``'MASTER_ADDR'`` and ``'MASTER_PORT'`` environment variables for you, +but you can also overwrite them. + +**Supported backends:** + +- ``gloo`` and ``tcp`` for CPU instances +- ``gloo`` and ``nccl`` for GPU instances + +Launching a Distributed Training Job +------------------------------------ + +You can run multi-node distributed PyTorch training jobs using the +:class:`sagemaker.pytorch.estimator.PyTorch` estimator class. +With ``instance_count=1``, the estimator submits a +single-node training job to SageMaker; with ``instance_count`` greater +than one, a multi-node training job is launched. + +To run a distributed training script that adopts +the `PyTorch DistributedDataParallel (DDP) package +`_, +choose the ``pytorchddp`` as the distributed training option in the ``PyTorch`` estimator. + +With the ``pytorchddp`` option, the SageMaker PyTorch estimator runs a SageMaker +training container for PyTorch, sets up the environment for MPI, and launches +the training job using the ``mpirun`` command on each worker with the given information +during the PyTorch DDP initialization. + +.. note:: + + The SageMaker PyTorch estimator doesn’t use ``torchrun`` for distributed training. + +For more information about setting up PyTorch DDP in your training script, +see `Getting Started with Distributed Data Parallel +`_ in the +PyTorch documentation. + +The following example shows how to run a PyTorch DDP training in SageMaker +using two ``ml.p4d.24xlarge`` instances: + +.. code:: python + + from sagemaker.pytorch import PyTorch + + pt_estimator = PyTorch( + entry_point="train_ptddp.py", + role="SageMakerRole", + framework_version="1.12.0", + py_version="py38", + instance_count=2, + instance_type="ml.p4d.24xlarge", + distribution={ + "pytorchddp": { + "enabled": True + } + } + ) + + pt_estimator.fit("s3://bucket/path/to/training/data") -Supported backends: -- `gloo` and `tcp` for cpu instances -- `gloo` and `nccl` for gpu instances ********************* diff --git a/src/sagemaker/fw_utils.py b/src/sagemaker/fw_utils.py index 40787d4440..ef99454a45 100644 --- a/src/sagemaker/fw_utils.py +++ b/src/sagemaker/fw_utils.py @@ -103,6 +103,17 @@ "1.11.0", ], } + +PYTORCHDDP_SUPPORTED_FRAMEWORK_VERSIONS = [ + "1.10", + "1.10.0", + "1.10.2", + "1.11", + "1.11.0", + "1.12", + "1.12.0", +] + SMDISTRIBUTED_SUPPORTED_STRATEGIES = ["dataparallel", "modelparallel"] @@ -728,6 +739,13 @@ def validate_distribution( distribution=distribution, image_uri=image_uri, ) + validate_pytorch_distribution( + distribution=distribution, + framework_name=framework_name, + framework_version=framework_version, + py_version=py_version, + image_uri=image_uri, + ) warn_if_parameter_server_with_multi_gpu( training_instance_type=instance_type, distribution=distribution ) @@ -747,12 +765,76 @@ def validate_distribution( distribution=distribution, image_uri=image_uri, ) + validate_pytorch_distribution( + distribution=distribution, + framework_name=framework_name, + framework_version=framework_version, + py_version=py_version, + image_uri=image_uri, + ) warn_if_parameter_server_with_multi_gpu( training_instance_type=instance_type, distribution=distribution ) return distribution +def validate_pytorch_distribution( + distribution, framework_name, framework_version, py_version, image_uri +): + """Check if pytorch distribution strategy is correctly invoked by the user. + + Args: + distribution (dict): A dictionary with information to enable distributed training. + (Defaults to None if distributed training is not enabled.) For example: + + .. code:: python + + { + "pytorchddp": { + "enabled": True + } + } + framework_name (str): A string representing the name of framework selected. + framework_version (str): A string representing the framework version selected. + py_version (str): A string representing the python version selected. + image_uri (str): A string representing a Docker image URI. + + Raises: + ValueError: if + `py_version` is not python3 or + `framework_version` is not in PYTORCHDDP_SUPPORTED_FRAMEWORK_VERSIONS + """ + if framework_name and framework_name != "pytorch": + # We need to validate only for PyTorch framework + return + + pytorch_ddp_enabled = False + if "pytorchddp" in distribution: + pytorch_ddp_enabled = distribution.get("pytorchddp").get("enabled", False) + if not pytorch_ddp_enabled: + # Distribution strategy other than pytorchddp is selected + return + + err_msg = "" + if not image_uri: + # ignore framework_version and py_version if image_uri is set + # in case image_uri is not set, then both are mandatory + if framework_version not in PYTORCHDDP_SUPPORTED_FRAMEWORK_VERSIONS: + err_msg += ( + f"Provided framework_version {framework_version} is not supported by" + " pytorchddp.\n" + "Please specify one of the supported framework versions:" + f" {PYTORCHDDP_SUPPORTED_FRAMEWORK_VERSIONS} \n" + ) + if "py3" not in py_version: + err_msg += ( + f"Provided py_version {py_version} is not supported by pytorchddp.\n" + "Please specify py_version>=py3" + ) + if err_msg: + raise ValueError(err_msg) + + def python_deprecation_warning(framework, latest_supported_version): """Placeholder docstring""" return PYTHON_2_DEPRECATION_WARNING.format( diff --git a/src/sagemaker/pytorch/estimator.py b/src/sagemaker/pytorch/estimator.py index 07554ca798..153d4656d4 100644 --- a/src/sagemaker/pytorch/estimator.py +++ b/src/sagemaker/pytorch/estimator.py @@ -38,6 +38,8 @@ class PyTorch(Framework): """Handle end-to-end training and deployment of custom PyTorch code.""" _framework_name = "pytorch" + LAUNCH_PYTORCH_DDP_ENV_NAME = "sagemaker_pytorch_ddp_enabled" + INSTANCE_TYPE_ENV_NAME = "sagemaker_instance_type" def __init__( self, @@ -153,6 +155,19 @@ def __init__( To find a complete list of parameters for SageMaker model parallelism, see :ref:`sm-sdk-modelparallel-general`. + **To enable PyTorch DDP:** + + .. code:: python + + { + "pytorchddp": { + "enabled": True + } + } + + To learn more, see `Distributed PyTorch Training + `_. + **To enable MPI:** .. code:: python @@ -217,10 +232,32 @@ def __init__( self.distribution = distribution or {} + def _pytorch_distribution_configuration(self, distribution): + """Returns a dict of distribution config for PyTorch training + + Args: + distribution (dict): A dictionary with information on how to run distributed training. + Returns: + dict containing Pytorch DDP config + """ + distribution_config = {} + pytorch_ddp_enabled = False + if "pytorchddp" in distribution: + pytorch_ddp_enabled = distribution.get("pytorchddp").get("enabled", False) + + if pytorch_ddp_enabled: + distribution_config[self.LAUNCH_PYTORCH_DDP_ENV_NAME] = pytorch_ddp_enabled + if self.instance_type is not None: + distribution_config[self.INSTANCE_TYPE_ENV_NAME] = self.instance_type + else: + distribution_config = self._distribution_configuration(distribution=distribution) + + return distribution_config + def hyperparameters(self): """Return hyperparameters used by your custom PyTorch code during model training.""" hyperparameters = super(PyTorch, self).hyperparameters() - additional_hyperparameters = self._distribution_configuration( + additional_hyperparameters = self._pytorch_distribution_configuration( distribution=self.distribution ) hyperparameters.update( diff --git a/tests/conftest.py b/tests/conftest.py index 8ccf443133..25f594a74b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -411,6 +411,18 @@ def tf_full_py_version(tf_full_version): return "py39" +@pytest.fixture(scope="module") +def pytorch_ddp_py_version(): + return "py3" + + +@pytest.fixture( + scope="module", params=["1.10", "1.10.0", "1.10.2", "1.11", "1.11.0", "1.12", "1.12.0"] +) +def pytorch_ddp_framework_version(request): + return request.param + + @pytest.fixture(scope="session") def cpu_instance_type(sagemaker_session, request): region = sagemaker_session.boto_session.region_name diff --git a/tests/data/_repack_model.py b/tests/data/_repack_model.py new file mode 100644 index 0000000000..3cfa6760b3 --- /dev/null +++ b/tests/data/_repack_model.py @@ -0,0 +1,110 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Repack model script for training jobs to inject entry points""" +from __future__ import absolute_import + +import argparse +import os +import shutil +import tarfile +import tempfile + +# Repack Model +# The following script is run via a training job which takes an existing model and a custom +# entry point script as arguments. The script creates a new model archive with the custom +# entry point in the "code" directory along with the existing model. Subsequently, when the model +# is unpacked for inference, the custom entry point will be used. +# Reference: https://docs.aws.amazon.com/sagemaker/latest/dg/amazon-sagemaker-toolkits.html + +# distutils.dir_util.copy_tree works way better than the half-baked +# shutil.copytree which bombs on previously existing target dirs... +# alas ... https://bugs.python.org/issue10948 +# we'll go ahead and use the copy_tree function anyways because this +# repacking is some short-lived hackery, right?? +from distutils.dir_util import copy_tree + + +def repack(inference_script, model_archive, dependencies=None, source_dir=None): # pragma: no cover + """Repack custom dependencies and code into an existing model TAR archive + + Args: + inference_script (str): The path to the custom entry point. + model_archive (str): The name or path (e.g. s3 uri) of the model TAR archive. + dependencies (str): A space-delimited string of paths to custom dependencies. + source_dir (str): The path to a custom source directory. + """ + + # the data directory contains a model archive generated by a previous training job + data_directory = "/opt/ml/input/data/training" + model_path = os.path.join(data_directory, model_archive.split("/")[-1]) + + # create a temporary directory + with tempfile.TemporaryDirectory() as tmp: + local_path = os.path.join(tmp, "local.tar.gz") + # copy the previous training job's model archive to the temporary directory + shutil.copy2(model_path, local_path) + src_dir = os.path.join(tmp, "src") + # create the "code" directory which will contain the inference script + code_dir = os.path.join(src_dir, "code") + os.makedirs(code_dir) + # extract the contents of the previous training job's model archive to the "src" + # directory of this training job + with tarfile.open(name=local_path, mode="r:gz") as tf: + tf.extractall(path=src_dir) + + if source_dir: + # copy /opt/ml/code to code/ + if os.path.exists(code_dir): + shutil.rmtree(code_dir) + shutil.copytree("/opt/ml/code", code_dir) + else: + # copy the custom inference script to code/ + entry_point = os.path.join("/opt/ml/code", inference_script) + shutil.copy2(entry_point, os.path.join(code_dir, inference_script)) + + # copy any dependencies to code/lib/ + if dependencies: + for dependency in dependencies.split(" "): + actual_dependency_path = os.path.join("/opt/ml/code", dependency) + lib_dir = os.path.join(code_dir, "lib") + if not os.path.exists(lib_dir): + os.mkdir(lib_dir) + if os.path.isfile(actual_dependency_path): + shutil.copy2(actual_dependency_path, lib_dir) + else: + if os.path.exists(lib_dir): + shutil.rmtree(lib_dir) + # a directory is in the dependencies. we have to copy + # all of /opt/ml/code into the lib dir because the original directory + # was flattened by the SDK training job upload.. + shutil.copytree("/opt/ml/code", lib_dir) + break + + # copy the "src" dir, which includes the previous training job's model and the + # custom inference script, to the output of this training job + copy_tree(src_dir, "/opt/ml/model") + + +if __name__ == "__main__": # pragma: no cover + parser = argparse.ArgumentParser() + parser.add_argument("--inference_script", type=str, default="inference.py") + parser.add_argument("--dependencies", type=str, default=None) + parser.add_argument("--source_dir", type=str, default=None) + parser.add_argument("--model_archive", type=str, default="model.tar.gz") + args, extra = parser.parse_known_args() + repack( + inference_script=args.inference_script, + dependencies=args.dependencies, + source_dir=args.source_dir, + model_archive=args.model_archive, + ) diff --git a/tests/data/pytorch_ddp/mnist_pt.py b/tests/data/pytorch_ddp/mnist_pt.py new file mode 100644 index 0000000000..6c37f9102b --- /dev/null +++ b/tests/data/pytorch_ddp/mnist_pt.py @@ -0,0 +1,246 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import print_function + +import argparse +import os +import time +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torchvision import datasets, transforms +from torch.optim.lr_scheduler import StepLR +from smdistributed.dataparallel.torch.parallel.distributed import DistributedDataParallel as DDP +import smdistributed.dataparallel.torch.distributed as dist + +dist.init_process_group(backend="nccl") + + +class Net(nn.Module): + def __init__(self): + super(Net, self).__init__() + self.conv1 = nn.Conv2d(1, 32, 3, 1) + self.conv2 = nn.Conv2d(32, 64, 3, 1) + self.dropout1 = nn.Dropout2d(0.25) + self.dropout2 = nn.Dropout2d(0.5) + self.fc1 = nn.Linear(9216, 128) + self.fc2 = nn.Linear(128, 10) + + def forward(self, x): + x = self.conv1(x) + x = F.relu(x) + x = self.conv2(x) + x = F.relu(x) + x = F.max_pool2d(x, 2) + x = self.dropout1(x) + x = torch.flatten(x, 1) + x = self.fc1(x) + x = F.relu(x) + x = self.dropout2(x) + x = self.fc2(x) + output = F.log_softmax(x, dim=1) + return output + + +def train(args, model, device, train_loader, optimizer, epoch): + model.train() + for batch_idx, (data, target) in enumerate(train_loader): + data, target = data.to(device), target.to(device) + optimizer.zero_grad() + output = model(data) + loss = F.nll_loss(output, target) + loss.backward() + optimizer.step() + if batch_idx % args.log_interval == 0 and args.rank == 0: + print( + "Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format( + epoch, + batch_idx * len(data) * args.world_size, + len(train_loader.dataset), + 100.0 * batch_idx / len(train_loader), + loss.item(), + ) + ) + if args.verbose: + print("Batch", batch_idx, "from rank", args.rank) + + +def test(model, device, test_loader): + model.eval() + test_loss = 0 + correct = 0 + with torch.no_grad(): + for data, target in test_loader: + data, target = data.to(device), target.to(device) + output = model(data) + test_loss += F.nll_loss(output, target, reduction="sum").item() # sum up batch loss + pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability + correct += pred.eq(target.view_as(pred)).sum().item() + + test_loss /= len(test_loader.dataset) + + print( + "\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n".format( + test_loss, correct, len(test_loader.dataset), 100.0 * correct / len(test_loader.dataset) + ) + ) + + +def main(): + # Training settings + parser = argparse.ArgumentParser(description="PyTorch MNIST Example") + parser.add_argument( + "--batch-size", + type=int, + default=64, + metavar="N", + help="input batch size for training (default: 64)", + ) + parser.add_argument( + "--test-batch-size", + type=int, + default=1000, + metavar="N", + help="input batch size for testing (default: 1000)", + ) + parser.add_argument( + "--epochs", + type=int, + default=14, + metavar="N", + help="number of epochs to train (default: 14)", + ) + parser.add_argument( + "--lr", type=float, default=1.0, metavar="LR", help="learning rate (default: 1.0)" + ) + parser.add_argument( + "--gamma", + type=float, + default=0.7, + metavar="M", + help="Learning rate step gamma (default: 0.7)", + ) + parser.add_argument("--seed", type=int, default=1, metavar="S", help="random seed (default: 1)") + parser.add_argument( + "--log-interval", + type=int, + default=10, + metavar="N", + help="how many batches to wait before logging training status", + ) + parser.add_argument( + "--save-model", action="store_true", default=False, help="For Saving the current Model" + ) + parser.add_argument( + "--verbose", + action="store_true", + default=False, + help="For displaying SM Distributed Data Parallel-specific logs", + ) + parser.add_argument( + "--data-path", + type=str, + default=os.environ["SM_CHANNEL_TRAINING"], + help="Path for downloading the MNIST dataset", + ) + + args = parser.parse_args() + args.world_size = dist.get_world_size() + args.rank = rank = dist.get_rank() + args.local_rank = local_rank = dist.get_local_rank() + args.lr = 1.0 + args.batch_size //= args.world_size // 8 + args.batch_size = max(args.batch_size, 1) + data_path = args.data_path + + if args.verbose: + print( + "Hello from rank", + rank, + "of local_rank", + local_rank, + "in world size of", + args.world_size, + ) + + if not torch.cuda.is_available(): + raise Exception( + "Must run SM Distributed Data Parallel MNIST example on CUDA-capable devices." + ) + + torch.manual_seed(args.seed) + + device = torch.device("cuda") + + if local_rank == 0: + train_dataset = datasets.MNIST( + data_path, + train=True, + download=False, # True sets a dependency on an external site for our tests. + transform=transforms.Compose( + [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))] + ), + ) + else: + time.sleep(8) + train_dataset = datasets.MNIST( + data_path, + train=True, + download=False, + transform=transforms.Compose( + [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))] + ), + ) + + train_sampler = torch.utils.data.distributed.DistributedSampler( + train_dataset, num_replicas=args.world_size, rank=rank + ) + train_loader = torch.utils.data.DataLoader( + train_dataset, + batch_size=args.batch_size, + shuffle=False, + num_workers=0, + pin_memory=True, + sampler=train_sampler, + ) + if rank == 0: + test_loader = torch.utils.data.DataLoader( + datasets.MNIST( + data_path, + train=False, + transform=transforms.Compose( + [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))] + ), + ), + batch_size=args.test_batch_size, + shuffle=True, + ) + + model = DDP(Net().to(device)) + torch.cuda.set_device(local_rank) + model.cuda(local_rank) + optimizer = optim.Adadelta(model.parameters(), lr=args.lr) + scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma) + for epoch in range(1, args.epochs + 1): + train(args, model, device, train_loader, optimizer, epoch) + if rank == 0: + test(model, device, test_loader) + scheduler.step() + + if args.save_model: + torch.save(model.state_dict(), "mnist_cnn.pt") + + +if __name__ == "__main__": + main() diff --git a/tests/integ/test_pytorchddp.py b/tests/integ/test_pytorchddp.py new file mode 100644 index 0000000000..c580fdebc2 --- /dev/null +++ b/tests/integ/test_pytorchddp.py @@ -0,0 +1,53 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import os +import pytest +import sagemaker.utils +import tests.integ as integ +from sagemaker.pytorch import PyTorch +from tests.integ import timeout +from tests.integ.test_pytorch import _upload_training_data + +pytorchddp_dir = os.path.join(os.path.dirname(__file__), "..", "data", "pytorch_ddp") + + +@pytest.mark.skip( + reason="This test is skipped for now due ML capacity error." + "This test should be re-enabled later." +) +@pytest.mark.skipif( + integ.test_region() not in integ.DATA_PARALLEL_TESTING_REGIONS, + reason="Only allow this test to run in IAD and CMH to limit usage of p3.16xlarge", +) +def test_pytorchddp_pt_mnist( + sagemaker_session, + pytorch_ddp_framework_version, + pytorch_ddp_py_version, +): + job_name = sagemaker.utils.unique_name_from_base("pt-pytorch-ddp") + estimator = PyTorch( + entry_point="mnist_pt.py", + role="SageMakerRole", + source_dir=pytorchddp_dir, + instance_count=2, + instance_type="ml.p3.16xlarge", + sagemaker_session=sagemaker_session, + framework_version=pytorch_ddp_framework_version, + py_version=pytorch_ddp_py_version, + distribution={"pytorchddp": {"enabled": True}}, + ) + + with timeout.timeout(minutes=integ.TRAINING_DEFAULT_TIMEOUT_MINUTES): + estimator.fit({"training": _upload_training_data(estimator)}, job_name=job_name) diff --git a/tests/unit/test_fw_utils.py b/tests/unit/test_fw_utils.py index 24bb7368a4..018255cf47 100644 --- a/tests/unit/test_fw_utils.py +++ b/tests/unit/test_fw_utils.py @@ -847,3 +847,65 @@ def test_validate_smdataparallel_args_not_raises(): fw_utils._validate_smdataparallel_args( instance_type, framework_name, framework_version, py_version, distribution ) + + +def test_validate_pytorchddp_not_raises(): + # Case 1: Framework is not PyTorch + fw_utils.validate_pytorch_distribution( + distribution=None, + framework_name="tensorflow", + framework_version="2.9.1", + py_version="py3", + image_uri="custom-container", + ) + # Case 2: Framework is PyTorch, but distribution is not PyTorchDDP + pytorchddp_disabled = {"pytorchddp": {"enabled": False}} + fw_utils.validate_pytorch_distribution( + distribution=pytorchddp_disabled, + framework_name="pytorch", + framework_version="1.10", + py_version="py3", + image_uri="custom-container", + ) + # Case 3: Framework is PyTorch, Distribution is PyTorchDDP enabled, supported framework and py versions + pytorchddp_enabled = {"pytorchddp": {"enabled": True}} + pytorchddp_supported_fw_versions = [ + "1.10", + "1.10.0", + "1.10.2", + "1.11", + "1.11.0", + "1.12", + "1.12.0", + ] + for framework_version in pytorchddp_supported_fw_versions: + fw_utils.validate_pytorch_distribution( + distribution=pytorchddp_enabled, + framework_name="pytorch", + framework_version=framework_version, + py_version="py3", + image_uri="custom-container", + ) + + +def test_validate_pytorchddp_raises(): + pytorchddp_enabled = {"pytorchddp": {"enabled": True}} + # Case 1: Unsupported framework version + with pytest.raises(ValueError): + fw_utils.validate_pytorch_distribution( + distribution=pytorchddp_enabled, + framework_name="pytorch", + framework_version="1.8", + py_version="py3", + image_uri=None, + ) + + # Case 2: Unsupported Py version + with pytest.raises(ValueError): + fw_utils.validate_pytorch_distribution( + distribution=pytorchddp_enabled, + framework_name="pytorch", + framework_version="1.10", + py_version="py2", + image_uri=None, + ) diff --git a/tests/unit/test_pytorch.py b/tests/unit/test_pytorch.py index 8b8541e816..082f699d63 100644 --- a/tests/unit/test_pytorch.py +++ b/tests/unit/test_pytorch.py @@ -56,6 +56,8 @@ "TrialComponentDisplayName": "tc", } +DISTRIBUTION_PYTORCH_DDP_ENABLED = {"pytorchddp": {"enabled": True}} + @pytest.fixture(name="sagemaker_session") def fixture_sagemaker_session(): @@ -97,7 +99,7 @@ def _pytorch_estimator( py_version, instance_type=None, base_job_name=None, - **kwargs + **kwargs, ): return PyTorch( entry_point=SCRIPT_PATH, @@ -108,7 +110,7 @@ def _pytorch_estimator( instance_count=INSTANCE_COUNT, instance_type=instance_type if instance_type else INSTANCE_TYPE, base_job_name=base_job_name, - **kwargs + **kwargs, ) @@ -763,3 +765,38 @@ def test_register_pytorch_model_auto_infer_framework( sagemaker_session.create_model_package_from_containers.assert_called_with( **expected_create_model_package_request ) + + +def test_pytorch_ddp_distribution_configuration( + sagemaker_session, pytorch_ddp_framework_version, pytorch_ddp_py_version +): + test_instance_type = "ml.p4d.24xlarge" + pytorch = _pytorch_estimator( + sagemaker_session, + framework_version=pytorch_ddp_framework_version, + py_version=pytorch_ddp_py_version, + distribution=DISTRIBUTION_PYTORCH_DDP_ENABLED, + instance_type=test_instance_type, + ) + actual_pytorch_ddp = pytorch._pytorch_distribution_configuration( + distribution=pytorch.distribution + ) + expected_torch_ddp = { + "sagemaker_pytorch_ddp_enabled": True, + "sagemaker_instance_type": test_instance_type, + } + assert actual_pytorch_ddp == expected_torch_ddp + + +def test_pytorch_ddp_distribution_configuration_unsupported(sagemaker_session): + unsupported_framework_version = "1.9.1" + unsupported_py_version = "py2" + with pytest.raises(ValueError) as error: + _pytorch_estimator( + sagemaker_session, + framework_version=unsupported_framework_version, + py_version=unsupported_py_version, + distribution=DISTRIBUTION_PYTORCH_DDP_ENABLED, + ) + assert (f"framework_version {unsupported_framework_version} is not supported") in str(error) + assert (f"py_version {unsupported_py_version} is not supported") in str(error)