Skip to content

Commit f0f199c

Browse files
yongyanraoYongyan Rao
and
Yongyan Rao
authored
feature: add PyTorch version environment variable, to facilitate SMTT (#250)
Co-authored-by: Yongyan Rao <[email protected]>
1 parent 85ae2a9 commit f0f199c

File tree

2 files changed

+35
-1
lines changed

2 files changed

+35
-1
lines changed

src/sagemaker_pytorch_container/training.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@ def train(training_environment):
9696
capture_error = False
9797
logger.info(f'capture_error is {capture_error}. Default is True')
9898

99+
_set_torch_version_environment()
99100
try:
100101
entry_point.run(uri=training_environment.module_dir,
101102
user_entry_point=training_environment.user_entry_point,
@@ -149,5 +150,22 @@ def _set_nccl_environment(network_interface_name):
149150
os.environ['NCCL_DEBUG'] = 'WARN'
150151

151152

153+
def _set_torch_version_environment():
154+
"""Set PyTorch version environment variable.
155+
156+
This is the PyTorch version of the DLC.
157+
"""
158+
try:
159+
import torch
160+
161+
os.environ["SM_DLC_TORCH_VERSION"] = torch.__version__
162+
except ModuleNotFoundError:
163+
logger.warn("PyTorch cannot be found")
164+
except ImportError:
165+
logger.warn("PyTorch can be found, but cannot be imported")
166+
except Exception:
167+
logger.warn("Torch version environment variable cannot be set")
168+
169+
152170
def main():
153171
train(environment.Environment())

test/unit/test_train.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
import os
1616
import shutil
17+
import sys
1718
import tempfile
1819

1920
import pytest
@@ -23,7 +24,14 @@
2324
from mock import MagicMock, PropertyMock
2425
from mock import patch
2526

26-
from sagemaker_pytorch_container.training import main, train, _dns_lookup, LAUNCH_PYTORCH_XLA_ENV_NAME, MASTER_PORT
27+
from sagemaker_pytorch_container.training import (
28+
main,
29+
train,
30+
_dns_lookup,
31+
LAUNCH_PYTORCH_XLA_ENV_NAME,
32+
MASTER_PORT,
33+
_set_torch_version_environment,
34+
)
2735

2836

2937
@pytest.fixture(name='training_env')
@@ -218,3 +226,11 @@ def test_user_script_error_raised(run_entry_point, training_env):
218226
)
219227
with pytest.raises(errors.ExecuteUserScriptError):
220228
train(training_env)
229+
230+
231+
def test_set_torch_version_environment():
232+
mock_torch = MagicMock()
233+
mock_torch.__version__ = '2.0.0'
234+
sys.modules['torch'] = mock_torch
235+
_set_torch_version_environment()
236+
assert os.environ.get("SM_DLC_TORCH_VERSION") == '2.0.0'

0 commit comments

Comments
 (0)