File tree Expand file tree Collapse file tree 2 files changed +22
-1
lines changed
src/sagemaker_pytorch_container Expand file tree Collapse file tree 2 files changed +22
-1
lines changed Original file line number Diff line number Diff line change @@ -81,12 +81,19 @@ def train(training_environment):
81
81
runner_type = runner .PyTorchXLARunnerType
82
82
logger .info ('Invoking PT-XLA Runner' )
83
83
logger .info ('Invoking user training script.' )
84
+
85
+ # get capture_error from framework parameters
86
+ capture_error = True
87
+ if training_environment .additional_framework_parameters .get ("sagemaker_toolkit_native_launcher_enabled" ):
88
+ capture_error = False
89
+ logger .info (f'capture_error is { capture_error } . Default is True' )
90
+
84
91
try :
85
92
entry_point .run (uri = training_environment .module_dir ,
86
93
user_entry_point = training_environment .user_entry_point ,
87
94
args = training_environment .to_cmd_args (),
88
95
env_vars = training_environment .to_env_vars (),
89
- capture_error = True ,
96
+ capture_error = capture_error ,
90
97
runner_type = runner_type )
91
98
except errors .ExecuteUserScriptError as err :
92
99
message = str (err )
Original file line number Diff line number Diff line change @@ -74,6 +74,20 @@ def test_train(run_entry_point, training_env):
74
74
runner_type = runner .ProcessRunnerType )
75
75
76
76
77
+ @patch ('sagemaker_training.entry_point.run' )
78
+ @patch ('socket.gethostbyname' , MagicMock ())
79
+ def test_train_no_capture_error (run_entry_point , training_env ):
80
+ training_env .additional_framework_parameters ["sagemaker_toolkit_native_launcher_enabled" ] = True
81
+ train (training_env )
82
+
83
+ run_entry_point .assert_called_with (uri = training_env .module_dir ,
84
+ user_entry_point = training_env .user_entry_point ,
85
+ args = training_env .to_cmd_args (),
86
+ env_vars = training_env .to_env_vars (),
87
+ capture_error = False ,
88
+ runner_type = runner .ProcessRunnerType )
89
+
90
+
77
91
@patch ("sagemaker_training.entry_point.run" )
78
92
@patch ('socket.gethostbyname' , MagicMock ())
79
93
def test_train_smdataparallel (run_module , training_env ):
You can’t perform that action at this time.
0 commit comments