Skip to content

Commit 305c6eb

Browse files
feature: Add torch_distributed support for Trainium instaces
1 parent 7f2a8c2 commit 305c6eb

File tree

3 files changed

+25
-1
lines changed

3 files changed

+25
-1
lines changed

setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def read(fname):
5353
'Programming Language :: Python :: 3.9',
5454
],
5555

56-
install_requires=['retrying', 'sagemaker-training>=4.2.0', 'six>=1.12.0'],
56+
install_requires=['retrying', 'sagemaker-training>=4.2.10', 'six>=1.12.0'],
5757
extras_require={
5858
'test': test_dependencies
5959
},

src/sagemaker_pytorch_container/training.py

+8
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
LAUNCH_MPI_ENV_NAME = 'sagemaker_mpi_enabled'
2525
LAUNCH_PYTORCH_DDP_ENV_NAME = "sagemaker_pytorch_ddp_enabled"
2626
LAUNCH_PYTORCH_XLA_ENV_NAME = "sagemaker_pytorch_xla_multi_worker_enabled"
27+
LAUNCH_TORCH_DISTRIBUTED_ENV_NAME = "sagemaker_torch_distributed_enabled"
2728

2829
logger = logging.getLogger(__name__)
2930

@@ -65,6 +66,10 @@ def train(training_environment):
6566
pytorch_xla_enabled = training_environment.additional_framework_parameters.get(
6667
LAUNCH_PYTORCH_XLA_ENV_NAME, False
6768
)
69+
70+
torch_distributed_enabled = training_environment.additional_framework_parameters.get(
71+
LAUNCH_TORCH_DISTRIBUTED_ENV_NAME, False
72+
)
6873
# default scenario
6974
runner_type = runner.ProcessRunnerType
7075

@@ -74,6 +79,9 @@ def train(training_environment):
7479
elif pytorch_ddp_enabled:
7580
runner_type = runner.SMDataParallelRunnerType
7681
logger.info('Invoking SMDataParallel for native PT DDP job')
82+
elif torch_distributed_enabled:
83+
runner_type = runner.TorchDistributedRunnerType
84+
logger.info('Invoking TorchDistributed...')
7785
elif smdataparallel_enabled:
7886
runner_type = runner.SMDataParallelRunnerType
7987
logger.info('Invoking SMDataParallel')

test/unit/test_train.py

+16
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,22 @@ def test_train_pytorch_ddp(run_module, training_env):
120120
)
121121

122122

123+
@patch("sagemaker_training.entry_point.run")
124+
@patch('socket.gethostbyname', MagicMock())
125+
def test_train_torch_distributed(run_module, training_env):
126+
training_env.additional_framework_parameters["sagemaker_torch_distributed_enabled"] = True
127+
128+
train(training_env)
129+
run_module.assert_called_with(
130+
uri=training_env.module_dir,
131+
user_entry_point=training_env.user_entry_point,
132+
args=training_env.to_cmd_args(),
133+
env_vars=training_env.to_env_vars(),
134+
capture_error=True,
135+
runner_type=runner.TorchDistributedRunnerType,
136+
)
137+
138+
123139
@patch("sagemaker_training.entry_point.run")
124140
@patch('socket.gethostbyname', MagicMock())
125141
def test_train_pytorch_xla_distributed(run_module, training_env):

0 commit comments

Comments
 (0)