Skip to content

Commit e3639e7

Browse files
bmouryakrMourya Baddam
authored and
Namrata Madan
committed
Fix:Use prefix for env auto-capture (aws#909)
Co-authored-by: Mourya Baddam <[email protected]>
1 parent f50859f commit e3639e7

File tree

6 files changed

+47
-144
lines changed

6 files changed

+47
-144
lines changed

src/sagemaker/remote_function/job.py

+19-4
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@
6161
# training channel names
6262
RUNTIME_SCRIPTS_CHANNEL_NAME = "sagemaker_remote_function_bootstrap"
6363
REMOTE_FUNCTION_WORKSPACE = "sm_rf_user_ws"
64+
JOB_REMOTE_FUNCTION_WORKSPACE = "sagemaker_remote_function_workspace"
6465

6566
# run context dictionary keys
6667
KEY_EXPERIMENT_NAME = "experiment_name"
@@ -88,6 +89,16 @@
8889
printf "INFO: Bootstraping runtime environment.\\n"
8990
python /opt/ml/input/data/{RUNTIME_SCRIPTS_CHANNEL_NAME}/{BOOTSTRAP_SCRIPT_NAME} "$@"
9091
92+
if [ -d {JOB_REMOTE_FUNCTION_WORKSPACE} ]
93+
then
94+
if [ -f "remote_function_conda_env.txt" ]
95+
then
96+
cp remote_function_conda_env.txt {JOB_REMOTE_FUNCTION_WORKSPACE}/remote_function_conda_env.txt
97+
fi
98+
printf "INFO: Changing workspace to {JOB_REMOTE_FUNCTION_WORKSPACE}.\\n"
99+
cd {JOB_REMOTE_FUNCTION_WORKSPACE}
100+
fi
101+
91102
if [ -f "remote_function_conda_env.txt" ]
92103
then
93104
conda_env=$(cat remote_function_conda_env.txt)
@@ -553,9 +564,11 @@ def _prepare_and_upload_dependencies(
553564
):
554565
return None
555566

556-
with _tmpdir() as tmp_workspace:
567+
with _tmpdir() as tmp_dir:
568+
tmp_workspace_dir = os.path.join(tmp_dir, "temp_workspace/")
569+
os.mkdir(tmp_workspace_dir)
557570
# TODO Remove the following hack to avoid dir_exists error in the copy_tree call below.
558-
tmp_workspace = os.path.join(tmp_workspace, "remote_function/")
571+
tmp_workspace = os.path.join(tmp_workspace_dir, JOB_REMOTE_FUNCTION_WORKSPACE)
559572

560573
if include_local_workdir:
561574
shutil.copytree(
@@ -593,8 +606,10 @@ def _prepare_and_upload_dependencies(
593606
pre_execution_script,
594607
)
595608

596-
workspace_archive_path = os.path.join(tmp_workspace, "workspace")
597-
workspace_archive_path = shutil.make_archive(workspace_archive_path, "zip", tmp_workspace)
609+
workspace_archive_path = os.path.join(tmp_dir, "workspace")
610+
workspace_archive_path = shutil.make_archive(
611+
workspace_archive_path, "zip", tmp_workspace_dir
612+
)
598613
logger.info("Successfully created workdir archive at '%s'", workspace_archive_path)
599614

600615
upload_path = S3Uploader.upload(

src/sagemaker/remote_function/runtime_environment/bootstrap_runtime_environment.py

+2
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
BASE_CHANNEL_PATH = "/opt/ml/input/data"
3535
FAILURE_REASON_PATH = "/opt/ml/output/failure"
3636
PRE_EXECUTION_SCRIPT_NAME = "pre_exec.sh"
37+
JOB_REMOTE_FUNCTION_WORKSPACE = "sagemaker_remote_function_workspace"
3738

3839

3940
logger = get_logger()
@@ -94,6 +95,7 @@ def _bootstrap_runtime_environment(
9495
workspace_unpack_dir = pathlib.Path(os.getcwd()).absolute()
9596
shutil.unpack_archive(filename=workspace_archive_path, extract_dir=workspace_unpack_dir)
9697
logger.info("Successfully unpacked workspace archive at '%s'.", workspace_unpack_dir)
98+
workspace_unpack_dir = pathlib.Path(workspace_unpack_dir, JOB_REMOTE_FUNCTION_WORKSPACE)
9799

98100
# Handle pre-execution commands
99101
path_to_pre_exec_script = os.path.join(workspace_unpack_dir, PRE_EXECUTION_SCRIPT_NAME)

src/sagemaker/remote_function/runtime_environment/runtime_environment_manager.py

+11-46
Original file line numberDiff line numberDiff line change
@@ -95,28 +95,12 @@ def _capture_from_local_runtime(self) -> str:
9595

9696
# Try to capture dependencies from the conda environment, if any.
9797
conda_env_name = self._get_active_conda_env_name()
98-
logger.info("Found conda_env_name: '%s'", conda_env_name)
99-
conda_env_prefix = None
100-
101-
if conda_env_name is None:
102-
conda_env_prefix = self._get_active_conda_env_prefix()
98+
conda_env_prefix = self._get_active_conda_env_prefix()
99+
if conda_env_name:
100+
logger.info("Found conda_env_name: '%s'", conda_env_name)
101+
elif conda_env_prefix:
103102
logger.info("Found conda_env_prefix: '%s'", conda_env_prefix)
104-
if conda_env_prefix is None and self._get_studio_image_uri() is not None:
105-
logger.info(
106-
"Neither conda env name or prefix is set. Running Studio fallback logic"
107-
)
108-
# Fallback for Studio Notebooks since conda env is not activated to use as a
109-
# Jupyter kernel from images.
110-
# TODO: Remove after fixing the behavior for Studio Notebooks.
111-
which_python = self._get_which_python()
112-
prefix_candidate = which_python.replace("/bin/python", "")
113-
conda_env_list = self._get_conda_envs_list()
114-
if (
115-
conda_env_list.find(prefix_candidate + "\n") > 0
116-
): # need "\n" to match exact prefix; -1 for not found.
117-
conda_env_prefix = prefix_candidate
118-
119-
if conda_env_name is None and conda_env_prefix is None:
103+
else:
120104
raise ValueError("No conda environment seems to be active.")
121105

122106
if conda_env_name == "base":
@@ -126,25 +110,10 @@ def _capture_from_local_runtime(self) -> str:
126110
)
127111

128112
local_dependencies_path = os.path.join(os.getcwd(), "env_snapshot.yml")
129-
if conda_env_name is not None:
130-
self._export_conda_env_from_env_name(conda_env_name, local_dependencies_path)
131-
else:
132-
self._export_conda_env_from_prefix(conda_env_prefix, local_dependencies_path)
113+
self._export_conda_env_from_prefix(conda_env_prefix, local_dependencies_path)
133114

134115
return local_dependencies_path
135116

136-
def _get_conda_envs_list(self) -> str:
137-
"""Returns the registered list of conda environments on the system."""
138-
return _run_and_get_output_shell_cmd(f"{self._get_conda_exe()} env list")
139-
140-
def _get_which_python(self) -> str:
141-
"""Return the location of the current Python interpreter."""
142-
return _python_executable()
143-
144-
def _get_studio_image_uri(self) -> str:
145-
"""Returns the Sagemaker Image URI from the set environment variable. None otherwise."""
146-
return os.getenv("SAGEMAKER_INTERNAL_IMAGE_URI")
147-
148117
def _get_active_conda_env_prefix(self) -> str:
149118
"""Returns the conda prefix from the set environment variable. None otherwise."""
150119
return os.getenv("CONDA_PREFIX")
@@ -242,14 +211,6 @@ def _update_conda_env(self, env_name, local_path):
242211
_run_shell_cmd(cmd)
243212
logger.info("Conda env %s updated succesfully", env_name)
244213

245-
def _export_conda_env_from_env_name(self, env_name, local_path):
246-
"""Export the conda env to a conda yml file"""
247-
248-
cmd = f"{self._get_conda_exe()} env export -n {env_name} --no-builds > {local_path}"
249-
logger.info("Exporting conda environment: %s", cmd)
250-
_run_shell_cmd(cmd)
251-
logger.info("Conda environment %s exported successfully", env_name)
252-
253214
def _export_conda_env_from_prefix(self, prefix, local_path):
254215
"""Export the conda env to a conda yml file"""
255216

@@ -324,9 +285,13 @@ def _run_pre_execution_command_script(script_path: str):
324285
325286
Raises RuntimeEnvironmentError if the shell script fails
326287
"""
288+
current_dir = os.path.dirname(script_path)
327289

328290
process = subprocess.Popen(
329-
["/bin/bash", "-eu", script_path], stdout=subprocess.PIPE, stderr=subprocess.PIPE
291+
["/bin/bash", "-eu", script_path],
292+
stdout=subprocess.PIPE,
293+
stderr=subprocess.PIPE,
294+
cwd=current_dir,
330295
)
331296

332297
_log_output(process)

tests/integ/sagemaker/remote_function/conftest.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727

2828
DOCKERFILE_TEMPLATE = (
2929
"FROM public.ecr.aws/docker/library/python:{py_version}-slim\n\n"
30-
"WORKDIR /opt/ml/remote_function/\n"
3130
"RUN apt-get update -y \
3231
&& apt-get install -y unzip curl\n\n"
3332
"RUN curl 'https://awscli.amazonaws.com/awscli-exe-linux-x86_64.zip' -o 'awscliv2.zip' \
@@ -40,7 +39,6 @@
4039

4140
DOCKERFILE_TEMPLATE_WITH_CONDA = (
4241
"FROM public.ecr.aws/docker/library/python:{py_version}-slim\n\n"
43-
"WORKDIR /opt/ml/remote_function/\n"
4442
'SHELL ["/bin/bash", "-c"]\n'
4543
"RUN apt-get update -y \
4644
&& apt-get install -y unzip curl\n\n"
@@ -66,7 +64,7 @@
6664
"dependencies:\n"
6765
" - scipy=1.7.3\n"
6866
" - pip:\n"
69-
" - /opt/ml/remote_function/sagemaker-{sagemaker_version}.tar.gz\n"
67+
" - /sagemaker-{sagemaker_version}.tar.gz\n"
7068
"prefix: /opt/conda/bin/conda\n"
7169
)
7270

tests/unit/sagemaker/remote_function/runtime_environment/test_bootstrap_runtime_environment.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@
2121
import pathlib
2222

2323
TEST_JOB_CONDA_ENV = "conda_env"
24-
TEST_DEPENDENCIES_PATH = "/user/set/workdir"
24+
CURR_WORKING_DIR = "/user/set/workdir"
25+
TEST_DEPENDENCIES_PATH = "/user/set/workdir/sagemaker_remote_function_workspace"
2526
TEST_PYTHON_VERSION = "3.10"
2627
TEST_WORKSPACE_ARCHIVE_DIR_PATH = "/opt/ml/input/data/sm_rf_user_ws"
2728
TEST_WORKSPACE_ARCHIVE_PATH = "/opt/ml/input/data/sm_rf_user_ws/workspace.zip"
@@ -44,7 +45,7 @@ def mock_args():
4445
)
4546
@patch("sys.exit")
4647
@patch("shutil.unpack_archive", Mock())
47-
@patch("os.getcwd", return_value=TEST_DEPENDENCIES_PATH)
48+
@patch("os.getcwd", return_value=CURR_WORKING_DIR)
4849
@patch("os.path.exists", return_value=True)
4950
@patch("os.path.isfile", return_value=True)
5051
@patch("os.listdir", return_value=["fileA.py", "fileB.sh", "requirements.txt"])
@@ -175,7 +176,7 @@ def test_main_no_workspace_archive(
175176
@patch("shutil.unpack_archive", Mock())
176177
@patch("os.path.exists", return_value=True)
177178
@patch("os.path.isfile", return_value=True)
178-
@patch("os.getcwd", return_value=TEST_DEPENDENCIES_PATH)
179+
@patch("os.getcwd", return_value=CURR_WORKING_DIR)
179180
@patch("os.listdir", return_value=["fileA.py", "fileB.sh"])
180181
@patch(
181182
"sagemaker.remote_function.runtime_environment.runtime_environment_manager."

tests/unit/sagemaker/remote_function/runtime_environment/test_runtime_environment_manager.py

+10-88
Original file line numberDiff line numberDiff line change
@@ -90,17 +90,6 @@ def test__get_active_conda_env_prefix():
9090
assert call_arg == "CONDA_PREFIX"
9191

9292

93-
def test__get_studio_image_uri():
94-
with patch("os.getenv") as getenv_patch:
95-
getenv_patch.return_value = "some-sagemaker-studio-image-ecr-uri"
96-
97-
result = RuntimeEnvironmentManager()._get_studio_image_uri()
98-
99-
assert result == "some-sagemaker-studio-image-ecr-uri"
100-
call_arg = getenv_patch.call_args[0][0]
101-
assert call_arg == "SAGEMAKER_INTERNAL_IMAGE_URI"
102-
103-
10493
@patch(
10594
"sagemaker.remote_function.runtime_environment.runtime_environment_manager._log_error", Mock()
10695
)
@@ -117,7 +106,14 @@ def test__get_studio_image_uri():
117106
".RuntimeEnvironmentManager._get_active_conda_env_name",
118107
return_value="test_env",
119108
)
120-
def test_snapshot_from_active_conda_env_when_name_available(conda_default_env, stub_conda_exe):
109+
@patch(
110+
"sagemaker.remote_function.runtime_environment.runtime_environment_manager"
111+
".RuntimeEnvironmentManager._get_active_conda_env_prefix",
112+
return_value="/some/conda/env/prefix",
113+
)
114+
def test_snapshot_from_active_conda_env_when_name_available(
115+
conda_env_prefix, conda_default_env, stub_conda_exe
116+
):
121117
expected_result = os.path.join(os.getcwd(), "env_snapshot.yml")
122118
with patch("subprocess.Popen") as popen:
123119
popen.return_value.wait.return_value = 0
@@ -128,7 +124,7 @@ def test_snapshot_from_active_conda_env_when_name_available(conda_default_env, s
128124
call_args = popen.call_args[0][0]
129125
assert call_args is not None
130126
expected_cmd = (
131-
f"{stub_conda_exe.return_value} env export -n {conda_default_env.return_value} "
127+
f"{stub_conda_exe.return_value} env export -p {conda_env_prefix.return_value} "
132128
f"--no-builds > {expected_result}"
133129
)
134130
assert call_args == expected_cmd
@@ -189,81 +185,7 @@ def test_snapshot_from_active_conda_env_when_prefix_available(
189185
".RuntimeEnvironmentManager._get_active_conda_env_prefix",
190186
return_value=None,
191187
)
192-
@patch(
193-
"sagemaker.remote_function.runtime_environment.runtime_environment_manager"
194-
".RuntimeEnvironmentManager._get_studio_image_uri",
195-
return_value=None,
196-
)
197-
def test_snapshot_auto_capture_non_studio_no_active_conda_env(
198-
no_studio_image_uri, no_conda_env_prefix, no_conda_env_name
199-
):
200-
with pytest.raises(ValueError):
201-
RuntimeEnvironmentManager().snapshot("auto_capture")
202-
203-
204-
@patch(
205-
"sagemaker.remote_function.runtime_environment.runtime_environment_manager._log_error", Mock()
206-
)
207-
@patch(
208-
"sagemaker.remote_function.runtime_environment.runtime_environment_manager._log_output", Mock()
209-
)
210-
@patch(
211-
"sagemaker.remote_function.runtime_environment.runtime_environment_manager"
212-
".RuntimeEnvironmentManager._get_conda_exe",
213-
return_value="some-exe",
214-
)
215-
@patch(
216-
"sagemaker.remote_function.runtime_environment.runtime_environment_manager"
217-
".RuntimeEnvironmentManager._get_active_conda_env_name",
218-
return_value=None,
219-
)
220-
@patch(
221-
"sagemaker.remote_function.runtime_environment.runtime_environment_manager"
222-
".RuntimeEnvironmentManager._get_active_conda_env_prefix",
223-
return_value=None,
224-
)
225-
@patch(
226-
"sagemaker.remote_function.runtime_environment.runtime_environment_manager"
227-
".RuntimeEnvironmentManager._get_studio_image_uri",
228-
return_value="some-image-uri",
229-
)
230-
@patch(
231-
"sagemaker.remote_function.runtime_environment.runtime_environment_manager"
232-
".RuntimeEnvironmentManager._get_which_python",
233-
return_value="/some/prefix/bin/python",
234-
)
235-
@patch(
236-
"sagemaker.remote_function.runtime_environment.runtime_environment_manager"
237-
".RuntimeEnvironmentManager._get_conda_envs_list",
238-
return_value="something something /some/prefix\n something /someother/prefix\n\n",
239-
)
240-
def test_snapshot_auto_capture_in_studio_no_active_conda_env(
241-
ouptut_conda_env_list,
242-
output_which_python,
243-
some_studio_image_uri,
244-
no_conda_env_prefix,
245-
no_conda_env_name,
246-
stub_conda_exe,
247-
):
248-
expected_result = os.path.join(os.getcwd(), "env_snapshot.yml")
249-
with patch("subprocess.Popen") as popen:
250-
popen.return_value.wait.return_value = 0
251-
252-
result = RuntimeEnvironmentManager().snapshot("auto_capture")
253-
assert result == expected_result
254-
255-
call_args = popen.call_args[0][0]
256-
assert call_args is not None
257-
expected_cmd = f"{stub_conda_exe.return_value} env export -p /some/prefix --no-builds > {expected_result}"
258-
assert call_args == expected_cmd
259-
260-
261-
@patch(
262-
"sagemaker.remote_function.runtime_environment.runtime_environment_manager"
263-
".RuntimeEnvironmentManager._get_active_conda_env_name",
264-
return_value=None,
265-
)
266-
def test_snapshot_from_active_conda_env_error(conda_default_env):
188+
def test_snapshot_auto_capture_no_active_conda_env(no_conda_env_prefix, no_conda_env_name):
267189
with pytest.raises(ValueError):
268190
RuntimeEnvironmentManager().snapshot("auto_capture")
269191

0 commit comments

Comments
 (0)