Skip to content

Commit 30b4ce2

Browse files
feature: support passing Env Vars to local mode training (aws#3015)
1 parent 7c667c6 commit 30b4ce2

File tree

7 files changed

+85
-13
lines changed

7 files changed

+85
-13
lines changed

.githooks/pre-push

+2-2
Original file line numberDiff line numberDiff line change
@@ -12,5 +12,5 @@ start_time=`date +%s`
1212
tox -e sphinx,doc8 --parallel all
1313
./ci-scripts/displaytime.sh 'sphinx,doc8' $start_time
1414
start_time=`date +%s`
15-
tox -e py36,py37,py38 --parallel all -- tests/unit
16-
./ci-scripts/displaytime.sh 'py36,py37,py38 unit' $start_time
15+
tox -e py36,py37,py38,py39 --parallel all -- tests/unit
16+
./ci-scripts/displaytime.sh 'py36,py37,py38,py39 unit' $start_time

src/sagemaker/local/entities.py

+20-4
Original file line numberDiff line numberDiff line change
@@ -175,22 +175,37 @@ def describe(self):
175175

176176

177177
class _LocalTrainingJob(object):
178-
"""Placeholder docstring"""
178+
"""Defines and starts a local training job."""
179179

180180
_STARTING = "Starting"
181181
_TRAINING = "Training"
182182
_COMPLETED = "Completed"
183183
_states = ["Starting", "Training", "Completed"]
184184

185185
def __init__(self, container):
186+
"""Creates a local training job.
187+
188+
Args:
189+
container: the local container object.
190+
"""
186191
self.container = container
187192
self.model_artifacts = None
188193
self.state = "created"
189194
self.start_time = None
190195
self.end_time = None
196+
self.environment = None
197+
198+
def start(self, input_data_config, output_data_config, hyperparameters, environment, job_name):
199+
"""Starts a local training job.
191200
192-
def start(self, input_data_config, output_data_config, hyperparameters, job_name):
193-
"""Placeholder docstring."""
201+
Args:
202+
input_data_config (dict): The Input Data Configuration, this contains data such as the
203+
channels to be used for training.
204+
output_data_config (dict): The configuration of the output data.
205+
hyperparameters (dict): The HyperParameters for the training job.
206+
environment (dict): The collection of environment variables passed to the job.
207+
job_name (str): Name of the local training job being run.
208+
"""
194209
for channel in input_data_config:
195210
if channel["DataSource"] and "S3DataSource" in channel["DataSource"]:
196211
data_distribution = channel["DataSource"]["S3DataSource"]["S3DataDistributionType"]
@@ -216,9 +231,10 @@ def start(self, input_data_config, output_data_config, hyperparameters, job_name
216231

217232
self.start_time = datetime.datetime.now()
218233
self.state = self._TRAINING
234+
self.environment = environment
219235

220236
self.model_artifacts = self.container.train(
221-
input_data_config, output_data_config, hyperparameters, job_name
237+
input_data_config, output_data_config, hyperparameters, environment, job_name
222238
)
223239
self.end_time = datetime.datetime.now()
224240
self.state = self._COMPLETED

src/sagemaker/local/image.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -176,14 +176,15 @@ def process(
176176
# you see this line at the end.
177177
print("===== Job Complete =====")
178178

179-
def train(self, input_data_config, output_data_config, hyperparameters, job_name):
179+
def train(self, input_data_config, output_data_config, hyperparameters, environment, job_name):
180180
"""Run a training job locally using docker-compose.
181181
182182
Args:
183183
input_data_config (dict): The Input Data Configuration, this contains data such as the
184184
channels to be used for training.
185185
output_data_config: The configuration of the output data.
186186
hyperparameters (dict): The HyperParameters for the training job.
187+
environment (dict): The environment collection for the training job.
187188
job_name (str): Name of the local training job being run.
188189
189190
Returns (str): Location of the trained model.
@@ -217,6 +218,7 @@ def train(self, input_data_config, output_data_config, hyperparameters, job_name
217218
REGION_ENV_NAME: self.sagemaker_session.boto_region_name,
218219
TRAINING_JOB_NAME_ENV_NAME: job_name,
219220
}
221+
training_env_vars.update(environment)
220222
if self.sagemaker_session.s3_resource is not None:
221223
training_env_vars[
222224
S3_ENDPOINT_URL_ENV_NAME

src/sagemaker/local/local_session.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,7 @@ def create_training_job(
155155
OutputDataConfig,
156156
ResourceConfig,
157157
InputDataConfig=None,
158+
Environment=None,
158159
**kwargs
159160
):
160161
"""Create a training job in Local Mode.
@@ -167,6 +168,8 @@ def create_training_job(
167168
OutputDataConfig(dict): Identifies the location where you want to save the results of
168169
model training.
169170
ResourceConfig(dict): Identifies the resources to use for local model training.
171+
Environment(dict, optional): Describes the environment variables to pass
172+
to the container. (Default value = None)
170173
HyperParameters(dict) [optional]: Specifies these algorithm-specific parameters to
171174
influence the quality of the final model.
172175
**kwargs:
@@ -175,6 +178,7 @@ def create_training_job(
175178
176179
"""
177180
InputDataConfig = InputDataConfig or {}
181+
Environment = Environment or {}
178182
container = _SageMakerContainer(
179183
ResourceConfig["InstanceType"],
180184
ResourceConfig["InstanceCount"],
@@ -184,7 +188,9 @@ def create_training_job(
184188
training_job = _LocalTrainingJob(container)
185189
hyperparameters = kwargs["HyperParameters"] if "HyperParameters" in kwargs else {}
186190
logger.info("Starting training job")
187-
training_job.start(InputDataConfig, OutputDataConfig, hyperparameters, TrainingJobName)
191+
training_job.start(
192+
InputDataConfig, OutputDataConfig, hyperparameters, Environment, TrainingJobName
193+
)
188194

189195
LocalSagemakerClient._training_jobs[TrainingJobName] = training_job
190196

tests/data/mxnet_mnist/check_env.py

+16
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
# Licensed under the Apache License, Version 2.0 (the "License"). You
2+
# may not use this file except in compliance with the License. A copy of
3+
# the License is located at
4+
#
5+
# http://aws.amazon.com/apache2.0/
6+
#
7+
# or in the "license" file accompanying this file. This file is
8+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
9+
# ANY KIND, either express or implied. See the License for the specific
10+
# language governing permissions and limitations under the License.
11+
from __future__ import absolute_import
12+
import os
13+
14+
15+
if __name__ == "__main__":
16+
assert os.environ["MYVAR"] == "HELLO_WORLD"

tests/integ/test_local_mode.py

+22
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,28 @@ def test_mxnet_local_data_local_script(
247247
predictor.delete_endpoint()
248248

249249

250+
@pytest.mark.local_mode
251+
def test_mxnet_local_training_env(mxnet_training_latest_version, mxnet_training_latest_py_version):
252+
data_path = os.path.join(DATA_DIR, "mxnet_mnist")
253+
script_path = os.path.join(data_path, "check_env.py")
254+
255+
mx = MXNet(
256+
entry_point=script_path,
257+
role="SageMakerRole",
258+
instance_count=1,
259+
instance_type="local",
260+
framework_version=mxnet_training_latest_version,
261+
py_version=mxnet_training_latest_py_version,
262+
sagemaker_session=LocalNoS3Session(),
263+
environment={"MYVAR": "HELLO_WORLD"},
264+
)
265+
266+
train_input = "file://" + os.path.join(data_path, "train")
267+
test_input = "file://" + os.path.join(data_path, "test")
268+
269+
mx.fit({"train": train_input, "test": test_input})
270+
271+
250272
@pytest.mark.local_mode
251273
def test_mxnet_training_failure(
252274
sagemaker_local_session, mxnet_training_latest_version, mxnet_training_latest_py_version, tmpdir

tests/unit/test_image.py

+15-5
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,8 @@
7474
"sagemaker_submit_directory": json.dumps("file:///tmp/code"),
7575
}
7676

77+
ENVIRONMENT = {"MYVAR": "HELLO_WORLD"}
78+
7779

7880
@pytest.fixture()
7981
def sagemaker_session():
@@ -352,7 +354,7 @@ def test_train(
352354
"local", instance_count, image, sagemaker_session=sagemaker_session
353355
)
354356
sagemaker_container.train(
355-
INPUT_DATA_CONFIG, OUTPUT_DATA_CONFIG, HYPERPARAMETERS, TRAINING_JOB_NAME
357+
INPUT_DATA_CONFIG, OUTPUT_DATA_CONFIG, HYPERPARAMETERS, ENVIRONMENT, TRAINING_JOB_NAME
356358
)
357359

358360
docker_compose_file = os.path.join(
@@ -415,7 +417,7 @@ def test_train_with_hyperparameters_without_job_name(
415417
"local", instance_count, image, sagemaker_session=sagemaker_session
416418
)
417419
sagemaker_container.train(
418-
INPUT_DATA_CONFIG, OUTPUT_DATA_CONFIG, HYPERPARAMETERS, TRAINING_JOB_NAME
420+
INPUT_DATA_CONFIG, OUTPUT_DATA_CONFIG, HYPERPARAMETERS, ENVIRONMENT, TRAINING_JOB_NAME
419421
)
420422

421423
docker_compose_file = os.path.join(
@@ -456,7 +458,11 @@ def test_train_error(
456458

457459
with pytest.raises(RuntimeError) as e:
458460
sagemaker_container.train(
459-
INPUT_DATA_CONFIG, OUTPUT_DATA_CONFIG, HYPERPARAMETERS, TRAINING_JOB_NAME
461+
INPUT_DATA_CONFIG,
462+
OUTPUT_DATA_CONFIG,
463+
HYPERPARAMETERS,
464+
ENVIRONMENT,
465+
TRAINING_JOB_NAME,
460466
)
461467

462468
assert "this is expected" in str(e)
@@ -486,7 +492,11 @@ def test_train_local_code(get_data_source_instance, tmpdir, sagemaker_session):
486492
)
487493

488494
sagemaker_container.train(
489-
INPUT_DATA_CONFIG, OUTPUT_DATA_CONFIG, LOCAL_CODE_HYPERPARAMETERS, TRAINING_JOB_NAME
495+
INPUT_DATA_CONFIG,
496+
OUTPUT_DATA_CONFIG,
497+
LOCAL_CODE_HYPERPARAMETERS,
498+
ENVIRONMENT,
499+
TRAINING_JOB_NAME,
490500
)
491501

492502
docker_compose_file = os.path.join(
@@ -538,7 +548,7 @@ def test_train_local_intermediate_output(get_data_source_instance, tmpdir, sagem
538548
hyperparameters = {"sagemaker_s3_output": output_path}
539549

540550
sagemaker_container.train(
541-
INPUT_DATA_CONFIG, output_data_config, hyperparameters, TRAINING_JOB_NAME
551+
INPUT_DATA_CONFIG, output_data_config, hyperparameters, ENVIRONMENT, TRAINING_JOB_NAME
542552
)
543553

544554
docker_compose_file = os.path.join(

0 commit comments

Comments
 (0)