Skip to content

Commit ab14621

Browse files
authored
Feature: Support new distribution mechanism for PT-XLA (#241)
* Support new distribution mechanism for PT-XLA * test: Adding test to check new PT-XLA distribution mechanism
1 parent 9535a12 commit ab14621

File tree

3 files changed

+31
-1
lines changed

3 files changed

+31
-1
lines changed

src/sagemaker_pytorch_container/training.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
LAUNCH_SMDATAPARALLEL_ENV_NAME = 'sagemaker_distributed_dataparallel_enabled'
2424
LAUNCH_MPI_ENV_NAME = 'sagemaker_mpi_enabled'
2525
LAUNCH_PYTORCH_DDP_ENV_NAME = "sagemaker_pytorch_ddp_enabled"
26+
LAUNCH_PYTORCH_XLA_ENV_NAME = "sagemaker_pytorch_xla_multi_worker_enabled"
2627

2728
logger = logging.getLogger(__name__)
2829

@@ -60,6 +61,10 @@ def train(training_environment):
6061
smdataparallel_enabled = training_environment.additional_framework_parameters.get(
6162
LAUNCH_SMDATAPARALLEL_ENV_NAME, False
6263
)
64+
65+
pytorch_xla_enabled = training_environment.additional_framework_parameters.get(
66+
LAUNCH_PYTORCH_XLA_ENV_NAME, False
67+
)
6368
# default scenario
6469
runner_type = runner.ProcessRunnerType
6570

@@ -72,6 +77,9 @@ def train(training_environment):
7277
elif smdataparallel_enabled:
7378
runner_type = runner.SMDataParallelRunnerType
7479
logger.info('Invoking SMDataParallel')
80+
elif pytorch_xla_enabled:
81+
runner_type = runner.PyTorchXLARunnerType
82+
logger.info('Invoking PT-XLA Runner')
7583
logger.info('Invoking user training script.')
7684
try:
7785
entry_point.run(uri=training_environment.module_dir,
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
ARG region
2+
FROM 763104351884.dkr.ecr.$region.amazonaws.com/huggingface-pytorch-trcomp-training:1.10.2-transformers4.17.0-gpu-py38-cu113-ubuntu20.04
3+
4+
COPY dist/sagemaker_pytorch_training-*.tar.gz /sagemaker_pytorch_training.tar.gz
5+
RUN pip install --upgrade --no-cache-dir /sagemaker_pytorch_training.tar.gz && \
6+
rm /sagemaker_pytorch_training.tar.gz

test/unit/test_train.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from mock import MagicMock, PropertyMock
2424
from mock import patch
2525

26-
from sagemaker_pytorch_container.training import main, train, _dns_lookup, MASTER_PORT
26+
from sagemaker_pytorch_container.training import main, train, _dns_lookup, LAUNCH_PYTORCH_XLA_ENV_NAME, MASTER_PORT
2727

2828

2929
@pytest.fixture(name='training_env')
@@ -106,6 +106,22 @@ def test_train_pytorch_ddp(run_module, training_env):
106106
)
107107

108108

109+
@patch("sagemaker_training.entry_point.run")
110+
@patch('socket.gethostbyname', MagicMock())
111+
def test_train_pytorch_xla_distributed(run_module, training_env):
112+
training_env.additional_framework_parameters[LAUNCH_PYTORCH_XLA_ENV_NAME] = True
113+
114+
train(training_env)
115+
run_module.assert_called_with(
116+
uri=training_env.module_dir,
117+
user_entry_point=training_env.user_entry_point,
118+
args=training_env.to_cmd_args(),
119+
env_vars=training_env.to_env_vars(),
120+
capture_error=True,
121+
runner_type=runner.PyTorchXLARunnerType,
122+
)
123+
124+
109125
@patch('sagemaker_training.entry_point.run', MagicMock())
110126
@patch('socket.gethostbyname', MagicMock())
111127
def test_environment(training_env):

0 commit comments

Comments
 (0)