File tree 1 file changed +3
-3
lines changed
src/sagemaker_pytorch_container
1 file changed +3
-3
lines changed Original file line number Diff line number Diff line change @@ -47,7 +47,7 @@ def train(training_environment):
47
47
48
48
_set_nccl_environment (training_environment .network_interface_name )
49
49
50
- _set_distributed_environment (training_environment . hosts )
50
+ _set_distributed_environment (training_environment )
51
51
52
52
mpi_enabled = training_environment .additional_framework_parameters .get ('sagemaker_mpi_enabled' )
53
53
@@ -88,15 +88,15 @@ def _dns_lookup(host):
88
88
return socket .gethostbyname (host )
89
89
90
90
91
- def _set_distributed_environment (hosts ):
91
+ def _set_distributed_environment (training_env ):
92
92
"""Set environment variable for distributed training.
93
93
94
94
Args:
95
95
hosts: list of hosts that are used for training.
96
96
"""
97
97
# According to https://docs.aws.amazon.com/sagemaker/latest/dg/your-algorithms-training-algo.html
98
98
# hosts are sorted lexicographically.
99
- os .environ ['MASTER_ADDR' ] = hosts [ 0 ]
99
+ os .environ ['MASTER_ADDR' ] = training_env . master_hostname
100
100
os .environ ['MASTER_PORT' ] = MASTER_PORT
101
101
102
102
You can’t perform that action at this time.
0 commit comments