Skip to content

Commit f8551a2

Browse files
change: bypass DNS check for studio local exec (#252)
* change: bypass DNS check for studio local exec * fix: unit test * fix: unit test --------- Co-authored-by: Mufaddal Rohawala <[email protected]>
1 parent a792fd0 commit f8551a2

File tree

2 files changed

+27
-5
lines changed

2 files changed

+27
-5
lines changed

src/sagemaker_pytorch_container/training.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,10 +44,15 @@ def train(training_environment):
4444
training_environment: training environment object containing environment
4545
variables, training arguments and hyperparameters.
4646
"""
47-
# Block until all host DNS lookups succeed. Relies on retrying dns_lookup.
48-
logger.info('Block until all host DNS lookups succeed.')
49-
for host in training_environment.hosts:
50-
_dns_lookup(host)
47+
_sm_studio_local_mode = os.environ.get("SM_STUDIO_LOCAL_MODE", "False").lower() == "true"
48+
49+
if not _sm_studio_local_mode:
50+
# Block until all host DNS lookups succeed. Relies on retrying dns_lookup.
51+
logger.info('Block until all host DNS lookups succeed.')
52+
for host in training_environment.hosts:
53+
_dns_lookup(host)
54+
else:
55+
logger.info('Bypass DNS check in case of Studio Local Mode execution.')
5156

5257
_set_nccl_environment(training_environment.network_interface_name)
5358

test/unit/test_train.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,17 +69,34 @@ def fixture_user_module_with_save():
6969
return MagicMock(spec=['train', 'save'])
7070

7171

72+
@patch('sagemaker_pytorch_container.training._dns_lookup')
7273
@patch('sagemaker_training.entry_point.run')
7374
@patch('socket.gethostbyname', MagicMock())
74-
def test_train(run_entry_point, training_env):
75+
def test_train(run_entry_point, dns_lookup, training_env):
7576
train(training_env)
77+
dns_lookup.assert_called_once_with('algo-1')
78+
run_entry_point.assert_called_with(uri=training_env.module_dir,
79+
user_entry_point=training_env.user_entry_point,
80+
args=training_env.to_cmd_args(),
81+
env_vars=training_env.to_env_vars(),
82+
capture_error=True,
83+
runner_type=runner.ProcessRunnerType)
84+
7685

86+
@patch('sagemaker_pytorch_container.training._dns_lookup')
87+
@patch('sagemaker_training.entry_point.run')
88+
@patch('socket.gethostbyname', MagicMock())
89+
def test_train_with_sm_studio_local_mode_enabled(run_entry_point, dns_lookup, training_env):
90+
os.environ['SM_STUDIO_LOCAL_MODE'] = 'True'
91+
train(training_env)
92+
dns_lookup.assert_not_called()
7793
run_entry_point.assert_called_with(uri=training_env.module_dir,
7894
user_entry_point=training_env.user_entry_point,
7995
args=training_env.to_cmd_args(),
8096
env_vars=training_env.to_env_vars(),
8197
capture_error=True,
8298
runner_type=runner.ProcessRunnerType)
99+
del os.environ['SM_STUDIO_LOCAL_MODE']
83100

84101

85102
@patch('sagemaker_training.entry_point.run')

0 commit comments

Comments
 (0)