diff --git a/setup.py b/setup.py index fe69866..38b705d 100644 --- a/setup.py +++ b/setup.py @@ -53,7 +53,7 @@ def read(fname): 'Programming Language :: Python :: 3.9', ], - install_requires=['retrying', 'sagemaker-training>=4.2.0', 'six>=1.12.0'], + install_requires=['retrying', 'sagemaker-training>=4.3.0', 'six>=1.12.0'], extras_require={ 'test': test_dependencies }, diff --git a/src/sagemaker_pytorch_container/training.py b/src/sagemaker_pytorch_container/training.py index 224506f..728c174 100644 --- a/src/sagemaker_pytorch_container/training.py +++ b/src/sagemaker_pytorch_container/training.py @@ -24,6 +24,7 @@ LAUNCH_MPI_ENV_NAME = 'sagemaker_mpi_enabled' LAUNCH_PYTORCH_DDP_ENV_NAME = "sagemaker_pytorch_ddp_enabled" LAUNCH_PYTORCH_XLA_ENV_NAME = "sagemaker_pytorch_xla_multi_worker_enabled" +LAUNCH_TORCH_DISTRIBUTED_ENV_NAME = "sagemaker_torch_distributed_enabled" logger = logging.getLogger(__name__) @@ -65,6 +66,10 @@ def train(training_environment): pytorch_xla_enabled = training_environment.additional_framework_parameters.get( LAUNCH_PYTORCH_XLA_ENV_NAME, False ) + + torch_distributed_enabled = training_environment.additional_framework_parameters.get( + LAUNCH_TORCH_DISTRIBUTED_ENV_NAME, False + ) # default scenario runner_type = runner.ProcessRunnerType @@ -74,6 +79,9 @@ def train(training_environment): elif pytorch_ddp_enabled: runner_type = runner.SMDataParallelRunnerType logger.info('Invoking SMDataParallel for native PT DDP job') + elif torch_distributed_enabled: + runner_type = runner.TorchDistributedRunnerType + logger.info('Invoking TorchDistributed...') elif smdataparallel_enabled: runner_type = runner.SMDataParallelRunnerType logger.info('Invoking SMDataParallel') diff --git a/test/unit/test_train.py b/test/unit/test_train.py index ac720cf..63af5a2 100644 --- a/test/unit/test_train.py +++ b/test/unit/test_train.py @@ -120,6 +120,22 @@ def test_train_pytorch_ddp(run_module, training_env): ) +@patch("sagemaker_training.entry_point.run") +@patch('socket.gethostbyname', MagicMock()) +def test_train_torch_distributed(run_module, training_env): + training_env.additional_framework_parameters["sagemaker_torch_distributed_enabled"] = True + + train(training_env) + run_module.assert_called_with( + uri=training_env.module_dir, + user_entry_point=training_env.user_entry_point, + args=training_env.to_cmd_args(), + env_vars=training_env.to_env_vars(), + capture_error=True, + runner_type=runner.TorchDistributedRunnerType, + ) + + @patch("sagemaker_training.entry_point.run") @patch('socket.gethostbyname', MagicMock()) def test_train_pytorch_xla_distributed(run_module, training_env):