diff --git a/src/sagemaker_pytorch_container/training.py b/src/sagemaker_pytorch_container/training.py index 45a163d..9331117 100644 --- a/src/sagemaker_pytorch_container/training.py +++ b/src/sagemaker_pytorch_container/training.py @@ -47,7 +47,7 @@ def train(training_environment): _set_nccl_environment(training_environment.network_interface_name) - _set_distributed_environment(training_environment.hosts) + _set_distributed_environment(training_environment) mpi_enabled = training_environment.additional_framework_parameters.get('sagemaker_mpi_enabled') @@ -88,7 +88,7 @@ def _dns_lookup(host): return socket.gethostbyname(host) -def _set_distributed_environment(hosts): +def _set_distributed_environment(training_env): """Set environment variable for distributed training. Args: @@ -96,7 +96,7 @@ def _set_distributed_environment(hosts): """ # According to https://docs.aws.amazon.com/sagemaker/latest/dg/your-algorithms-training-algo.html # hosts are sorted lexicographically. - os.environ['MASTER_ADDR'] = hosts[0] + os.environ['MASTER_ADDR'] = training_env.master_hostname os.environ['MASTER_PORT'] = MASTER_PORT diff --git a/test/unit/test_train.py b/test/unit/test_train.py index b47e03f..a82ef51 100644 --- a/test/unit/test_train.py +++ b/test/unit/test_train.py @@ -31,6 +31,7 @@ def fixture_training_env(): env = MagicMock() env.current_host = 'algo-1' env.hosts = ['algo-1'] + env.master_hostname = 'algo-1' env.network_interface_name = 'eth0' tmp = tempfile.mkdtemp() os.makedirs(os.path.join(tmp, 'model')) @@ -96,7 +97,7 @@ def test_environment(training_env): # distributed training specific environment assert MASTER_PORT == os.environ['MASTER_PORT'] - assert training_env.hosts[0] == os.environ['MASTER_ADDR'] + assert training_env.master_hostname == os.environ['MASTER_ADDR'] # nccl specific environment assert training_env.network_interface_name == os.environ['NCCL_SOCKET_IFNAME']