Skip to content

Commit 34f23ce

Browse files
Taragolisephraimbuddy
authored andcommitted
Avoid to use functools.lru_cache in class methods in google provider (#38652)
(cherry picked from commit d3dc88f)
1 parent 74abd01 commit 34f23ce

File tree

3 files changed

+36
-24
lines changed

3 files changed

+36
-24
lines changed

airflow/providers/google/cloud/hooks/compute_ssh.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -335,7 +335,7 @@ def _authorize_compute_engine_instance_metadata(self, pubkey):
335335
)
336336

337337
def _authorize_os_login(self, pubkey):
338-
username = self._oslogin_hook._get_credentials_email()
338+
username = self._oslogin_hook._get_credentials_email
339339
self.log.info("Importing SSH public key using OSLogin: user=%s", username)
340340
expiration = int((time.time() + self.expire_time) * 1000000)
341341
ssh_public_key = {"key": pubkey, "expiration_time_usec": expiration}

airflow/providers/google/common/hooks/base_google.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -317,7 +317,7 @@ def _get_access_token(self) -> str:
317317
credentials.refresh(auth_req)
318318
return credentials.token
319319

320-
@functools.lru_cache(maxsize=None)
320+
@functools.cached_property
321321
def _get_credentials_email(self) -> str:
322322
"""
323323
Return the email address associated with the currently logged in account.

tests/providers/google/cloud/hooks/test_compute_ssh.py

Lines changed: 34 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from airflow.exceptions import AirflowException
2929
from airflow.models import Connection
3030
from airflow.providers.google.cloud.hooks.compute_ssh import ComputeEngineSSHHook
31+
from airflow.providers.google.cloud.hooks.os_login import OSLoginHook
3132

3233
pytestmark = pytest.mark.db_test
3334

@@ -48,22 +49,35 @@ def test_delegate_to_runtime_error(self):
4849
with pytest.raises(RuntimeError):
4950
ComputeEngineSSHHook(gcp_conn_id="gcpssh", delegate_to="delegate_to")
5051

52+
def test_os_login_hook(self, mocker):
53+
mock_os_login_hook = mocker.patch.object(OSLoginHook, "__init__", return_value=None, spec=OSLoginHook)
54+
55+
# Default values
56+
assert ComputeEngineSSHHook()._oslogin_hook
57+
mock_os_login_hook.assert_called_with(gcp_conn_id="google_cloud_default")
58+
59+
# Custom conn_id
60+
assert ComputeEngineSSHHook(gcp_conn_id="gcpssh")._oslogin_hook
61+
mock_os_login_hook.assert_called_with(gcp_conn_id="gcpssh")
62+
5163
@mock.patch("airflow.providers.google.cloud.hooks.compute_ssh.ComputeEngineHook")
52-
@mock.patch("airflow.providers.google.cloud.hooks.compute_ssh.OSLoginHook")
5364
@mock.patch("airflow.providers.google.cloud.hooks.compute_ssh.paramiko")
5465
@mock.patch("airflow.providers.google.cloud.hooks.compute_ssh._GCloudAuthorizedSSHClient")
55-
def test_get_conn_default_configuration(
56-
self, mock_ssh_client, mock_paramiko, mock_os_login_hook, mock_compute_hook
57-
):
58-
mock_paramiko.SSHException = Exception
66+
def test_get_conn_default_configuration(self, mock_ssh_client, mock_paramiko, mock_compute_hook, mocker):
67+
mock_paramiko.SSHException = RuntimeError
5968
mock_paramiko.RSAKey.generate.return_value.get_name.return_value = "NAME"
6069
mock_paramiko.RSAKey.generate.return_value.get_base64.return_value = "AYZ"
6170

6271
mock_compute_hook.return_value.project_id = TEST_PROJECT_ID
6372
mock_compute_hook.return_value.get_instance_address.return_value = EXTERNAL_IP
6473

65-
mock_os_login_hook.return_value._get_credentials_email.return_value = "[email protected]"
66-
mock_os_login_hook.return_value.import_ssh_public_key.return_value.login_profile.posix_accounts = [
74+
mock_os_login_hook = mocker.patch.object(
75+
ComputeEngineSSHHook, "_oslogin_hook", spec=OSLoginHook, name="FakeOsLoginHook"
76+
)
77+
type(mock_os_login_hook)._get_credentials_email = mock.PropertyMock(
78+
return_value="[email protected]"
79+
)
80+
mock_os_login_hook.import_ssh_public_key.return_value.login_profile.posix_accounts = [
6781
mock.MagicMock(username="test-username")
6882
]
6983

@@ -83,16 +97,10 @@ def test_get_conn_default_configuration(
8397
),
8498
]
8599
)
86-
mock_os_login_hook.assert_has_calls(
87-
[
88-
mock.call(gcp_conn_id="google_cloud_default"),
89-
mock.call()._get_credentials_email(),
90-
mock.call().import_ssh_public_key(
91-
ssh_public_key={"key": "NAME AYZ root", "expiration_time_usec": mock.ANY},
92-
project_id="test-project-id",
93-
user=mock_os_login_hook.return_value._get_credentials_email.return_value,
94-
),
95-
]
100+
mock_os_login_hook.import_ssh_public_key.assert_called_once_with(
101+
ssh_public_key={"key": "NAME AYZ root", "expiration_time_usec": mock.ANY},
102+
project_id="test-project-id",
103+
96104
)
97105
mock_ssh_client.assert_has_calls(
98106
[
@@ -113,7 +121,6 @@ def test_get_conn_default_configuration(
113121
[(SSHException, r"Error occurred when establishing SSH connection using Paramiko")],
114122
)
115123
@mock.patch("airflow.providers.google.cloud.hooks.compute_ssh.ComputeEngineHook")
116-
@mock.patch("airflow.providers.google.cloud.hooks.compute_ssh.OSLoginHook")
117124
@mock.patch("airflow.providers.google.cloud.hooks.compute_ssh.paramiko")
118125
@mock.patch("airflow.providers.google.cloud.hooks.compute_ssh._GCloudAuthorizedSSHClient")
119126
@mock.patch("airflow.providers.google.cloud.hooks.compute_ssh.ComputeEngineSSHHook._connect_to_instance")
@@ -122,21 +129,26 @@ def test_get_conn_default_configuration_test_exceptions(
122129
mock_connect,
123130
mock_ssh_client,
124131
mock_paramiko,
125-
mock_os_login_hook,
126132
mock_compute_hook,
127133
exception_type,
128134
error_message,
129135
caplog,
136+
mocker,
130137
):
131-
mock_paramiko.SSHException = Exception
138+
mock_paramiko.SSHException = RuntimeError
132139
mock_paramiko.RSAKey.generate.return_value.get_name.return_value = "NAME"
133140
mock_paramiko.RSAKey.generate.return_value.get_base64.return_value = "AYZ"
134141

135142
mock_compute_hook.return_value.project_id = TEST_PROJECT_ID
136143
mock_compute_hook.return_value.get_instance_address.return_value = EXTERNAL_IP
137144

138-
mock_os_login_hook.return_value._get_credentials_email.return_value = "[email protected]"
139-
mock_os_login_hook.return_value.import_ssh_public_key.return_value.login_profile.posix_accounts = [
145+
mock_os_login_hook = mocker.patch.object(
146+
ComputeEngineSSHHook, "_oslogin_hook", spec=OSLoginHook, name="FakeOsLoginHook"
147+
)
148+
type(mock_os_login_hook)._get_credentials_email = mock.PropertyMock(
149+
return_value="[email protected]"
150+
)
151+
mock_os_login_hook.import_ssh_public_key.return_value.login_profile.posix_accounts = [
140152
mock.MagicMock(username="test-username")
141153
]
142154

0 commit comments

Comments
 (0)