File tree 2 files changed +26
-1
lines changed
src/sagemaker_pytorch_container
2 files changed +26
-1
lines changed Original file line number Diff line number Diff line change 21
21
22
22
MASTER_PORT = '7777'
23
23
LAUNCH_SMDATAPARALLEL_ENV_NAME = 'sagemaker_distributed_dataparallel_enabled'
24
+ LAUNCH_MPI_ENV_NAME = 'sagemaker_mpi_enabled'
25
+ LAUNCH_PYTORCH_DDP_ENV_NAME = "sagemaker_pytorch_ddp_enabled"
24
26
25
27
logger = logging .getLogger (__name__ )
26
28
@@ -49,7 +51,11 @@ def train(training_environment):
49
51
50
52
_set_distributed_environment (training_environment )
51
53
52
- mpi_enabled = training_environment .additional_framework_parameters .get ('sagemaker_mpi_enabled' )
54
+ mpi_enabled = training_environment .additional_framework_parameters .get (LAUNCH_MPI_ENV_NAME )
55
+
56
+ pytorch_ddp_enabled = training_environment .additional_framework_parameters .get (
57
+ LAUNCH_PYTORCH_DDP_ENV_NAME , False
58
+ )
53
59
54
60
smdataparallel_enabled = training_environment .additional_framework_parameters .get (
55
61
LAUNCH_SMDATAPARALLEL_ENV_NAME , False
@@ -60,6 +66,9 @@ def train(training_environment):
60
66
if training_environment .current_instance_group in training_environment .distribution_instance_groups :
61
67
if mpi_enabled :
62
68
runner_type = runner .MPIRunnerType
69
+ elif pytorch_ddp_enabled :
70
+ runner_type = runner .SMDataParallelRunnerType
71
+ logger .info ('Invoking SMDataParallel for native PT DDP job' )
63
72
elif smdataparallel_enabled :
64
73
runner_type = runner .SMDataParallelRunnerType
65
74
logger .info ('Invoking SMDataParallel' )
Original file line number Diff line number Diff line change @@ -90,6 +90,22 @@ def test_train_smdataparallel(run_module, training_env):
90
90
)
91
91
92
92
93
+ @patch ("sagemaker_training.entry_point.run" )
94
+ @patch ('socket.gethostbyname' , MagicMock ())
95
+ def test_train_pytorch_ddp (run_module , training_env ):
96
+ training_env .additional_framework_parameters ["sagemaker_pytorch_ddp_enabled" ] = True
97
+
98
+ train (training_env )
99
+ run_module .assert_called_with (
100
+ uri = training_env .module_dir ,
101
+ user_entry_point = training_env .user_entry_point ,
102
+ args = training_env .to_cmd_args (),
103
+ env_vars = training_env .to_env_vars (),
104
+ capture_error = True ,
105
+ runner_type = runner .SMDataParallelRunnerType ,
106
+ )
107
+
108
+
93
109
@patch ('sagemaker_training.entry_point.run' , MagicMock ())
94
110
@patch ('socket.gethostbyname' , MagicMock ())
95
111
def test_environment (training_env ):
You can’t perform that action at this time.
0 commit comments