diff --git a/tests/data/horovod/test_hvd_basic.py b/tests/data/horovod/test_hvd_basic.py index 6de7700eae..d37214defc 100644 --- a/tests/data/horovod/test_hvd_basic.py +++ b/tests/data/horovod/test_hvd_basic.py @@ -1,11 +1,15 @@ import json import os + import horovod.tensorflow as hvd -hvd.init() +if __name__ == '__main__': + + hvd.init() + + with open(os.path.join('/opt/ml/model/rank-%s' % hvd.rank()), 'w+') as f: + basic_info = {'rank': hvd.rank(), 'size': hvd.size()} -with open(os.path.join('/opt/ml/model/rank-%s' % hvd.rank()), 'w+') as f: - basic_info = {'rank': hvd.rank(), 'size': hvd.size()} + json.dump(basic_info, f) + print('Saved file "rank-%s": %s' % (hvd.rank(), basic_info)) - print(basic_info) - json.dump(basic_info, f)