diff --git a/README.md b/README.md index d2f52bfb..fc06cf6f 100644 --- a/README.md +++ b/README.md @@ -264,6 +264,10 @@ client = Finch( See the httpx documentation for information about the [`proxies`](https://www.python-httpx.org/advanced/#http-proxying) and [`transport`](https://www.python-httpx.org/advanced/#custom-transports) keyword arguments. +## Advanced: Managing HTTP resources + +By default we will close the underlying HTTP connections whenever the client is [garbage collected](https://docs.python.org/3/reference/datamodel.html#object.__del__) is called but you can also manually close the client using the `.close()` method if desired, or with a context manager that closes when exiting. + ## Status This package is in beta. Its internals and interfaces are not stable and subject to change without a major semver bump; diff --git a/src/finch/_base_client.py b/src/finch/_base_client.py index 857dcadd..07b0a373 100644 --- a/src/finch/_base_client.py +++ b/src/finch/_base_client.py @@ -5,6 +5,7 @@ import uuid import inspect import platform +from types import TracebackType from random import random from typing import ( Any, @@ -677,6 +678,27 @@ def __init__( headers={"Accept": "application/json"}, ) + def is_closed(self) -> bool: + return self._client.is_closed + + def close(self) -> None: + """Close the underlying HTTPX client. + + The client will *not* be usable after this. + """ + self._client.close() + + def __enter__(self: _T) -> _T: + return self + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: + self.close() + @overload def request( self, @@ -1009,6 +1031,27 @@ def __init__( headers={"Accept": "application/json"}, ) + def is_closed(self) -> bool: + return self._client.is_closed + + async def close(self) -> None: + """Close the underlying HTTPX client. + + The client will *not* be usable after this. + """ + await self._client.aclose() + + async def __aenter__(self: _T) -> _T: + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: + await self.close() + @overload async def request( self, diff --git a/src/finch/_client.py b/src/finch/_client.py index 0180f0f4..49e49696 100644 --- a/src/finch/_client.py +++ b/src/finch/_client.py @@ -3,6 +3,7 @@ from __future__ import annotations import os +import asyncio from typing import Union, Mapping, Optional import httpx @@ -87,6 +88,13 @@ def __init__( - `client_id` from `FINCH_CLIENT_ID` - `client_secret` from `FINCH_CLIENT_SECRET` """ + self.access_token = access_token + + client_id_envvar = os.environ.get("FINCH_CLIENT_ID", None) + self.client_id = client_id or client_id_envvar or None + + client_secret_envvar = os.environ.get("FINCH_CLIENT_SECRET", None) + self.client_secret = client_secret or client_secret_envvar or None if base_url is None: base_url = f"https://api.tryfinch.com" @@ -104,14 +112,6 @@ def __init__( _strict_response_validation=_strict_response_validation, ) - self.access_token = access_token - - client_id_envvar = os.environ.get("FINCH_CLIENT_ID", None) - self.client_id = client_id or client_id_envvar or None - - client_secret_envvar = os.environ.get("FINCH_CLIENT_SECRET", None) - self.client_secret = client_secret or client_secret_envvar or None - self.ats = resources.ATS(self) self.hris = resources.HRIS(self) self.providers = resources.Providers(self) @@ -204,6 +204,9 @@ def copy( # client.with_options(timeout=10).foo.create(...) with_options = copy + def __del__(self) -> None: + self.close() + def get_access_token( self, code: str, @@ -308,6 +311,13 @@ def __init__( - `client_id` from `FINCH_CLIENT_ID` - `client_secret` from `FINCH_CLIENT_SECRET` """ + self.access_token = access_token + + client_id_envvar = os.environ.get("FINCH_CLIENT_ID", None) + self.client_id = client_id or client_id_envvar or None + + client_secret_envvar = os.environ.get("FINCH_CLIENT_SECRET", None) + self.client_secret = client_secret or client_secret_envvar or None if base_url is None: base_url = f"https://api.tryfinch.com" @@ -325,14 +335,6 @@ def __init__( _strict_response_validation=_strict_response_validation, ) - self.access_token = access_token - - client_id_envvar = os.environ.get("FINCH_CLIENT_ID", None) - self.client_id = client_id or client_id_envvar or None - - client_secret_envvar = os.environ.get("FINCH_CLIENT_SECRET", None) - self.client_secret = client_secret or client_secret_envvar or None - self.ats = resources.AsyncATS(self) self.hris = resources.AsyncHRIS(self) self.providers = resources.AsyncProviders(self) @@ -425,6 +427,12 @@ def copy( # client.with_options(timeout=10).foo.create(...) with_options = copy + def __del__(self) -> None: + try: + asyncio.get_running_loop().create_task(self.close()) + except Exception: + pass + async def get_access_token( self, code: str, diff --git a/tests/test_client.py b/tests/test_client.py index 1bc0a33e..1da67a2e 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -4,6 +4,7 @@ import os import json +import asyncio import inspect from typing import Any, Dict, Union, cast @@ -368,6 +369,22 @@ def test_base_url_no_trailing_slash(self) -> None: ) assert request.url == "http://localhost:5000/custom/path/foo" + def test_client_del(self) -> None: + client = Finch(base_url=base_url, access_token=access_token, _strict_response_validation=True) + assert not client.is_closed() + + client.__del__() + + assert client.is_closed() + + def test_client_context_manager(self) -> None: + client = Finch(base_url=base_url, access_token=access_token, _strict_response_validation=True) + with client as c2: + assert c2 is client + assert not c2.is_closed() + assert not client.is_closed() + assert client.is_closed() + class TestAsyncFinch: client = AsyncFinch(base_url=base_url, access_token=access_token, _strict_response_validation=True) @@ -710,3 +727,20 @@ def test_base_url_no_trailing_slash(self) -> None: ), ) assert request.url == "http://localhost:5000/custom/path/foo" + + async def test_client_del(self) -> None: + client = AsyncFinch(base_url=base_url, access_token=access_token, _strict_response_validation=True) + assert not client.is_closed() + + client.__del__() + + await asyncio.sleep(0.2) + assert client.is_closed() + + async def test_client_context_manager(self) -> None: + client = AsyncFinch(base_url=base_url, access_token=access_token, _strict_response_validation=True) + async with client as c2: + assert c2 is client + assert not c2.is_closed() + assert not client.is_closed() + assert client.is_closed()