Skip to content

Commit 13f9c45

Browse files
fix: provide option to use native process launcher (#244)
fix: add test cases
1 parent ad1724d commit 13f9c45

File tree

2 files changed

+22
-1
lines changed

2 files changed

+22
-1
lines changed

src/sagemaker_pytorch_container/training.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,12 +81,19 @@ def train(training_environment):
8181
runner_type = runner.PyTorchXLARunnerType
8282
logger.info('Invoking PT-XLA Runner')
8383
logger.info('Invoking user training script.')
84+
85+
# get capture_error from framework parameters
86+
capture_error = True
87+
if training_environment.additional_framework_parameters.get("sagemaker_toolkit_native_launcher_enabled"):
88+
capture_error = False
89+
logger.info(f'capture_error is {capture_error}. Default is True')
90+
8491
try:
8592
entry_point.run(uri=training_environment.module_dir,
8693
user_entry_point=training_environment.user_entry_point,
8794
args=training_environment.to_cmd_args(),
8895
env_vars=training_environment.to_env_vars(),
89-
capture_error=True,
96+
capture_error=capture_error,
9097
runner_type=runner_type)
9198
except errors.ExecuteUserScriptError as err:
9299
message = str(err)

test/unit/test_train.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,20 @@ def test_train(run_entry_point, training_env):
7474
runner_type=runner.ProcessRunnerType)
7575

7676

77+
@patch('sagemaker_training.entry_point.run')
78+
@patch('socket.gethostbyname', MagicMock())
79+
def test_train_no_capture_error(run_entry_point, training_env):
80+
training_env.additional_framework_parameters["sagemaker_toolkit_native_launcher_enabled"] = True
81+
train(training_env)
82+
83+
run_entry_point.assert_called_with(uri=training_env.module_dir,
84+
user_entry_point=training_env.user_entry_point,
85+
args=training_env.to_cmd_args(),
86+
env_vars=training_env.to_env_vars(),
87+
capture_error=False,
88+
runner_type=runner.ProcessRunnerType)
89+
90+
7791
@patch("sagemaker_training.entry_point.run")
7892
@patch('socket.gethostbyname', MagicMock())
7993
def test_train_smdataparallel(run_module, training_env):

0 commit comments

Comments
 (0)