Skip to content

Commit 05fad97

Browse files
authored
feature: add support for native PyTorch DDP distribution (#236)
1 parent 37d880f commit 05fad97

File tree

2 files changed

+26
-1
lines changed

2 files changed

+26
-1
lines changed

src/sagemaker_pytorch_container/training.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121

2222
MASTER_PORT = '7777'
2323
LAUNCH_SMDATAPARALLEL_ENV_NAME = 'sagemaker_distributed_dataparallel_enabled'
24+
LAUNCH_MPI_ENV_NAME = 'sagemaker_mpi_enabled'
25+
LAUNCH_PYTORCH_DDP_ENV_NAME = "sagemaker_pytorch_ddp_enabled"
2426

2527
logger = logging.getLogger(__name__)
2628

@@ -49,7 +51,11 @@ def train(training_environment):
4951

5052
_set_distributed_environment(training_environment)
5153

52-
mpi_enabled = training_environment.additional_framework_parameters.get('sagemaker_mpi_enabled')
54+
mpi_enabled = training_environment.additional_framework_parameters.get(LAUNCH_MPI_ENV_NAME)
55+
56+
pytorch_ddp_enabled = training_environment.additional_framework_parameters.get(
57+
LAUNCH_PYTORCH_DDP_ENV_NAME, False
58+
)
5359

5460
smdataparallel_enabled = training_environment.additional_framework_parameters.get(
5561
LAUNCH_SMDATAPARALLEL_ENV_NAME, False
@@ -60,6 +66,9 @@ def train(training_environment):
6066
if training_environment.current_instance_group in training_environment.distribution_instance_groups:
6167
if mpi_enabled:
6268
runner_type = runner.MPIRunnerType
69+
elif pytorch_ddp_enabled:
70+
runner_type = runner.SMDataParallelRunnerType
71+
logger.info('Invoking SMDataParallel for native PT DDP job')
6372
elif smdataparallel_enabled:
6473
runner_type = runner.SMDataParallelRunnerType
6574
logger.info('Invoking SMDataParallel')

test/unit/test_train.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,22 @@ def test_train_smdataparallel(run_module, training_env):
9090
)
9191

9292

93+
@patch("sagemaker_training.entry_point.run")
94+
@patch('socket.gethostbyname', MagicMock())
95+
def test_train_pytorch_ddp(run_module, training_env):
96+
training_env.additional_framework_parameters["sagemaker_pytorch_ddp_enabled"] = True
97+
98+
train(training_env)
99+
run_module.assert_called_with(
100+
uri=training_env.module_dir,
101+
user_entry_point=training_env.user_entry_point,
102+
args=training_env.to_cmd_args(),
103+
env_vars=training_env.to_env_vars(),
104+
capture_error=True,
105+
runner_type=runner.SMDataParallelRunnerType,
106+
)
107+
108+
93109
@patch('sagemaker_training.entry_point.run', MagicMock())
94110
@patch('socket.gethostbyname', MagicMock())
95111
def test_environment(training_env):

0 commit comments

Comments
 (0)