diff --git a/src/sagemaker_pytorch_container/training.py b/src/sagemaker_pytorch_container/training.py index 9331117..7c671be 100644 --- a/src/sagemaker_pytorch_container/training.py +++ b/src/sagemaker_pytorch_container/training.py @@ -21,6 +21,8 @@ MASTER_PORT = '7777' LAUNCH_SMDATAPARALLEL_ENV_NAME = 'sagemaker_distributed_dataparallel_enabled' +LAUNCH_MPI_ENV_NAME = 'sagemaker_mpi_enabled' +LAUNCH_PYTORCH_DDP_ENV_NAME = "sagemaker_pytorch_ddp_enabled" logger = logging.getLogger(__name__) @@ -49,7 +51,11 @@ def train(training_environment): _set_distributed_environment(training_environment) - mpi_enabled = training_environment.additional_framework_parameters.get('sagemaker_mpi_enabled') + mpi_enabled = training_environment.additional_framework_parameters.get(LAUNCH_MPI_ENV_NAME) + + pytorch_ddp_enabled = training_environment.additional_framework_parameters.get( + LAUNCH_PYTORCH_DDP_ENV_NAME, False + ) smdataparallel_enabled = training_environment.additional_framework_parameters.get( LAUNCH_SMDATAPARALLEL_ENV_NAME, False @@ -60,6 +66,9 @@ def train(training_environment): if training_environment.current_instance_group in training_environment.distribution_instance_groups: if mpi_enabled: runner_type = runner.MPIRunnerType + elif pytorch_ddp_enabled: + runner_type = runner.SMDataParallelRunnerType + logger.info('Invoking SMDataParallel for native PT DDP job') elif smdataparallel_enabled: runner_type = runner.SMDataParallelRunnerType logger.info('Invoking SMDataParallel') diff --git a/test/unit/test_train.py b/test/unit/test_train.py index a82ef51..c9ce21b 100644 --- a/test/unit/test_train.py +++ b/test/unit/test_train.py @@ -90,6 +90,22 @@ def test_train_smdataparallel(run_module, training_env): ) +@patch("sagemaker_training.entry_point.run") +@patch('socket.gethostbyname', MagicMock()) +def test_train_pytorch_ddp(run_module, training_env): + training_env.additional_framework_parameters["sagemaker_pytorch_ddp_enabled"] = True + + train(training_env) + run_module.assert_called_with( + uri=training_env.module_dir, + user_entry_point=training_env.user_entry_point, + args=training_env.to_cmd_args(), + env_vars=training_env.to_env_vars(), + capture_error=True, + runner_type=runner.SMDataParallelRunnerType, + ) + + @patch('sagemaker_training.entry_point.run', MagicMock()) @patch('socket.gethostbyname', MagicMock()) def test_environment(training_env):