@@ -69,17 +69,34 @@ def fixture_user_module_with_save():
69
69
return MagicMock (spec = ['train' , 'save' ])
70
70
71
71
72
+ @patch ('sagemaker_pytorch_container.training._dns_lookup' )
72
73
@patch ('sagemaker_training.entry_point.run' )
73
74
@patch ('socket.gethostbyname' , MagicMock ())
74
- def test_train (run_entry_point , training_env ):
75
+ def test_train (run_entry_point , dns_lookup , training_env ):
75
76
train (training_env )
77
+ dns_lookup .assert_called_once_with ('algo-1' )
78
+ run_entry_point .assert_called_with (uri = training_env .module_dir ,
79
+ user_entry_point = training_env .user_entry_point ,
80
+ args = training_env .to_cmd_args (),
81
+ env_vars = training_env .to_env_vars (),
82
+ capture_error = True ,
83
+ runner_type = runner .ProcessRunnerType )
84
+
76
85
86
+ @patch ('sagemaker_pytorch_container.training._dns_lookup' )
87
+ @patch ('sagemaker_training.entry_point.run' )
88
+ @patch ('socket.gethostbyname' , MagicMock ())
89
+ def test_train_with_sm_studio_local_mode_enabled (run_entry_point , dns_lookup , training_env ):
90
+ os .environ ['SM_STUDIO_LOCAL_MODE' ] = 'True'
91
+ train (training_env )
92
+ dns_lookup .assert_not_called ()
77
93
run_entry_point .assert_called_with (uri = training_env .module_dir ,
78
94
user_entry_point = training_env .user_entry_point ,
79
95
args = training_env .to_cmd_args (),
80
96
env_vars = training_env .to_env_vars (),
81
97
capture_error = True ,
82
98
runner_type = runner .ProcessRunnerType )
99
+ del os .environ ['SM_STUDIO_LOCAL_MODE' ]
83
100
84
101
85
102
@patch ('sagemaker_training.entry_point.run' )
0 commit comments