Skip to content

Commit c4075bb

Browse files
ranzvishreyapanditjeniyat
authored
fix: localmode subprocess parent process not sending SIGTERM to child (#2613)
Co-authored-by: Shreya Pandit <[email protected]> Co-authored-by: Jeniya Tabassum <[email protected]>
1 parent 9be4c8a commit c4075bb

File tree

2 files changed

+41
-1
lines changed

2 files changed

+41
-1
lines changed

src/sagemaker/local/image.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@
3232

3333
from distutils.spawn import find_executable
3434
from threading import Thread
35-
3635
from six.moves.urllib.parse import urlparse
3736

3837
import sagemaker
@@ -841,6 +840,8 @@ def run(self):
841840

842841
def down(self):
843842
"""Placeholder docstring"""
843+
if os.name != "nt":
844+
sagemaker.local.utils.kill_child_processes(self.process.pid)
844845
self.process.terminate()
845846

846847

src/sagemaker/local/utils.py

+39
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
import os
1717
import shutil
18+
import subprocess
1819

1920
from distutils.dir_util import copy_tree
2021
from six.moves.urllib.parse import urlparse
@@ -88,3 +89,41 @@ def recursive_copy(source, destination):
8889
"""
8990
if os.path.isdir(source):
9091
copy_tree(source, destination)
92+
93+
94+
def kill_child_processes(pid):
95+
"""Kill child processes
96+
97+
Kills all nested child process ids for a specific pid
98+
99+
Args:
100+
pid (int): process id
101+
"""
102+
child_pids = get_child_process_ids(pid)
103+
for child_pid in child_pids:
104+
os.kill(child_pid, 15)
105+
106+
107+
def get_child_process_ids(pid):
108+
"""Retrieve all child pids for a certain pid
109+
110+
Recursively scan each childs process tree and add it to the output
111+
112+
Args:
113+
pid (int): process id
114+
115+
Returns:
116+
(List[int]): Child process ids
117+
"""
118+
cmd = f"pgrep -P {pid}".split()
119+
output, err = subprocess.Popen(
120+
cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE
121+
).communicate()
122+
if err:
123+
return []
124+
pids = [int(pid) for pid in output.decode("utf-8").split()]
125+
if pids:
126+
for child_pid in pids:
127+
return pids + get_child_process_ids(child_pid)
128+
else:
129+
return []

0 commit comments

Comments
 (0)