Skip to content

Commit c3708bd

Browse files
committed
fix: localmode subprocess parent process not sending SIGTERM to child
1 parent 7b07b90 commit c3708bd

File tree

3 files changed

+24
-6
lines changed

3 files changed

+24
-6
lines changed

setup.py

-1
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,6 @@ def read_version():
5252
"urllib3>=1.21.1,!=1.25,!=1.25.1",
5353
"docker-compose>=1.25.2",
5454
"PyYAML>=5.3, <6", # PyYAML version has to match docker-compose requirements
55-
"psutil",
5655
],
5756
"scipy": ["scipy>=0.19.0"],
5857
}

src/sagemaker/local/image.py

+2-5
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,7 @@
3131
import tempfile
3232

3333
from distutils.spawn import find_executable
34-
from signal import SIGTERM
3534
from threading import Thread
36-
37-
import psutil
3835
from six.moves.urllib.parse import urlparse
3936

4037
import sagemaker
@@ -843,8 +840,8 @@ def run(self):
843840

844841
def down(self):
845842
"""Placeholder docstring"""
846-
for process in psutil.Process(self.process.pid).children():
847-
process.send_signal(SIGTERM)
843+
if os.name != 'nt':
844+
sagemaker.local.utils.kill_child_processes(self.process.pid)
848845
self.process.terminate()
849846

850847

src/sagemaker/local/utils.py

+22
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
import os
1717
import shutil
18+
import subprocess
1819

1920
from distutils.dir_util import copy_tree
2021
from six.moves.urllib.parse import urlparse
@@ -88,3 +89,24 @@ def recursive_copy(source, destination):
8889
"""
8990
if os.path.isdir(source):
9091
copy_tree(source, destination)
92+
93+
94+
def kill_child_processes(pid):
95+
child_pids = get_child_process_ids(pid)
96+
for pid in child_pids:
97+
os.kill(pid, 15)
98+
99+
100+
def get_child_process_ids(pid):
101+
cmd = f"pgrep -P {pid}".split()
102+
output, err = subprocess.Popen(
103+
cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE
104+
).communicate()
105+
if err:
106+
return []
107+
pids = [int(pid) for pid in output.decode('utf-8').split()]
108+
if pids:
109+
for pid in pids:
110+
return pids + get_child_process_ids(pid)
111+
else:
112+
return []

0 commit comments

Comments
 (0)