File tree Expand file tree Collapse file tree 1 file changed +6
-2
lines changed
src/sagemaker_tensorflow_container Expand file tree Collapse file tree 1 file changed +6
-2
lines changed Original file line number Diff line number Diff line change @@ -173,6 +173,8 @@ def train(env, cmd_args):
173
173
SAGEMAKER_DISTRIBUTED_DATAPARALLEL_ENABLED , False
174
174
)
175
175
176
+ env_vars = env .to_env_vars ()
177
+
176
178
# Setup
177
179
if parameter_server_enabled :
178
180
@@ -181,7 +183,9 @@ def train(env, cmd_args):
181
183
182
184
elif multi_worker_mirrored_strategy_enabled :
183
185
184
- tf_config = _build_tf_config_for_mwms (hosts = env .hosts , current_host = env .current_host )
186
+ env_vars ["TF_CONFIG" ] = _build_tf_config_for_mwms (
187
+ hosts = env .hosts , current_host = env .current_host
188
+ )
185
189
logger .info ("Running distributed training job with multi_worker_mirrored_strategy setup" )
186
190
187
191
# Run
@@ -210,7 +214,7 @@ def train(env, cmd_args):
210
214
uri = env .module_dir ,
211
215
user_entry_point = env .user_entry_point ,
212
216
args = cmd_args ,
213
- env_vars = env . to_env_vars () ,
217
+ env_vars = env_vars ,
214
218
capture_error = True ,
215
219
runner_type = runner_type ,
216
220
)
You can’t perform that action at this time.
0 commit comments