28
28
from airflow .exceptions import AirflowException
29
29
from airflow .models import Connection
30
30
from airflow .providers .google .cloud .hooks .compute_ssh import ComputeEngineSSHHook
31
+ from airflow .providers .google .cloud .hooks .os_login import OSLoginHook
31
32
32
33
pytestmark = pytest .mark .db_test
33
34
@@ -48,22 +49,35 @@ def test_delegate_to_runtime_error(self):
48
49
with pytest .raises (RuntimeError ):
49
50
ComputeEngineSSHHook (gcp_conn_id = "gcpssh" , delegate_to = "delegate_to" )
50
51
52
+ def test_os_login_hook (self , mocker ):
53
+ mock_os_login_hook = mocker .patch .object (OSLoginHook , "__init__" , return_value = None , spec = OSLoginHook )
54
+
55
+ # Default values
56
+ assert ComputeEngineSSHHook ()._oslogin_hook
57
+ mock_os_login_hook .assert_called_with (gcp_conn_id = "google_cloud_default" )
58
+
59
+ # Custom conn_id
60
+ assert ComputeEngineSSHHook (gcp_conn_id = "gcpssh" )._oslogin_hook
61
+ mock_os_login_hook .assert_called_with (gcp_conn_id = "gcpssh" )
62
+
51
63
@mock .patch ("airflow.providers.google.cloud.hooks.compute_ssh.ComputeEngineHook" )
52
- @mock .patch ("airflow.providers.google.cloud.hooks.compute_ssh.OSLoginHook" )
53
64
@mock .patch ("airflow.providers.google.cloud.hooks.compute_ssh.paramiko" )
54
65
@mock .patch ("airflow.providers.google.cloud.hooks.compute_ssh._GCloudAuthorizedSSHClient" )
55
- def test_get_conn_default_configuration (
56
- self , mock_ssh_client , mock_paramiko , mock_os_login_hook , mock_compute_hook
57
- ):
58
- mock_paramiko .SSHException = Exception
66
+ def test_get_conn_default_configuration (self , mock_ssh_client , mock_paramiko , mock_compute_hook , mocker ):
67
+ mock_paramiko .SSHException = RuntimeError
59
68
mock_paramiko .RSAKey .generate .return_value .get_name .return_value = "NAME"
60
69
mock_paramiko .RSAKey .generate .return_value .get_base64 .return_value = "AYZ"
61
70
62
71
mock_compute_hook .return_value .project_id = TEST_PROJECT_ID
63
72
mock_compute_hook .return_value .get_instance_address .return_value = EXTERNAL_IP
64
73
65
- mock_os_login_hook .
return_value .
_get_credentials_email .
return_value = "[email protected] "
66
- mock_os_login_hook .return_value .import_ssh_public_key .return_value .login_profile .posix_accounts = [
74
+ mock_os_login_hook = mocker .patch .object (
75
+ ComputeEngineSSHHook , "_oslogin_hook" , spec = OSLoginHook , name = "FakeOsLoginHook"
76
+ )
77
+ type(mock_os_login_hook )._get_credentials_email = mock .PropertyMock (
78
+
79
+ )
80
+ mock_os_login_hook .import_ssh_public_key .return_value .login_profile .posix_accounts = [
67
81
mock .MagicMock (username = "test-username" )
68
82
]
69
83
@@ -83,16 +97,10 @@ def test_get_conn_default_configuration(
83
97
),
84
98
]
85
99
)
86
- mock_os_login_hook .assert_has_calls (
87
- [
88
- mock .call (gcp_conn_id = "google_cloud_default" ),
89
- mock .call ()._get_credentials_email (),
90
- mock .call ().import_ssh_public_key (
91
- ssh_public_key = {"key" : "NAME AYZ root" , "expiration_time_usec" : mock .ANY },
92
- project_id = "test-project-id" ,
93
- user = mock_os_login_hook .return_value ._get_credentials_email .return_value ,
94
- ),
95
- ]
100
+ mock_os_login_hook .import_ssh_public_key .assert_called_once_with (
101
+ ssh_public_key = {"key" : "NAME AYZ root" , "expiration_time_usec" : mock .ANY },
102
+ project_id = "test-project-id" ,
103
+
96
104
)
97
105
mock_ssh_client .assert_has_calls (
98
106
[
@@ -113,7 +121,6 @@ def test_get_conn_default_configuration(
113
121
[(SSHException , r"Error occurred when establishing SSH connection using Paramiko" )],
114
122
)
115
123
@mock .patch ("airflow.providers.google.cloud.hooks.compute_ssh.ComputeEngineHook" )
116
- @mock .patch ("airflow.providers.google.cloud.hooks.compute_ssh.OSLoginHook" )
117
124
@mock .patch ("airflow.providers.google.cloud.hooks.compute_ssh.paramiko" )
118
125
@mock .patch ("airflow.providers.google.cloud.hooks.compute_ssh._GCloudAuthorizedSSHClient" )
119
126
@mock .patch ("airflow.providers.google.cloud.hooks.compute_ssh.ComputeEngineSSHHook._connect_to_instance" )
@@ -122,21 +129,26 @@ def test_get_conn_default_configuration_test_exceptions(
122
129
mock_connect ,
123
130
mock_ssh_client ,
124
131
mock_paramiko ,
125
- mock_os_login_hook ,
126
132
mock_compute_hook ,
127
133
exception_type ,
128
134
error_message ,
129
135
caplog ,
136
+ mocker ,
130
137
):
131
- mock_paramiko .SSHException = Exception
138
+ mock_paramiko .SSHException = RuntimeError
132
139
mock_paramiko .RSAKey .generate .return_value .get_name .return_value = "NAME"
133
140
mock_paramiko .RSAKey .generate .return_value .get_base64 .return_value = "AYZ"
134
141
135
142
mock_compute_hook .return_value .project_id = TEST_PROJECT_ID
136
143
mock_compute_hook .return_value .get_instance_address .return_value = EXTERNAL_IP
137
144
138
- mock_os_login_hook .
return_value .
_get_credentials_email .
return_value = "[email protected] "
139
- mock_os_login_hook .return_value .import_ssh_public_key .return_value .login_profile .posix_accounts = [
145
+ mock_os_login_hook = mocker .patch .object (
146
+ ComputeEngineSSHHook , "_oslogin_hook" , spec = OSLoginHook , name = "FakeOsLoginHook"
147
+ )
148
+ type(mock_os_login_hook )._get_credentials_email = mock .PropertyMock (
149
+
150
+ )
151
+ mock_os_login_hook .import_ssh_public_key .return_value .login_profile .posix_accounts = [
140
152
mock .MagicMock (username = "test-username" )
141
153
]
142
154
0 commit comments