Skip to content

Commit a166b3c

Browse files
beniericpintaoz-aws
authored andcommitted
Add enviornment variable bootstrapping script (#1530)
* Add enviornment variables scripts * format * fix comment * add docstrings * fix comment
1 parent 4841929 commit a166b3c

File tree

2 files changed

+236
-0
lines changed

2 files changed

+236
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,232 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""This module is used to define the environment variables for the training job container."""
14+
15+
from __future__ import absolute_import
16+
17+
from typing import Dict, Any
18+
import multiprocessing
19+
import subprocess
20+
import json
21+
import os
22+
import sys
23+
import logging
24+
25+
# Initialize logger
26+
SM_LOG_LEVEL = os.environ.get("SM_LOG_LEVEL", 20)
27+
logger = logging.getLogger(__name__)
28+
console_handler = logging.StreamHandler(sys.stdout)
29+
logger.addHandler(console_handler)
30+
logger.setLevel(SM_LOG_LEVEL)
31+
32+
SM_MODEL_DIR = "/opt/ml/model"
33+
34+
SM_INPUT_DIR = "/opt/ml/input"
35+
SM_INPUT_DATA_DIR = "/opt/ml/input/data"
36+
SM_INPUT_CONFIG_DIR = "/opt/ml/input/config"
37+
38+
SM_OUTPUT_DIR = "/opt/ml/output"
39+
SM_OUTPUT_FAILURE = "/opt/ml/output/failure"
40+
SM_OUTPUT_DATA_DIR = "/opt/ml/output/data"
41+
42+
SM_MASTER_ADDR = "algo-1"
43+
SM_MASTER_PORT = 7777
44+
45+
RESOURCE_CONFIG = f"{SM_INPUT_CONFIG_DIR}/resourceconfig.json"
46+
INPUT_DATA_CONFIG = f"{SM_INPUT_CONFIG_DIR}/inputdataconfig.json"
47+
HYPERPARAMETERS_CONFIG = f"{SM_INPUT_CONFIG_DIR}/hyperparameters.json"
48+
49+
ENV_OUTPUT_FILE = "sm_training.env"
50+
51+
52+
def num_cpus():
53+
"""Return the number of CPUs available in the current container.
54+
55+
Returns:
56+
int: Number of CPUs available in the current container.
57+
"""
58+
return multiprocessing.cpu_count()
59+
60+
61+
def num_gpus():
62+
"""Return the number of GPUs available in the current container.
63+
64+
Returns:
65+
int: Number of GPUs available in the current container.
66+
"""
67+
try:
68+
cmd = ["nvidia-smi", "--list-gpus"]
69+
output = subprocess.check_output(cmd).decode("utf-8")
70+
return sum(1 for line in output.splitlines() if line.startswith("GPU "))
71+
except (OSError, subprocess.CalledProcessError):
72+
logger.info("No GPUs detected (normal if no gpus installed)")
73+
return 0
74+
75+
76+
def num_neurons():
77+
"""Return the number of neuron cores available in the current container.
78+
79+
Returns:
80+
int: Number of Neuron Cores available in the current container.
81+
"""
82+
try:
83+
cmd = ["neuron-ls", "-j"]
84+
output = subprocess.check_output(cmd, stderr=subprocess.STDOUT).decode("utf-8")
85+
j = json.loads(output)
86+
neuron_cores = 0
87+
for item in j:
88+
neuron_cores += item.get("nc_count", 0)
89+
logger.info(f"Found {neuron_cores} neurons on this instance")
90+
return neuron_cores
91+
except OSError:
92+
logger.info("No Neurons detected (normal if no neurons installed)")
93+
return 0
94+
except subprocess.CalledProcessError as e:
95+
if e.output is not None:
96+
try:
97+
msg = e.output.decode("utf-8").partition("error=")[2]
98+
logger.info(
99+
"No Neurons detected (normal if no neurons installed). \
100+
If neuron installed then {}".format(
101+
msg
102+
)
103+
)
104+
except AttributeError:
105+
logger.info("No Neurons detected (normal if no neurons installed)")
106+
else:
107+
logger.info("No Neurons detected (normal if no neurons installed)")
108+
109+
return 0
110+
111+
112+
def set_env(
113+
resource_config: Dict[str, Any] = {},
114+
input_data_config: Dict[str, Any] = {},
115+
hyperparameters_config: Dict[str, Any] = {},
116+
output_file: str = "sm_training.env",
117+
write_to_etc: bool = False,
118+
):
119+
"""Set environment variables for the training job container.
120+
121+
Args:
122+
resource_config (Dict[str, Any]): Resource configuration for the training job.
123+
input_data_config (Dict[str, Any]): Input data configuration for the training job.
124+
hyperparameters_config (Dict[str, Any]): Hyperparameters configuration for the training job.
125+
output_file (str): Output file to write the environment variables.
126+
write_to_etc (bool): Whether to write the environment variables to /etc/environment.
127+
"""
128+
# Constants
129+
env_vars = {
130+
"SM_MODEL_DIR": SM_MODEL_DIR,
131+
"SM_INPUT_DIR": SM_INPUT_DIR,
132+
"SM_INPUT_DATA_DIR": SM_INPUT_DATA_DIR,
133+
"SM_INPUT_CONFIG_DIR": SM_INPUT_CONFIG_DIR,
134+
"SM_OUTPUT_DIR": SM_OUTPUT_DIR,
135+
"SM_OUTPUT_FAILURE": SM_OUTPUT_FAILURE,
136+
"SM_OUTPUT_DATA_DIR": SM_OUTPUT_DATA_DIR,
137+
"SM_LOG_LEVEL": SM_LOG_LEVEL,
138+
"SM_MASTER_ADDR": SM_MASTER_ADDR,
139+
"SM_MASTER_PORT": SM_MASTER_PORT,
140+
}
141+
142+
# Data Channels
143+
channels = list(input_data_config.keys())
144+
for channel in channels:
145+
env_vars[f"SM_CHANNEL_{channel.upper()}"] = f"{SM_INPUT_DATA_DIR}/{channel}"
146+
env_vars["SM_CHANNELS"] = channels
147+
148+
# Hyperparameters
149+
env_vars["SM_HPS"] = hyperparameters_config
150+
for key, value in hyperparameters_config.items():
151+
env_vars[f"SM_HP_{key.upper()}"] = value
152+
153+
# Host Variables
154+
current_host = resource_config["current_host"]
155+
hosts = resource_config["hosts"]
156+
sorted_hosts = sorted(hosts)
157+
158+
env_vars["SM_CURRENT_HOST"] = current_host
159+
env_vars["SM_HOSTS"] = sorted_hosts
160+
env_vars["SM_NETWORK_INTERFACE_NAME"] = resource_config["network_interface_name"]
161+
env_vars["SM_HOST_COUNT"] = len(sorted_hosts)
162+
env_vars["SM_CURRENT_HOST_RANK"] = sorted_hosts.index(current_host)
163+
164+
env_vars["SM_NUM_CPUS"] = num_cpus()
165+
env_vars["SM_NUM_GPUS"] = num_gpus()
166+
env_vars["SM_NUM_NEURONS"] = num_neurons()
167+
168+
# Misc.
169+
env_vars["SM_RESOURCE_CONFIG"] = resource_config
170+
env_vars["SM_INPUT_DATA_CONFIG"] = input_data_config
171+
172+
# All Training Environment Variables
173+
env_vars["SM_TRAINING_ENV"] = {
174+
"channel_input_dirs": {
175+
channel: env_vars[f"SM_CHANNEL_{channel.upper()}"] for channel in channels
176+
},
177+
"current_host": env_vars["SM_CURRENT_HOST"],
178+
"hosts": env_vars["SM_HOSTS"],
179+
"master_addr": env_vars["SM_MASTER_ADDR"],
180+
"master_port": env_vars["SM_MASTER_PORT"],
181+
"hyperparameters": env_vars["SM_HPS"],
182+
"input_data_config": input_data_config,
183+
"input_config_dir": env_vars["SM_INPUT_CONFIG_DIR"],
184+
"input_data_dir": env_vars["SM_INPUT_DATA_DIR"],
185+
"input_dir": env_vars["SM_INPUT_DIR"],
186+
"job_name": os.environ["TRAINING_JOB_NAME"],
187+
"log_level": env_vars["SM_LOG_LEVEL"],
188+
"model_dir": env_vars["SM_MODEL_DIR"],
189+
"network_interface_name": env_vars["SM_NETWORK_INTERFACE_NAME"],
190+
"num_cpus": env_vars["SM_NUM_CPUS"],
191+
"num_gpus": env_vars["SM_NUM_GPUS"],
192+
"num_neurons": env_vars["SM_NUM_NEURONS"],
193+
"output_data_dir": env_vars["SM_OUTPUT_DATA_DIR"],
194+
"resource_config": env_vars["SM_RESOURCE_CONFIG"],
195+
}
196+
try:
197+
cur_dir = os.path.dirname(os.path.abspath(__file__))
198+
except NameError:
199+
# Fallback to current working directory
200+
cur_dir = os.getcwd()
201+
with open(os.path.join(cur_dir, output_file), "w") as f:
202+
for key, value in env_vars.items():
203+
if isinstance(value, (list, dict)):
204+
f.write(f"export {key}='{json.dumps(value)}'\n")
205+
else:
206+
f.write(f"export {key}='{value}'\n")
207+
208+
# Need to write to /etc/environment for MPI to work
209+
if write_to_etc:
210+
with open("/etc/environment", "a") as f:
211+
for key, value in env_vars.items():
212+
if isinstance(value, (list, dict)):
213+
f.write(f"{key}='{json.dumps(value)}'\n")
214+
else:
215+
f.write(f"{key}='{value}'\n")
216+
217+
218+
if __name__ == "__main__":
219+
with open(RESOURCE_CONFIG, "r") as f:
220+
resource_config = json.load(f)
221+
with open(INPUT_DATA_CONFIG, "r") as f:
222+
input_data_config = json.load(f)
223+
with open(HYPERPARAMETERS_CONFIG, "r") as f:
224+
hyperparameters_config = json.load(f)
225+
226+
set_env(
227+
resource_config=resource_config,
228+
input_data_config=input_data_config,
229+
hyperparameters_config=hyperparameters_config,
230+
output_file=ENV_OUTPUT_FILE,
231+
write_to_etc=True,
232+
)

src/sagemaker/modules/templates.py

+4
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,10 @@
2929
cat /opt/ml/input/config/hyperparameters.json
3030
echo
3131
32+
echo "Setting up environment variables"
33+
python /opt/ml/input/data/sm_code/environment.py
34+
source /opt/ml/input/data/sm_code/sm_training.env
35+
3236
python --version
3337
{working_dir}
3438
{install_requirements}

0 commit comments

Comments
 (0)