34
34
import sagemaker
35
35
from sagemaker .utils import get_config_value
36
36
37
- CONTAINER_PREFIX = " algo"
37
+ CONTAINER_PREFIX = ' algo'
38
38
DOCKER_COMPOSE_FILENAME = 'docker-compose.yaml'
39
39
40
+ # Environment variables to be set during training
41
+ REGION_ENV_NAME = 'AWS_REGION'
42
+ TRAINING_JOB_NAME_ENV_NAME = 'TRAINING_JOB_NAME'
43
+
40
44
logger = logging .getLogger (__name__ )
41
45
logger .setLevel (logging .WARNING )
42
46
@@ -102,7 +106,12 @@ def train(self, input_data_config, hyperparameters):
102
106
self .write_config_files (host , hyperparameters , input_data_config )
103
107
shutil .copytree (data_dir , os .path .join (self .container_root , host , 'input' , 'data' ))
104
108
105
- compose_data = self ._generate_compose_file ('train' , additional_volumes = volumes )
109
+ training_env_vars = {
110
+ REGION_ENV_NAME : self .sagemaker_session .boto_region_name ,
111
+ TRAINING_JOB_NAME_ENV_NAME : json .loads (hyperparameters .get (sagemaker .model .JOB_NAME_PARAM_NAME )),
112
+ }
113
+ compose_data = self ._generate_compose_file ('train' , additional_volumes = volumes ,
114
+ additional_env_vars = training_env_vars )
106
115
compose_command = self ._compose ()
107
116
108
117
_ecr_login_if_needed (self .sagemaker_session .boto_session , self .image )
@@ -149,7 +158,6 @@ def serve(self, model_dir, environment):
149
158
logger .info ('creating hosting dir in {}' .format (self .container_root ))
150
159
151
160
volumes = self ._prepare_serving_volumes (model_dir )
152
- env_vars = ['{}={}' .format (k , v ) for k , v in environment .items ()]
153
161
154
162
# If the user script was passed as a file:// mount it to the container.
155
163
if sagemaker .estimator .DIR_PARAM_NAME .upper () in environment :
@@ -161,7 +169,7 @@ def serve(self, model_dir, environment):
161
169
_ecr_login_if_needed (self .sagemaker_session .boto_session , self .image )
162
170
163
171
self ._generate_compose_file ('serve' ,
164
- additional_env_vars = env_vars ,
172
+ additional_env_vars = environment ,
165
173
additional_volumes = volumes )
166
174
compose_command = self ._compose ()
167
175
self .container = _HostingContainer (compose_command )
@@ -384,7 +392,8 @@ def _generate_compose_file(self, command, additional_volumes=None, additional_en
384
392
if aws_creds is not None :
385
393
environment .extend (aws_creds )
386
394
387
- environment .extend (additional_env_vars )
395
+ additional_env_var_list = ['{}={}' .format (k , v ) for k , v in additional_env_vars .items ()]
396
+ environment .extend (additional_env_var_list )
388
397
389
398
if command == 'train' :
390
399
optml_dirs = {'output' , 'output/data' , 'input' }
0 commit comments