Skip to content

feature: support passing Env Vars to local mode training #3015

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Mar 29, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .githooks/pre-push
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,5 @@ start_time=`date +%s`
tox -e sphinx,doc8 --parallel all
./ci-scripts/displaytime.sh 'sphinx,doc8' $start_time
start_time=`date +%s`
tox -e py36,py37,py38 --parallel all -- tests/unit
./ci-scripts/displaytime.sh 'py36,py37,py38 unit' $start_time
tox -e py36,py37,py38,py39 --parallel all -- tests/unit
./ci-scripts/displaytime.sh 'py36,py37,py38,py39 unit' $start_time
24 changes: 20 additions & 4 deletions src/sagemaker/local/entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,22 +175,37 @@ def describe(self):


class _LocalTrainingJob(object):
"""Placeholder docstring"""
"""Defines and starts a local training job."""

_STARTING = "Starting"
_TRAINING = "Training"
_COMPLETED = "Completed"
_states = ["Starting", "Training", "Completed"]

def __init__(self, container):
"""Creates a local training job.

Args:
container: the local container object.
"""
self.container = container
self.model_artifacts = None
self.state = "created"
self.start_time = None
self.end_time = None
self.environment = None

def start(self, input_data_config, output_data_config, hyperparameters, environment, job_name):
"""Starts a local training job.

def start(self, input_data_config, output_data_config, hyperparameters, job_name):
"""Placeholder docstring."""
Args:
input_data_config (dict): The Input Data Configuration, this contains data such as the
channels to be used for training.
output_data_config (dict): The configuration of the output data.
hyperparameters (dict): The HyperParameters for the training job.
environment (dict): The collection of environment variables passed to the job.
job_name (str): Name of the local training job being run.
"""
for channel in input_data_config:
if channel["DataSource"] and "S3DataSource" in channel["DataSource"]:
data_distribution = channel["DataSource"]["S3DataSource"]["S3DataDistributionType"]
Expand All @@ -216,9 +231,10 @@ def start(self, input_data_config, output_data_config, hyperparameters, job_name

self.start_time = datetime.datetime.now()
self.state = self._TRAINING
self.environment = environment

self.model_artifacts = self.container.train(
input_data_config, output_data_config, hyperparameters, job_name
input_data_config, output_data_config, hyperparameters, environment, job_name
)
self.end_time = datetime.datetime.now()
self.state = self._COMPLETED
Expand Down
4 changes: 3 additions & 1 deletion src/sagemaker/local/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,14 +176,15 @@ def process(
# you see this line at the end.
print("===== Job Complete =====")

def train(self, input_data_config, output_data_config, hyperparameters, job_name):
def train(self, input_data_config, output_data_config, hyperparameters, environment, job_name):
"""Run a training job locally using docker-compose.

Args:
input_data_config (dict): The Input Data Configuration, this contains data such as the
channels to be used for training.
output_data_config: The configuration of the output data.
hyperparameters (dict): The HyperParameters for the training job.
environment (dict): The environment collection for the training job.
job_name (str): Name of the local training job being run.

Returns (str): Location of the trained model.
Expand Down Expand Up @@ -217,6 +218,7 @@ def train(self, input_data_config, output_data_config, hyperparameters, job_name
REGION_ENV_NAME: self.sagemaker_session.boto_region_name,
TRAINING_JOB_NAME_ENV_NAME: job_name,
}
training_env_vars.update(environment)
if self.sagemaker_session.s3_resource is not None:
training_env_vars[
S3_ENDPOINT_URL_ENV_NAME
Expand Down
8 changes: 7 additions & 1 deletion src/sagemaker/local/local_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@ def create_training_job(
OutputDataConfig,
ResourceConfig,
InputDataConfig=None,
Environment=None,
**kwargs
):
"""Create a training job in Local Mode.
Expand All @@ -167,6 +168,8 @@ def create_training_job(
OutputDataConfig(dict): Identifies the location where you want to save the results of
model training.
ResourceConfig(dict): Identifies the resources to use for local model training.
Environment(dict, optional): Describes the environment variables to pass
to the container. (Default value = None)
HyperParameters(dict) [optional]: Specifies these algorithm-specific parameters to
influence the quality of the final model.
**kwargs:
Expand All @@ -175,6 +178,7 @@ def create_training_job(

"""
InputDataConfig = InputDataConfig or {}
Environment = Environment or {}
container = _SageMakerContainer(
ResourceConfig["InstanceType"],
ResourceConfig["InstanceCount"],
Expand All @@ -184,7 +188,9 @@ def create_training_job(
training_job = _LocalTrainingJob(container)
hyperparameters = kwargs["HyperParameters"] if "HyperParameters" in kwargs else {}
logger.info("Starting training job")
training_job.start(InputDataConfig, OutputDataConfig, hyperparameters, TrainingJobName)
training_job.start(
InputDataConfig, OutputDataConfig, hyperparameters, Environment, TrainingJobName
)

LocalSagemakerClient._training_jobs[TrainingJobName] = training_job

Expand Down
16 changes: 16 additions & 0 deletions tests/data/mxnet_mnist/check_env.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
from __future__ import absolute_import
import os


if __name__ == "__main__":
assert os.environ["MYVAR"] == "HELLO_WORLD"
22 changes: 22 additions & 0 deletions tests/integ/test_local_mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,28 @@ def test_mxnet_local_data_local_script(
predictor.delete_endpoint()


@pytest.mark.local_mode
def test_mxnet_local_training_env(mxnet_training_latest_version, mxnet_training_latest_py_version):
data_path = os.path.join(DATA_DIR, "mxnet_mnist")
script_path = os.path.join(data_path, "check_env.py")

mx = MXNet(
entry_point=script_path,
role="SageMakerRole",
instance_count=1,
instance_type="local",
framework_version=mxnet_training_latest_version,
py_version=mxnet_training_latest_py_version,
sagemaker_session=LocalNoS3Session(),
environment={"MYVAR": "HELLO_WORLD"},
)

train_input = "file://" + os.path.join(data_path, "train")
test_input = "file://" + os.path.join(data_path, "test")

mx.fit({"train": train_input, "test": test_input})


@pytest.mark.local_mode
def test_mxnet_training_failure(
sagemaker_local_session, mxnet_training_latest_version, mxnet_training_latest_py_version, tmpdir
Expand Down
20 changes: 15 additions & 5 deletions tests/unit/test_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@
"sagemaker_submit_directory": json.dumps("file:///tmp/code"),
}

ENVIRONMENT = {"MYVAR": "HELLO_WORLD"}


@pytest.fixture()
def sagemaker_session():
Expand Down Expand Up @@ -352,7 +354,7 @@ def test_train(
"local", instance_count, image, sagemaker_session=sagemaker_session
)
sagemaker_container.train(
INPUT_DATA_CONFIG, OUTPUT_DATA_CONFIG, HYPERPARAMETERS, TRAINING_JOB_NAME
INPUT_DATA_CONFIG, OUTPUT_DATA_CONFIG, HYPERPARAMETERS, ENVIRONMENT, TRAINING_JOB_NAME
)

docker_compose_file = os.path.join(
Expand Down Expand Up @@ -415,7 +417,7 @@ def test_train_with_hyperparameters_without_job_name(
"local", instance_count, image, sagemaker_session=sagemaker_session
)
sagemaker_container.train(
INPUT_DATA_CONFIG, OUTPUT_DATA_CONFIG, HYPERPARAMETERS, TRAINING_JOB_NAME
INPUT_DATA_CONFIG, OUTPUT_DATA_CONFIG, HYPERPARAMETERS, ENVIRONMENT, TRAINING_JOB_NAME
)

docker_compose_file = os.path.join(
Expand Down Expand Up @@ -456,7 +458,11 @@ def test_train_error(

with pytest.raises(RuntimeError) as e:
sagemaker_container.train(
INPUT_DATA_CONFIG, OUTPUT_DATA_CONFIG, HYPERPARAMETERS, TRAINING_JOB_NAME
INPUT_DATA_CONFIG,
OUTPUT_DATA_CONFIG,
HYPERPARAMETERS,
ENVIRONMENT,
TRAINING_JOB_NAME,
)

assert "this is expected" in str(e)
Expand Down Expand Up @@ -486,7 +492,11 @@ def test_train_local_code(get_data_source_instance, tmpdir, sagemaker_session):
)

sagemaker_container.train(
INPUT_DATA_CONFIG, OUTPUT_DATA_CONFIG, LOCAL_CODE_HYPERPARAMETERS, TRAINING_JOB_NAME
INPUT_DATA_CONFIG,
OUTPUT_DATA_CONFIG,
LOCAL_CODE_HYPERPARAMETERS,
ENVIRONMENT,
TRAINING_JOB_NAME,
)

docker_compose_file = os.path.join(
Expand Down Expand Up @@ -538,7 +548,7 @@ def test_train_local_intermediate_output(get_data_source_instance, tmpdir, sagem
hyperparameters = {"sagemaker_s3_output": output_path}

sagemaker_container.train(
INPUT_DATA_CONFIG, output_data_config, hyperparameters, TRAINING_JOB_NAME
INPUT_DATA_CONFIG, output_data_config, hyperparameters, ENVIRONMENT, TRAINING_JOB_NAME
)

docker_compose_file = os.path.join(
Expand Down