Skip to content

Commit ae2344e

Browse files
fix: deriver master node from training environment
1 parent cfe0a66 commit ae2344e

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

src/sagemaker_pytorch_container/training.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def train(training_environment):
4747

4848
_set_nccl_environment(training_environment.network_interface_name)
4949

50-
_set_distributed_environment(training_environment.hosts)
50+
_set_distributed_environment(training_environment)
5151

5252
mpi_enabled = training_environment.additional_framework_parameters.get('sagemaker_mpi_enabled')
5353

@@ -88,15 +88,15 @@ def _dns_lookup(host):
8888
return socket.gethostbyname(host)
8989

9090

91-
def _set_distributed_environment(hosts):
91+
def _set_distributed_environment(training_env):
9292
"""Set environment variable for distributed training.
9393
9494
Args:
9595
hosts: list of hosts that are used for training.
9696
"""
9797
# According to https://docs.aws.amazon.com/sagemaker/latest/dg/your-algorithms-training-algo.html
9898
# hosts are sorted lexicographically.
99-
os.environ['MASTER_ADDR'] = hosts[0]
99+
os.environ['MASTER_ADDR'] = training_env.master_hostname
100100
os.environ['MASTER_PORT'] = MASTER_PORT
101101

102102

0 commit comments

Comments
 (0)