Skip to content

Commit ae2a2f8

Browse files
fix: local mode - support relative file structure (#2768)
1 parent 1ff8ae3 commit ae2a2f8

File tree

3 files changed

+52
-9
lines changed

3 files changed

+52
-9
lines changed

src/sagemaker/local/image.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -277,7 +277,8 @@ def serve(self, model_dir, environment):
277277
script_dir = environment[sagemaker.estimator.DIR_PARAM_NAME.upper()]
278278
parsed_uri = urlparse(script_dir)
279279
if parsed_uri.scheme == "file":
280-
volumes.append(_Volume(parsed_uri.path, "/opt/ml/code"))
280+
host_dir = os.path.abspath(parsed_uri.netloc + parsed_uri.path)
281+
volumes.append(_Volume(host_dir, "/opt/ml/code"))
281282
# Update path to mount location
282283
environment = environment.copy()
283284
environment[sagemaker.estimator.DIR_PARAM_NAME.upper()] = "/opt/ml/code"
@@ -495,7 +496,8 @@ def _prepare_training_volumes(
495496
training_dir = json.loads(hyperparameters[sagemaker.estimator.DIR_PARAM_NAME])
496497
parsed_uri = urlparse(training_dir)
497498
if parsed_uri.scheme == "file":
498-
volumes.append(_Volume(parsed_uri.path, "/opt/ml/code"))
499+
host_dir = os.path.abspath(parsed_uri.netloc + parsed_uri.path)
500+
volumes.append(_Volume(host_dir, "/opt/ml/code"))
499501
# Also mount a directory that all the containers can access.
500502
volumes.append(_Volume(shared_dir, "/opt/ml/shared"))
501503

@@ -504,7 +506,8 @@ def _prepare_training_volumes(
504506
parsed_uri.scheme == "file"
505507
and sagemaker.model.SAGEMAKER_OUTPUT_LOCATION in hyperparameters
506508
):
507-
intermediate_dir = os.path.join(parsed_uri.path, "output", "intermediate")
509+
dir_path = os.path.abspath(parsed_uri.netloc + parsed_uri.path)
510+
intermediate_dir = os.path.join(dir_path, "output", "intermediate")
508511
if not os.path.exists(intermediate_dir):
509512
os.makedirs(intermediate_dir)
510513
volumes.append(_Volume(intermediate_dir, "/opt/ml/output/intermediate"))

src/sagemaker/local/utils.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,8 @@ def move_to_destination(source, destination, job_name, sagemaker_session):
6464
"""
6565
parsed_uri = urlparse(destination)
6666
if parsed_uri.scheme == "file":
67-
recursive_copy(source, parsed_uri.path)
67+
dir_path = os.path.abspath(parsed_uri.netloc + parsed_uri.path)
68+
recursive_copy(source, dir_path)
6869
final_uri = destination
6970
elif parsed_uri.scheme == "s3":
7071
bucket = parsed_uri.netloc
@@ -116,9 +117,8 @@ def get_child_process_ids(pid):
116117
(List[int]): Child process ids
117118
"""
118119
cmd = f"pgrep -P {pid}".split()
119-
output, err = subprocess.Popen(
120-
cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE
121-
).communicate()
120+
process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
121+
output, err = process.communicate()
122122
if err:
123123
return []
124124
pids = [int(pid) for pid in output.decode("utf-8").split()]

tests/unit/test_local_utils.py

+42-2
Original file line numberDiff line numberDiff line change
@@ -12,18 +12,31 @@
1212
# language governing permissions and limitations under the License.
1313
from __future__ import absolute_import
1414

15+
import os
1516
import pytest
1617
from mock import patch, Mock
1718

1819
import sagemaker.local.utils
1920

2021

22+
@patch("sagemaker.local.utils.os.path")
23+
@patch("sagemaker.local.utils.os")
24+
def test_copy_directory_structure(m_os, m_os_path):
25+
m_os_path.exists.return_value = False
26+
sagemaker.local.utils.copy_directory_structure("/tmp/", "code/")
27+
m_os.makedirs.assert_called_with("/tmp/", "code/")
28+
29+
2130
@patch("shutil.rmtree", Mock())
2231
@patch("sagemaker.local.utils.recursive_copy")
2332
def test_move_to_destination_local(recursive_copy):
2433
# local files will just be recursively copied
25-
sagemaker.local.utils.move_to_destination("/tmp/data", "file:///target/dir/", "job", None)
26-
recursive_copy.assert_called_with("/tmp/data", "/target/dir/")
34+
# given absolute path
35+
sagemaker.local.utils.move_to_destination("/tmp/data", "file:///target/dir", "job", None)
36+
recursive_copy.assert_called_with("/tmp/data", "/target/dir")
37+
# given relative path
38+
sagemaker.local.utils.move_to_destination("/tmp/data", "file://root/target/dir", "job", None)
39+
recursive_copy.assert_called_with("/tmp/data", os.path.abspath("./root/target/dir"))
2740

2841

2942
@patch("shutil.rmtree", Mock())
@@ -52,3 +65,30 @@ def test_move_to_destination_s3(recursive_copy):
5265
def test_move_to_destination_illegal_destination():
5366
with pytest.raises(ValueError):
5467
sagemaker.local.utils.move_to_destination("/tmp/data", "ftp://ftp/in/2018", "job", None)
68+
69+
70+
@patch("sagemaker.local.utils.os.path")
71+
@patch("sagemaker.local.utils.copy_tree")
72+
def test_recursive_copy(copy_tree, m_os_path):
73+
m_os_path.isdir.return_value = True
74+
sagemaker.local.utils.recursive_copy("source", "destination")
75+
copy_tree.assert_called_with("source", "destination")
76+
77+
78+
@patch("sagemaker.local.utils.os")
79+
@patch("sagemaker.local.utils.get_child_process_ids")
80+
def test_kill_child_processes(m_get_child_process_ids, m_os):
81+
m_get_child_process_ids.return_value = ["child_pids"]
82+
sagemaker.local.utils.kill_child_processes("pid")
83+
m_os.kill.assert_called_with("child_pids", 15)
84+
85+
86+
@patch("sagemaker.local.utils.subprocess")
87+
def test_get_child_process_ids(m_subprocess):
88+
cmd = "pgrep -P pid".split()
89+
process_mock = Mock()
90+
attrs = {"communicate.return_value": (b"\n", False), "returncode": 0}
91+
process_mock.configure_mock(**attrs)
92+
m_subprocess.Popen.return_value = process_mock
93+
sagemaker.local.utils.get_child_process_ids("pid")
94+
m_subprocess.Popen.assert_called_with(cmd, stdout=m_subprocess.PIPE, stderr=m_subprocess.PIPE)

0 commit comments

Comments
 (0)