Skip to content

Commit aecb66a

Browse files
laurenyuPiali Das
authored and
Piali Das
committed
Local mode pass training env var (aws#411)
SageMaker Training sets various environment variables for every training job, while Local Mode replicates none of these. This change adds the environment variables for the AWS region and training job name.
1 parent 08eea20 commit aecb66a

File tree

3 files changed

+26
-8
lines changed

3 files changed

+26
-8
lines changed

CHANGELOG.rst

+5
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,11 @@
22
CHANGELOG
33
=========
44

5+
1.11.1dev
6+
=========
7+
8+
* enhancement: Local Mode: add training environment variables for AWS region and job name
9+
510
1.11.0
611
======
712

src/sagemaker/local/image.py

+14-5
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,13 @@
3434
import sagemaker
3535
from sagemaker.utils import get_config_value
3636

37-
CONTAINER_PREFIX = "algo"
37+
CONTAINER_PREFIX = 'algo'
3838
DOCKER_COMPOSE_FILENAME = 'docker-compose.yaml'
3939

40+
# Environment variables to be set during training
41+
REGION_ENV_NAME = 'AWS_REGION'
42+
TRAINING_JOB_NAME_ENV_NAME = 'TRAINING_JOB_NAME'
43+
4044
logger = logging.getLogger(__name__)
4145
logger.setLevel(logging.WARNING)
4246

@@ -102,7 +106,12 @@ def train(self, input_data_config, hyperparameters):
102106
self.write_config_files(host, hyperparameters, input_data_config)
103107
shutil.copytree(data_dir, os.path.join(self.container_root, host, 'input', 'data'))
104108

105-
compose_data = self._generate_compose_file('train', additional_volumes=volumes)
109+
training_env_vars = {
110+
REGION_ENV_NAME: self.sagemaker_session.boto_region_name,
111+
TRAINING_JOB_NAME_ENV_NAME: json.loads(hyperparameters.get(sagemaker.model.JOB_NAME_PARAM_NAME)),
112+
}
113+
compose_data = self._generate_compose_file('train', additional_volumes=volumes,
114+
additional_env_vars=training_env_vars)
106115
compose_command = self._compose()
107116

108117
_ecr_login_if_needed(self.sagemaker_session.boto_session, self.image)
@@ -149,7 +158,6 @@ def serve(self, model_dir, environment):
149158
logger.info('creating hosting dir in {}'.format(self.container_root))
150159

151160
volumes = self._prepare_serving_volumes(model_dir)
152-
env_vars = ['{}={}'.format(k, v) for k, v in environment.items()]
153161

154162
# If the user script was passed as a file:// mount it to the container.
155163
if sagemaker.estimator.DIR_PARAM_NAME.upper() in environment:
@@ -161,7 +169,7 @@ def serve(self, model_dir, environment):
161169
_ecr_login_if_needed(self.sagemaker_session.boto_session, self.image)
162170

163171
self._generate_compose_file('serve',
164-
additional_env_vars=env_vars,
172+
additional_env_vars=environment,
165173
additional_volumes=volumes)
166174
compose_command = self._compose()
167175
self.container = _HostingContainer(compose_command)
@@ -384,7 +392,8 @@ def _generate_compose_file(self, command, additional_volumes=None, additional_en
384392
if aws_creds is not None:
385393
environment.extend(aws_creds)
386394

387-
environment.extend(additional_env_vars)
395+
additional_env_var_list = ['{}={}'.format(k, v) for k, v in additional_env_vars.items()]
396+
environment.extend(additional_env_var_list)
388397

389398
if command == 'train':
390399
optml_dirs = {'output', 'output/data', 'input'}

tests/unit/test_image.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -49,12 +49,14 @@
4949
}
5050
]
5151
HYPERPARAMETERS = {'a': 1,
52-
'b': 'bee',
53-
'sagemaker_submit_directory': json.dumps('s3://my_bucket/code')}
52+
'b': json.dumps('bee'),
53+
'sagemaker_submit_directory': json.dumps('s3://my_bucket/code'),
54+
'sagemaker_job_name': json.dumps('my-job')}
5455

5556
LOCAL_CODE_HYPERPARAMETERS = {'a': 1,
5657
'b': 2,
57-
'sagemaker_submit_directory': json.dumps('file:///tmp/code')}
58+
'sagemaker_submit_directory': json.dumps('file:///tmp/code'),
59+
'sagemaker_job_name': json.dumps('my-job')}
5860

5961

6062
@pytest.fixture()
@@ -244,6 +246,8 @@ def test_train(_download_folder, _cleanup, popen, _stream_output, LocalSession,
244246
for h in sagemaker_container.hosts:
245247
assert config['services'][h]['image'] == image
246248
assert config['services'][h]['command'] == 'train'
249+
assert 'AWS_REGION={}'.format(REGION) in config['services'][h]['environment']
250+
assert 'TRAINING_JOB_NAME=my-job' in config['services'][h]['environment']
247251

248252
# assert that expected by sagemaker container output directories exist
249253
assert os.path.exists(os.path.join(sagemaker_container.container_root, 'output'))

0 commit comments

Comments
 (0)