Skip to content

Commit 32f01f6

Browse files
rohangujarathiRohan Gujarathi
authored and
Namrata Madan
committed
feature: support runtime customization with req.txt file (aws#814)
Co-authored-by: Rohan Gujarathi <[email protected]>
1 parent bff2f8c commit 32f01f6

File tree

9 files changed

+477
-22
lines changed

9 files changed

+477
-22
lines changed

src/sagemaker/remote_function/client.py

+12-3
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ def remote(
5656
image_uri: str = None,
5757
instance_count: int = 1,
5858
instance_type: str = None,
59+
job_conda_env: str = None,
5960
keep_alive_period_in_seconds: int = 0,
6061
max_retry_attempts: int = 1,
6162
max_runtime_in_seconds: int = 24 * 60 * 60,
@@ -74,11 +75,14 @@ def remote(
7475
7576
Args:
7677
_func (Optional): Python function to be executed on the SageMaker job runtime environment.
77-
dependencies (str): Path to dependencies file or a reserved keyword ``AUTO_DETECT``.
78+
dependencies (str): Path to dependencies file or a reserved keyword
79+
``from_active_conda_env``. Defaults to None.
7880
environment_variables (Dict): environment variables
7981
image_uri (str): Docker image URI on ECR.
8082
instance_count (int): Number of instance to use. Default is 1.
8183
instance_type (str): EC2 instance type.
84+
job_conda_env (str): Name of the conda environment to activate during execution of the job.
85+
Default is None.
8286
keep_alive_period_in_seconds (int): The duration of time in seconds to retain configured
8387
resources in a warm pool for subsequent training jobs. Default is 0.
8488
max_retry_attempts (int): Max number of times the job is retried on InternalServerFailure.
@@ -115,6 +119,7 @@ def wrapper(*args, **kwargs):
115119
image_uri=image_uri,
116120
instance_count=instance_count,
117121
instance_type=instance_type,
122+
job_conda_env=job_conda_env,
118123
keep_alive_period_in_seconds=keep_alive_period_in_seconds,
119124
max_retry_attempts=max_retry_attempts,
120125
max_runtime_in_seconds=max_runtime_in_seconds,
@@ -243,6 +248,7 @@ def __init__(
243248
image_uri: str = None,
244249
instance_count: int = 1,
245250
instance_type: str = None,
251+
job_conda_env: str = None,
246252
keep_alive_period_in_seconds: int = 0,
247253
max_parallel_job: int = 1,
248254
max_retry_attempts: int = 1,
@@ -261,13 +267,15 @@ def __init__(
261267
"""Initiates a ``RemoteExecutor`` instance.
262268
263269
Args:
264-
dependencies (str): Path to dependencies file or a reserved keyword ``AUTO_DETECT``.
265-
Defaults to None.
270+
dependencies (str): Path to dependencies file or a reserved keyword
271+
``from_active_conda_env``. Defaults to None.
266272
environment_variables (Dict): Environment variables passed to the underlying sagemaker
267273
job. Defaults to None
268274
image_uri (str): Docker image URI on ECR. Defaults to base Python image.
269275
instance_count (int): Number of instance to use. Defaults to 1.
270276
instance_type (str): EC2 instance type.
277+
job_conda_env (str): Name of the conda environment to activate during execution
278+
of the job. Default is None.
271279
keep_alive_period_in_seconds (int): The duration of time in seconds to retain configured
272280
resources in a warm pool for subsequent training jobs. Defaults to 0.
273281
max_parallel_job (int): Maximal number of jobs that run in parallel. Default to 1.
@@ -306,6 +314,7 @@ def __init__(
306314
image_uri=image_uri,
307315
instance_count=instance_count,
308316
instance_type=instance_type,
317+
job_conda_env=job_conda_env,
309318
keep_alive_period_in_seconds=keep_alive_period_in_seconds,
310319
max_retry_attempts=max_retry_attempts,
311320
max_runtime_in_seconds=max_runtime_in_seconds,
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,205 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""SageMaker runtime environment module"""
14+
15+
from __future__ import absolute_import
16+
17+
import logging
18+
import sys
19+
import os
20+
import shlex
21+
import subprocess
22+
from sagemaker.s3 import s3_path_join, S3Uploader, S3Downloader
23+
from sagemaker.session import Session
24+
from sagemaker.remote_function.errors import RuntimeEnvironmentError
25+
26+
logging.basicConfig(level=logging.INFO)
27+
logger = logging.getLogger(__name__)
28+
29+
30+
class RuntimeEnvironmentManager:
31+
"""Runtime Environment Manager class to manage runtime environment."""
32+
33+
def __init__(
34+
self, s3_base_uri: str = None, s3_kms_key: str = None, sagemaker_session: Session = None
35+
):
36+
self.s3_base_uri = s3_base_uri
37+
self.s3_kms_key = s3_kms_key
38+
self.sagemaker_session = sagemaker_session
39+
40+
def snapshot(self, dependencies: str = None):
41+
"""Creates snapshot of the user's environment
42+
43+
If req.txt or conda.yml file is provided, this method uploads it to
44+
S3 to be used at job's side.
45+
If ``from_active_conda_env`` is set, this method will take the snapshot of
46+
user's active conda env and upload the yml file to S3
47+
48+
Args:
49+
sagemaker_session (Session): Current sagemaker session
50+
s3_base_uri (str): Base S3 URI where dependencies file is uploaded
51+
s3_kms_key (str): KMS key to access the S3 bucket
52+
53+
Returns:
54+
S3 URI where the dependencies file is uploaded or None
55+
"""
56+
57+
# No additional dependencies specified
58+
if dependencies is None:
59+
return None
60+
61+
elif dependencies == "from_active_conda_env":
62+
# TODO:
63+
# 1. verify if conda is active
64+
# 2. take snapshot of active conda env
65+
# 3. upload yml to S3
66+
raise ValueError("from_active_conda_environment keyword is not supported yet")
67+
68+
# Dependencies species as either req.txt or conda_env.yml
69+
elif dependencies.endswith(".txt") or dependencies.endswith(".yml"):
70+
self._is_file_exists(dependencies)
71+
return S3Uploader.upload(
72+
dependencies,
73+
s3_path_join(self.s3_base_uri, "additional_dependencies"),
74+
self.s3_kms_key,
75+
self.sagemaker_session,
76+
)
77+
78+
else:
79+
raise ValueError('Invalid dependencies provided: "{}"'.format(dependencies))
80+
81+
def bootstrap(self, dependencies_s3_uri: str, job_conda_env: str = None):
82+
"""Bootstraps the runtime environment by installing the additional dependencies if any.
83+
84+
Args:
85+
dependencies_s3_uri (str): S3 URI where dependencies file exists.
86+
job_conda_env (str): conda environment to be activated. Default is None.
87+
88+
Returns: None
89+
"""
90+
91+
if dependencies_s3_uri.endswith(".txt"):
92+
if job_conda_env:
93+
# TODO:
94+
# 1. verify if conda exists in the image
95+
# 2. activate the given conda env
96+
# 3. update the conda env with req.txt file
97+
return
98+
99+
local_path = os.getcwd()
100+
S3Downloader.download(
101+
dependencies_s3_uri, local_path, self.s3_kms_key, self.sagemaker_session
102+
)
103+
104+
local_dependencies_file = os.path.join(local_path, dependencies_s3_uri.split("/")[-1])
105+
self._install_requirements_txt(local_dependencies_file, _python_executable())
106+
107+
elif dependencies_s3_uri.endswith(".yml"):
108+
# TODO: implement
109+
# 1. verify is conda exists in the image
110+
# 2. if job_conda_env: activate and update the conda env with yml
111+
# 3. if not, create and activate conda env from conda yml file
112+
return
113+
return
114+
115+
def _is_file_exists(self, dependencies):
116+
"""Check whether the dependencies file exists at the given location.
117+
118+
Raises error if not
119+
120+
Returns: True
121+
"""
122+
if not os.path.isfile(dependencies):
123+
raise ValueError('No dependencies file named "{}" was found.'.format(dependencies))
124+
return True
125+
126+
def _install_requirements_txt(self, local_path, python_executable):
127+
"""Install requirements.txt file"""
128+
cmd = "{} -m pip install -r {}".format(python_executable, local_path)
129+
logger.info("Running command {}".format(cmd))
130+
_run_shell_cmd(cmd)
131+
logger.info("Command {} ran successfully".format(cmd))
132+
133+
def _create_conda_env(self):
134+
"""Create conda env using conda yml file"""
135+
# TODO: implement
136+
pass # pylint: disable=W0107
137+
138+
def _activate_conda_env(self):
139+
"""Activate conda environment"""
140+
# TODO: implement
141+
pass # pylint: disable=W0107
142+
143+
def _update_conda_env(self):
144+
"""Update conda env using conda yml file"""
145+
# TODO: implement
146+
pass # pylint: disable=W0107
147+
148+
149+
def _run_shell_cmd(cmd: str):
150+
"""This method runs a given shell command using subprocess
151+
152+
Raises RuntimeEnvironmentError if the command fails
153+
"""
154+
155+
process = subprocess.Popen(shlex.split(cmd), stdout=subprocess.PIPE, stderr=subprocess.PIPE)
156+
157+
_log_output(process)
158+
error_logs = _log_error(process)
159+
return_code = process.wait()
160+
if return_code:
161+
error_message = "Encountered error while installing dependencies. Reason: {}".format(
162+
error_logs
163+
)
164+
raise RuntimeEnvironmentError(error_message)
165+
166+
167+
def _log_output(process):
168+
"""This method takes in Popen process and logs the output of that process"""
169+
with process.stdout as pipe:
170+
for line in iter(pipe.readline, b""):
171+
logger.info(str(line, "UTF-8"))
172+
173+
174+
def _log_error(process):
175+
"""This method takes in Popen process and logs the error of that process.
176+
177+
Returns those logs as a string
178+
"""
179+
180+
error_logs = ""
181+
with process.stderr as pipe:
182+
for line in iter(pipe.readline, b""):
183+
error_str = str(line, "UTF-8")
184+
if "ERROR:" in error_str:
185+
logging.error(error_str)
186+
error_logs = error_logs + error_str
187+
else:
188+
logging.warn(error_str)
189+
190+
return error_logs
191+
192+
193+
def _python_executable():
194+
"""Return the real path for the Python executable, if it exists.
195+
196+
Return RuntimeEnvironmentError otherwise.
197+
198+
Returns:
199+
(str): The real path of the current Python executable.
200+
"""
201+
if not sys.executable:
202+
raise RuntimeEnvironmentError(
203+
"Failed to retrieve the path for the Python executable binary"
204+
)
205+
return sys.executable

src/sagemaker/remote_function/job.py

+14
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from sagemaker.s3 import s3_path_join
2222
from sagemaker import vpc_utils
2323
from sagemaker.remote_function.core.stored_function import StoredFunction
24+
from sagemaker.remote_function.core.runtime_environment import RuntimeEnvironmentManager
2425

2526

2627
JOBS_CONTAINER_ENTRYPOINT = ["invoke-remote-function"]
@@ -42,6 +43,7 @@ def __init__(
4243
image_uri: str = None,
4344
instance_count: int = 1,
4445
instance_type: str = None,
46+
job_conda_env: str = None,
4547
keep_alive_period_in_seconds: int = 0,
4648
max_retry_attempts: int = 1,
4749
max_runtime_in_seconds: int = 24 * 60 * 60,
@@ -70,6 +72,7 @@ def __init__(
7072
self.max_retry_attempts = max_retry_attempts
7173
self.keep_alive_period_in_seconds = keep_alive_period_in_seconds
7274
self.source_dir = source_dir
75+
self.job_conda_env = job_conda_env
7376

7477
if role is not None:
7578
self.role = self.sagemaker_session.expand_role(role)
@@ -116,6 +119,13 @@ def start(job_settings: _JobSettings, func, func_args, func_kwargs):
116119

117120
s3_base_uri = s3_path_join(job_settings.s3_root_uri, job_name)
118121

122+
runtime_environment_manager = RuntimeEnvironmentManager(
123+
s3_base_uri=s3_base_uri,
124+
s3_kms_key=job_settings.s3_kms_key,
125+
sagemaker_session=job_settings.sagemaker_session,
126+
)
127+
uploaded_dependencies_path = runtime_environment_manager.snapshot(job_settings.dependencies)
128+
119129
stored_function = StoredFunction(
120130
sagemaker_session=job_settings.sagemaker_session,
121131
s3_base_uri=s3_base_uri,
@@ -143,6 +153,10 @@ def start(job_settings: _JobSettings, func, func_args, func_kwargs):
143153
container_args.extend(["--region", job_settings.sagemaker_session.boto_region_name])
144154
if job_settings.s3_kms_key:
145155
container_args.extend(["--s3_kms_key", job_settings.s3_kms_key])
156+
if uploaded_dependencies_path:
157+
container_args.extend(["--dependencies", uploaded_dependencies_path])
158+
if job_settings.job_conda_env:
159+
container_args.extend(["--job_conda_env", job_settings.job_conda_env])
146160

147161
algorithm_spec = dict(
148162
TrainingImage=job_settings.image_uri,

src/sagemaker/remote_function/job_driver.py

+24-9
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
from sagemaker.session import Session
2121
from sagemaker.remote_function.errors import handle_error
22+
from sagemaker.remote_function.core.runtime_environment import RuntimeEnvironmentManager
2223

2324

2425
SUCCESS_EXIT_CODE = 0
@@ -30,6 +31,8 @@ def _parse_agrs():
3031
parser.add_argument("--region", type=str, required=True)
3132
parser.add_argument("--s3_base_uri", type=str, required=True)
3233
parser.add_argument("--s3_kms_key", type=str)
34+
parser.add_argument("--dependencies", type=str)
35+
parser.add_argument("--job_conda_env", type=str)
3336

3437
return parser.parse_args()
3538

@@ -46,23 +49,22 @@ def _execute_pre_exec_cmds():
4649
pass # pylint: disable=W0107
4750

4851

49-
def _install_required_dependencies():
50-
"""Install dependencies required by remote function invocation"""
51-
# TODO: complete me
52-
pass # pylint: disable=W0107
53-
54-
5552
def _uncompress_src_dir():
5653
"""Uncompress src directory for remote function invocation"""
5754
# TODO: complete me
5855
pass # pylint: disable=W0107
5956

6057

61-
def _bootstrap_runtime_environment():
58+
def _bootstrap_runtime_environment(
59+
runtime_manager: RuntimeEnvironmentManager,
60+
dependencies: str,
61+
job_conda_env: str = None,
62+
):
6263
"""Bootstrap runtime environment for remote function invocation"""
6364
_execute_pre_exec_cmds()
6465

65-
_install_required_dependencies()
66+
if dependencies:
67+
runtime_manager.bootstrap(dependencies_s3_uri=dependencies, job_conda_env=job_conda_env)
6668

6769
_uncompress_src_dir()
6870

@@ -83,10 +85,23 @@ def main():
8385
region = args.region
8486
s3_base_uri = args.s3_base_uri
8587
s3_kms_key = args.s3_kms_key
88+
dependencies = args.dependencies
89+
job_conda_env = args.job_conda_env
8690

8791
sagemaker_session = _get_sagemaker_session(region)
8892

89-
_bootstrap_runtime_environment()
93+
runtime_environment_manager = RuntimeEnvironmentManager(
94+
s3_base_uri=s3_base_uri,
95+
s3_kms_key=s3_kms_key,
96+
sagemaker_session=sagemaker_session,
97+
)
98+
99+
_bootstrap_runtime_environment(
100+
runtime_manager=runtime_environment_manager,
101+
dependencies=dependencies,
102+
job_conda_env=job_conda_env,
103+
)
104+
90105
_execute_remote_function(sagemaker_session, s3_base_uri, s3_kms_key)
91106

92107
except Exception as e: # pylint: disable=broad-except
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
does_not_exist
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
scipy==1.10.0

0 commit comments

Comments
 (0)