Skip to content

Commit 1b0cc39

Browse files
fix: local mode - support relative file structure
1 parent 9d8020b commit 1b0cc39

File tree

2 files changed

+46
-6
lines changed

2 files changed

+46
-6
lines changed

src/sagemaker/local/utils.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,8 @@ def move_to_destination(source, destination, job_name, sagemaker_session):
6464
"""
6565
parsed_uri = urlparse(destination)
6666
if parsed_uri.scheme == "file":
67-
recursive_copy(source, parsed_uri.path)
67+
dir_path = os.path.abspath(parsed_uri.netloc + parsed_uri.path)
68+
recursive_copy(source, dir_path)
6869
final_uri = destination
6970
elif parsed_uri.scheme == "s3":
7071
bucket = parsed_uri.netloc
@@ -116,9 +117,8 @@ def get_child_process_ids(pid):
116117
(List[int]): Child process ids
117118
"""
118119
cmd = f"pgrep -P {pid}".split()
119-
output, err = subprocess.Popen(
120-
cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE
121-
).communicate()
120+
process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
121+
output, err = process.communicate()
122122
if err:
123123
return []
124124
pids = [int(pid) for pid in output.decode("utf-8").split()]

tests/unit/test_local_utils.py

Lines changed: 42 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,18 +12,31 @@
1212
# language governing permissions and limitations under the License.
1313
from __future__ import absolute_import
1414

15+
import os
1516
import pytest
1617
from mock import patch, Mock
1718

1819
import sagemaker.local.utils
1920

2021

22+
@patch("sagemaker.local.utils.os.path")
23+
@patch("sagemaker.local.utils.os")
24+
def test_copy_directory_structure(m_os, m_os_path):
25+
m_os_path.exists.return_value = False
26+
sagemaker.local.utils.copy_directory_structure("/tmp/", "code/")
27+
m_os.makedirs.assert_called_with("/tmp/", "code/")
28+
29+
2130
@patch("shutil.rmtree", Mock())
2231
@patch("sagemaker.local.utils.recursive_copy")
2332
def test_move_to_destination_local(recursive_copy):
2433
# local files will just be recursively copied
25-
sagemaker.local.utils.move_to_destination("/tmp/data", "file:///target/dir/", "job", None)
26-
recursive_copy.assert_called_with("/tmp/data", "/target/dir/")
34+
# given absolute path
35+
sagemaker.local.utils.move_to_destination("/tmp/data", "file:///target/dir", "job", None)
36+
recursive_copy.assert_called_with("/tmp/data", "/target/dir")
37+
# given relative path
38+
sagemaker.local.utils.move_to_destination("/tmp/data", "file://root/target/dir", "job", None)
39+
recursive_copy.assert_called_with("/tmp/data", os.path.abspath("./root/target/dir"))
2740

2841

2942
@patch("shutil.rmtree", Mock())
@@ -52,3 +65,30 @@ def test_move_to_destination_s3(recursive_copy):
5265
def test_move_to_destination_illegal_destination():
5366
with pytest.raises(ValueError):
5467
sagemaker.local.utils.move_to_destination("/tmp/data", "ftp://ftp/in/2018", "job", None)
68+
69+
70+
@patch("sagemaker.local.utils.os.path")
71+
@patch("sagemaker.local.utils.copy_tree")
72+
def test_recursive_copy(copy_tree, m_os_path):
73+
m_os_path.isdir.return_value = True
74+
sagemaker.local.utils.recursive_copy("source", "destination")
75+
copy_tree.assert_called_with("source", "destination")
76+
77+
78+
@patch("sagemaker.local.utils.os")
79+
@patch("sagemaker.local.utils.get_child_process_ids")
80+
def test_kill_child_processes(m_get_child_process_ids, m_os):
81+
m_get_child_process_ids.return_value = ["child_pids"]
82+
sagemaker.local.utils.kill_child_processes("pid")
83+
m_os.kill.assert_called_with("child_pids", 15)
84+
85+
86+
@patch("sagemaker.local.utils.subprocess")
87+
def test_get_child_process_ids(m_subprocess):
88+
cmd = "pgrep -P pid".split()
89+
process_mock = Mock()
90+
attrs = {"communicate.return_value": (b"\n", False), "returncode": 0}
91+
process_mock.configure_mock(**attrs)
92+
m_subprocess.Popen.return_value = process_mock
93+
sagemaker.local.utils.get_child_process_ids("pid")
94+
m_subprocess.Popen.assert_called_with(cmd, stdout=m_subprocess.PIPE, stderr=m_subprocess.PIPE)

0 commit comments

Comments
 (0)