diff --git a/google/cloud/sql/connector/connector.py b/google/cloud/sql/connector/connector.py index 05eaa51d..df519cd4 100755 --- a/google/cloud/sql/connector/connector.py +++ b/google/cloud/sql/connector/connector.py @@ -33,6 +33,7 @@ from google.cloud.sql.connector.enums import DriverMapping from google.cloud.sql.connector.enums import IPTypes from google.cloud.sql.connector.enums import RefreshStrategy +from google.cloud.sql.connector.exceptions import ClosedConnectorError from google.cloud.sql.connector.instance import RefreshAheadCache from google.cloud.sql.connector.lazy import LazyRefreshCache from google.cloud.sql.connector.monitored_cache import MonitoredCache @@ -153,6 +154,7 @@ def __init__( # connection name string and enable_iam_auth boolean flag self._cache: dict[tuple[str, bool], MonitoredCache] = {} self._client: Optional[CloudSQLClient] = None + self._closed: bool = False # initialize credentials scopes = ["https://www.googleapis.com/auth/sqlservice.admin"] @@ -242,6 +244,12 @@ def connect( # connect runs sync database connections on background thread. # Async database connections should call 'connect_async' directly to # avoid hanging indefinitely. + + # Check if the connector is closed before attempting to connect. + if self._closed: + raise ClosedConnectorError( + "Connection attempt failed because the connector has already been closed." + ) connect_future = asyncio.run_coroutine_threadsafe( self.connect_async(instance_connection_string, driver, **kwargs), self._loop, @@ -279,7 +287,13 @@ async def connect_async( and then subsequent attempt with IAM database authentication. KeyError: Unsupported database driver Must be one of pymysql, asyncpg, pg8000, and pytds. + RuntimeError: Connector has been closed. Cannot connect using a closed + Connector. """ + if self._closed: + raise ClosedConnectorError( + "Connection attempt failed because the connector has already been closed." + ) if self._keys is None: self._keys = asyncio.create_task(generate_keys()) if self._client is None: @@ -462,6 +476,7 @@ def close(self) -> None: self._loop.call_soon_threadsafe(self._loop.stop) # wait for thread to finish closing (i.e. loop to stop) self._thread.join() + self._closed = True async def close_async(self) -> None: """Helper function to cancel the cache's tasks @@ -469,6 +484,7 @@ async def close_async(self) -> None: await asyncio.gather(*[cache.close() for cache in self._cache.values()]) if self._client: await self._client.close() + self._closed = True async def create_async_connector( diff --git a/google/cloud/sql/connector/exceptions.py b/google/cloud/sql/connector/exceptions.py index da39ea25..1f15ced4 100644 --- a/google/cloud/sql/connector/exceptions.py +++ b/google/cloud/sql/connector/exceptions.py @@ -84,3 +84,10 @@ class CacheClosedError(Exception): Exception to be raised when a ConnectionInfoCache can not be accessed after it is closed. """ + + +class ClosedConnectorError(Exception): + """ + Exception to be raised when a Connector is closed and connect method is + called on it. + """ diff --git a/tests/unit/test_connector.py b/tests/unit/test_connector.py index 15769772..bde7f65a 100644 --- a/tests/unit/test_connector.py +++ b/tests/unit/test_connector.py @@ -28,6 +28,7 @@ from google.cloud.sql.connector import IPTypes from google.cloud.sql.connector.client import CloudSQLClient from google.cloud.sql.connector.connection_name import ConnectionName +from google.cloud.sql.connector.exceptions import ClosedConnectorError from google.cloud.sql.connector.exceptions import CloudSQLIPTypeError from google.cloud.sql.connector.exceptions import IncompatibleDriverError from google.cloud.sql.connector.instance import RefreshAheadCache @@ -468,3 +469,48 @@ def test_configured_quota_project_env_var( assert connector._quota_project == quota_project # unset env var del os.environ["GOOGLE_CLOUD_QUOTA_PROJECT"] + + +@pytest.mark.asyncio +async def test_connect_async_closed_connector( + fake_credentials: Credentials, fake_client: CloudSQLClient +) -> None: + """Test that calling connect_async() on a closed connector raises an error.""" + async with Connector( + credentials=fake_credentials, loop=asyncio.get_running_loop() + ) as connector: + connector._client = fake_client + await connector.close_async() + with pytest.raises(ClosedConnectorError) as exc_info: + await connector.connect_async( + "test-project:test-region:test-instance", + "asyncpg", + user="my-user", + password="my-pass", + db="my-db", + ) + assert ( + exc_info.value.args[0] + == "Connection attempt failed because the connector has already been closed." + ) + + +def test_connect_closed_connector( + fake_credentials: Credentials, fake_client: CloudSQLClient +) -> None: + """Test that calling connect() on a closed connector raises an error.""" + with Connector(credentials=fake_credentials) as connector: + connector._client = fake_client + connector.close() + with pytest.raises(ClosedConnectorError) as exc_info: + connector.connect( + "test-project:test-region:test-instance", + "pg8000", + user="my-user", + password="my-pass", + db="my-db", + ) + assert ( + exc_info.value.args[0] + == "Connection attempt failed because the connector has already been closed." + )