Skip to content

fix: localmode subprocess parent process not sending SIGTERM to child #2613

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Nov 1, 2021
Merged
3 changes: 2 additions & 1 deletion src/sagemaker/local/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@

from distutils.spawn import find_executable
from threading import Thread

from six.moves.urllib.parse import urlparse

import sagemaker
Expand Down Expand Up @@ -841,6 +840,8 @@ def run(self):

def down(self):
"""Placeholder docstring"""
if os.name != "nt":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is this check for?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's checking if the os is a windows machine, since I'm using unix commands to kill the processes

sagemaker.local.utils.kill_child_processes(self.process.pid)
self.process.terminate()


Expand Down
39 changes: 39 additions & 0 deletions src/sagemaker/local/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import os
import shutil
import subprocess

from distutils.dir_util import copy_tree
from six.moves.urllib.parse import urlparse
Expand Down Expand Up @@ -88,3 +89,41 @@ def recursive_copy(source, destination):
"""
if os.path.isdir(source):
copy_tree(source, destination)


def kill_child_processes(pid):
"""Kill child processes

Kills all nested child process ids for a specific pid

Args:
pid (int): process id
"""
child_pids = get_child_process_ids(pid)
for child_pid in child_pids:
os.kill(child_pid, 15)


def get_child_process_ids(pid):
"""Retrieve all child pids for a certain pid

Recursively scan each childs process tree and add it to the output

Args:
pid (int): process id

Returns:
(List[int]): Child process ids
"""
cmd = f"pgrep -P {pid}".split()
output, err = subprocess.Popen(
cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE
).communicate()
if err:
return []
pids = [int(pid) for pid in output.decode("utf-8").split()]
if pids:
for child_pid in pids:
return pids + get_child_process_ids(child_pid)
else:
return []