Skip to content

Commit f8ce3f0

Browse files
committed
fix: logic error in MWMS
1 parent f5cf636 commit f8ce3f0

File tree

1 file changed

+6
-2
lines changed

1 file changed

+6
-2
lines changed

src/sagemaker_tensorflow_container/training.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,8 @@ def train(env, cmd_args):
173173
SAGEMAKER_DISTRIBUTED_DATAPARALLEL_ENABLED, False
174174
)
175175

176+
env_vars = env.to_env_vars()
177+
176178
# Setup
177179
if parameter_server_enabled:
178180

@@ -181,7 +183,9 @@ def train(env, cmd_args):
181183

182184
elif multi_worker_mirrored_strategy_enabled:
183185

184-
tf_config = _build_tf_config_for_mwms(hosts=env.hosts, current_host=env.current_host)
186+
env_vars["TF_CONFIG"] = _build_tf_config_for_mwms(
187+
hosts=env.hosts, current_host=env.current_host
188+
)
185189
logger.info("Running distributed training job with multi_worker_mirrored_strategy setup")
186190

187191
# Run
@@ -210,7 +214,7 @@ def train(env, cmd_args):
210214
uri=env.module_dir,
211215
user_entry_point=env.user_entry_point,
212216
args=cmd_args,
213-
env_vars=env.to_env_vars(),
217+
env_vars=env_vars,
214218
capture_error=True,
215219
runner_type=runner_type,
216220
)

0 commit comments

Comments
 (0)