Skip to content

Commit a12bc7d

Browse files
fix: deriver master node from training environment (#238)
* fix: deriver master node from training environment * Fix unit tests
1 parent cfe0a66 commit a12bc7d

File tree

2 files changed

+5
-4
lines changed

2 files changed

+5
-4
lines changed

src/sagemaker_pytorch_container/training.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def train(training_environment):
4747

4848
_set_nccl_environment(training_environment.network_interface_name)
4949

50-
_set_distributed_environment(training_environment.hosts)
50+
_set_distributed_environment(training_environment)
5151

5252
mpi_enabled = training_environment.additional_framework_parameters.get('sagemaker_mpi_enabled')
5353

@@ -88,15 +88,15 @@ def _dns_lookup(host):
8888
return socket.gethostbyname(host)
8989

9090

91-
def _set_distributed_environment(hosts):
91+
def _set_distributed_environment(training_env):
9292
"""Set environment variable for distributed training.
9393
9494
Args:
9595
hosts: list of hosts that are used for training.
9696
"""
9797
# According to https://docs.aws.amazon.com/sagemaker/latest/dg/your-algorithms-training-algo.html
9898
# hosts are sorted lexicographically.
99-
os.environ['MASTER_ADDR'] = hosts[0]
99+
os.environ['MASTER_ADDR'] = training_env.master_hostname
100100
os.environ['MASTER_PORT'] = MASTER_PORT
101101

102102

test/unit/test_train.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ def fixture_training_env():
3131
env = MagicMock()
3232
env.current_host = 'algo-1'
3333
env.hosts = ['algo-1']
34+
env.master_hostname = 'algo-1'
3435
env.network_interface_name = 'eth0'
3536
tmp = tempfile.mkdtemp()
3637
os.makedirs(os.path.join(tmp, 'model'))
@@ -96,7 +97,7 @@ def test_environment(training_env):
9697

9798
# distributed training specific environment
9899
assert MASTER_PORT == os.environ['MASTER_PORT']
99-
assert training_env.hosts[0] == os.environ['MASTER_ADDR']
100+
assert training_env.master_hostname == os.environ['MASTER_ADDR']
100101

101102
# nccl specific environment
102103
assert training_env.network_interface_name == os.environ['NCCL_SOCKET_IFNAME']

0 commit comments

Comments
 (0)