|
| 1 | +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"). You |
| 4 | +# may not use this file except in compliance with the License. A copy of |
| 5 | +# the License is located at |
| 6 | +# |
| 7 | +# http://aws.amazon.com/apache2.0/ |
| 8 | +# |
| 9 | +# or in the "license" file accompanying this file. This file is |
| 10 | +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF |
| 11 | +# ANY KIND, either express or implied. See the License for the specific |
| 12 | +# language governing permissions and limitations under the License. |
| 13 | +"""SageMaker runtime environment module""" |
| 14 | + |
| 15 | +from __future__ import absolute_import |
| 16 | + |
| 17 | +import logging |
| 18 | +import sys |
| 19 | +import os |
| 20 | +import shlex |
| 21 | +import subprocess |
| 22 | +from sagemaker.s3 import s3_path_join, S3Uploader, S3Downloader |
| 23 | +from sagemaker.session import Session |
| 24 | +from sagemaker.remote_function.errors import RuntimeEnvironmentError |
| 25 | + |
| 26 | +logging.basicConfig(level=logging.INFO) |
| 27 | +logger = logging.getLogger(__name__) |
| 28 | + |
| 29 | + |
| 30 | +class RuntimeEnvironmentManager: |
| 31 | + """Runtime Environment Manager class to manage runtime environment.""" |
| 32 | + |
| 33 | + def __init__( |
| 34 | + self, s3_base_uri: str = None, s3_kms_key: str = None, sagemaker_session: Session = None |
| 35 | + ): |
| 36 | + self.s3_base_uri = s3_base_uri |
| 37 | + self.s3_kms_key = s3_kms_key |
| 38 | + self.sagemaker_session = sagemaker_session |
| 39 | + |
| 40 | + def snapshot(self, dependencies: str = None): |
| 41 | + """Creates snapshot of the user's environment |
| 42 | +
|
| 43 | + If req.txt or conda.yml file is provided, this method uploads it to |
| 44 | + S3 to be used at job's side. |
| 45 | + If ``from_active_conda_env`` is set, this method will take the snapshot of |
| 46 | + user's active conda env and upload the yml file to S3 |
| 47 | +
|
| 48 | + Args: |
| 49 | + sagemaker_session (Session): Current sagemaker session |
| 50 | + s3_base_uri (str): Base S3 URI where dependencies file is uploaded |
| 51 | + s3_kms_key (str): KMS key to access the S3 bucket |
| 52 | +
|
| 53 | + Returns: |
| 54 | + S3 URI where the dependencies file is uploaded or None |
| 55 | + """ |
| 56 | + |
| 57 | + # No additional dependencies specified |
| 58 | + if dependencies is None: |
| 59 | + return None |
| 60 | + |
| 61 | + elif dependencies == "from_active_conda_env": |
| 62 | + # TODO: |
| 63 | + # 1. verify if conda is active |
| 64 | + # 2. take snapshot of active conda env |
| 65 | + # 3. upload yml to S3 |
| 66 | + raise ValueError("from_active_conda_environment keyword is not supported yet") |
| 67 | + |
| 68 | + # Dependencies species as either req.txt or conda_env.yml |
| 69 | + elif dependencies.endswith(".txt") or dependencies.endswith(".yml"): |
| 70 | + self._is_file_exists(dependencies) |
| 71 | + return S3Uploader.upload( |
| 72 | + dependencies, |
| 73 | + s3_path_join(self.s3_base_uri, "additional_dependencies"), |
| 74 | + self.s3_kms_key, |
| 75 | + self.sagemaker_session, |
| 76 | + ) |
| 77 | + |
| 78 | + else: |
| 79 | + raise ValueError('Invalid dependencies provided: "{}"'.format(dependencies)) |
| 80 | + |
| 81 | + def bootstrap(self, dependencies_s3_uri: str, job_conda_env: str = None): |
| 82 | + """Bootstraps the runtime environment by installing the additional dependencies if any. |
| 83 | +
|
| 84 | + Args: |
| 85 | + dependencies_s3_uri (str): S3 URI where dependencies file exists. |
| 86 | + job_conda_env (str): conda environment to be activated. Default is None. |
| 87 | +
|
| 88 | + Returns: None |
| 89 | + """ |
| 90 | + |
| 91 | + if dependencies_s3_uri.endswith(".txt"): |
| 92 | + if job_conda_env: |
| 93 | + # TODO: |
| 94 | + # 1. verify if conda exists in the image |
| 95 | + # 2. activate the given conda env |
| 96 | + # 3. update the conda env with req.txt file |
| 97 | + return |
| 98 | + |
| 99 | + local_path = os.getcwd() |
| 100 | + S3Downloader.download( |
| 101 | + dependencies_s3_uri, local_path, self.s3_kms_key, self.sagemaker_session |
| 102 | + ) |
| 103 | + |
| 104 | + local_dependencies_file = os.path.join(local_path, dependencies_s3_uri.split("/")[-1]) |
| 105 | + self._install_requirements_txt(local_dependencies_file, _python_executable()) |
| 106 | + |
| 107 | + elif dependencies_s3_uri.endswith(".yml"): |
| 108 | + # TODO: implement |
| 109 | + # 1. verify is conda exists in the image |
| 110 | + # 2. if job_conda_env: activate and update the conda env with yml |
| 111 | + # 3. if not, create and activate conda env from conda yml file |
| 112 | + return |
| 113 | + return |
| 114 | + |
| 115 | + def _is_file_exists(self, dependencies): |
| 116 | + """Check whether the dependencies file exists at the given location. |
| 117 | +
|
| 118 | + Raises error if not |
| 119 | +
|
| 120 | + Returns: True |
| 121 | + """ |
| 122 | + if not os.path.isfile(dependencies): |
| 123 | + raise ValueError('No dependencies file named "{}" was found.'.format(dependencies)) |
| 124 | + return True |
| 125 | + |
| 126 | + def _install_requirements_txt(self, local_path, python_executable): |
| 127 | + """Install requirements.txt file""" |
| 128 | + cmd = "{} -m pip install -r {}".format(python_executable, local_path) |
| 129 | + logger.info("Running command {}".format(cmd)) |
| 130 | + _run_shell_cmd(cmd) |
| 131 | + logger.info("Command {} ran successfully".format(cmd)) |
| 132 | + |
| 133 | + def _create_conda_env(self): |
| 134 | + """Create conda env using conda yml file""" |
| 135 | + # TODO: implement |
| 136 | + pass # pylint: disable=W0107 |
| 137 | + |
| 138 | + def _activate_conda_env(self): |
| 139 | + """Activate conda environment""" |
| 140 | + # TODO: implement |
| 141 | + pass # pylint: disable=W0107 |
| 142 | + |
| 143 | + def _update_conda_env(self): |
| 144 | + """Update conda env using conda yml file""" |
| 145 | + # TODO: implement |
| 146 | + pass # pylint: disable=W0107 |
| 147 | + |
| 148 | + |
| 149 | +def _run_shell_cmd(cmd: str): |
| 150 | + """This method runs a given shell command using subprocess |
| 151 | +
|
| 152 | + Raises RuntimeEnvironmentError if the command fails |
| 153 | + """ |
| 154 | + |
| 155 | + process = subprocess.Popen(shlex.split(cmd), stdout=subprocess.PIPE, stderr=subprocess.PIPE) |
| 156 | + |
| 157 | + _log_output(process) |
| 158 | + error_logs = _log_error(process) |
| 159 | + return_code = process.wait() |
| 160 | + if return_code: |
| 161 | + error_message = "Encountered error while installing dependencies. Reason: {}".format( |
| 162 | + error_logs |
| 163 | + ) |
| 164 | + raise RuntimeEnvironmentError(error_message) |
| 165 | + |
| 166 | + |
| 167 | +def _log_output(process): |
| 168 | + """This method takes in Popen process and logs the output of that process""" |
| 169 | + with process.stdout as pipe: |
| 170 | + for line in iter(pipe.readline, b""): |
| 171 | + logger.info(str(line, "UTF-8")) |
| 172 | + |
| 173 | + |
| 174 | +def _log_error(process): |
| 175 | + """This method takes in Popen process and logs the error of that process. |
| 176 | +
|
| 177 | + Returns those logs as a string |
| 178 | + """ |
| 179 | + |
| 180 | + error_logs = "" |
| 181 | + with process.stderr as pipe: |
| 182 | + for line in iter(pipe.readline, b""): |
| 183 | + error_str = str(line, "UTF-8") |
| 184 | + if "ERROR:" in error_str: |
| 185 | + logging.error(error_str) |
| 186 | + error_logs = error_logs + error_str |
| 187 | + else: |
| 188 | + logging.warn(error_str) |
| 189 | + |
| 190 | + return error_logs |
| 191 | + |
| 192 | + |
| 193 | +def _python_executable(): |
| 194 | + """Return the real path for the Python executable, if it exists. |
| 195 | +
|
| 196 | + Return RuntimeEnvironmentError otherwise. |
| 197 | +
|
| 198 | + Returns: |
| 199 | + (str): The real path of the current Python executable. |
| 200 | + """ |
| 201 | + if not sys.executable: |
| 202 | + raise RuntimeEnvironmentError( |
| 203 | + "Failed to retrieve the path for the Python executable binary" |
| 204 | + ) |
| 205 | + return sys.executable |
0 commit comments