diff --git a/src/sagemaker/local/image.py b/src/sagemaker/local/image.py index 1c73833a4f..f0a3ed8579 100644 --- a/src/sagemaker/local/image.py +++ b/src/sagemaker/local/image.py @@ -32,7 +32,6 @@ from distutils.spawn import find_executable from threading import Thread - from six.moves.urllib.parse import urlparse import sagemaker @@ -841,6 +840,8 @@ def run(self): def down(self): """Placeholder docstring""" + if os.name != "nt": + sagemaker.local.utils.kill_child_processes(self.process.pid) self.process.terminate() diff --git a/src/sagemaker/local/utils.py b/src/sagemaker/local/utils.py index d9455041ab..5a8ce03282 100644 --- a/src/sagemaker/local/utils.py +++ b/src/sagemaker/local/utils.py @@ -15,6 +15,7 @@ import os import shutil +import subprocess from distutils.dir_util import copy_tree from six.moves.urllib.parse import urlparse @@ -88,3 +89,41 @@ def recursive_copy(source, destination): """ if os.path.isdir(source): copy_tree(source, destination) + + +def kill_child_processes(pid): + """Kill child processes + + Kills all nested child process ids for a specific pid + + Args: + pid (int): process id + """ + child_pids = get_child_process_ids(pid) + for child_pid in child_pids: + os.kill(child_pid, 15) + + +def get_child_process_ids(pid): + """Retrieve all child pids for a certain pid + + Recursively scan each childs process tree and add it to the output + + Args: + pid (int): process id + + Returns: + (List[int]): Child process ids + """ + cmd = f"pgrep -P {pid}".split() + output, err = subprocess.Popen( + cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE + ).communicate() + if err: + return [] + pids = [int(pid) for pid in output.decode("utf-8").split()] + if pids: + for child_pid in pids: + return pids + get_child_process_ids(child_pid) + else: + return []