Skip to content

Commit 4aee7a7

Browse files
chore(internal): minor restructuring of base client (#158)
1 parent e52735d commit 4aee7a7

File tree

2 files changed

+45
-17
lines changed

2 files changed

+45
-17
lines changed

pyproject.toml

+2-2
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ dependencies = [
1313
"typing-extensions>=4.5, <5",
1414
"anyio>=3.5.0, <4",
1515
"distro>=1.7.0, <2",
16-
16+
1717
]
1818
requires-python = ">= 3.7"
1919

@@ -39,7 +39,7 @@ dev-dependencies = [
3939
"time-machine==2.9.0",
4040
"nox==2023.4.22",
4141
"dirty-equals>=0.6.0",
42-
42+
4343
]
4444

4545
[tool.rye.scripts]

src/finch/_base_client.py

+43-15
Original file line numberDiff line numberDiff line change
@@ -399,18 +399,6 @@ def _build_headers(self, options: FinalRequestOptions) -> httpx.Headers:
399399

400400
return headers
401401

402-
def _prepare_request(
403-
self,
404-
request: httpx.Request, # noqa: ARG002
405-
) -> None:
406-
"""This method is used as a callback for mutating the `Request` object
407-
after it has been constructed.
408-
409-
This is useful for cases where you want to add certain headers based off of
410-
the request properties, e.g. `url`, `method` etc.
411-
"""
412-
return None
413-
414402
def _prepare_url(self, url: str) -> URL:
415403
"""
416404
Merge a URL argument together with any 'base_url' on the client,
@@ -463,7 +451,7 @@ def _build_request(
463451
kwargs["data"] = self._serialize_multipartform(json_data)
464452

465453
# TODO: report this error to httpx
466-
request = self._client.build_request( # pyright: ignore[reportUnknownMemberType]
454+
return self._client.build_request( # pyright: ignore[reportUnknownMemberType]
467455
headers=headers,
468456
timeout=self.timeout if isinstance(options.timeout, NotGiven) else options.timeout,
469457
method=options.method,
@@ -477,8 +465,6 @@ def _build_request(
477465
files=options.files,
478466
**kwargs,
479467
)
480-
self._prepare_request(request)
481-
return request
482468

483469
def _serialize_multipartform(self, data: Mapping[object, object]) -> dict[str, object]:
484470
items = self.qs.stringify_items(
@@ -781,6 +767,24 @@ def __exit__(
781767
) -> None:
782768
self.close()
783769

770+
def _prepare_options(
771+
self,
772+
options: FinalRequestOptions, # noqa: ARG002
773+
) -> None:
774+
"""Hook for mutating the given options"""
775+
return None
776+
777+
def _prepare_request(
778+
self,
779+
request: httpx.Request, # noqa: ARG002
780+
) -> None:
781+
"""This method is used as a callback for mutating the `Request` object
782+
after it has been constructed.
783+
This is useful for cases where you want to add certain headers based off of
784+
the request properties, e.g. `url`, `method` etc.
785+
"""
786+
return None
787+
784788
@overload
785789
def request(
786790
self,
@@ -842,8 +846,11 @@ def _request(
842846
stream: bool,
843847
stream_cls: type[_StreamT] | None,
844848
) -> ResponseT | _StreamT:
849+
self._prepare_options(options)
850+
845851
retries = self._remaining_retries(remaining_retries, options)
846852
request = self._build_request(options)
853+
self._prepare_request(request)
847854

848855
try:
849856
response = self._client.send(request, auth=self.custom_auth, stream=stream)
@@ -1201,6 +1208,24 @@ async def __aexit__(
12011208
) -> None:
12021209
await self.close()
12031210

1211+
async def _prepare_options(
1212+
self,
1213+
options: FinalRequestOptions, # noqa: ARG002
1214+
) -> None:
1215+
"""Hook for mutating the given options"""
1216+
return None
1217+
1218+
async def _prepare_request(
1219+
self,
1220+
request: httpx.Request, # noqa: ARG002
1221+
) -> None:
1222+
"""This method is used as a callback for mutating the `Request` object
1223+
after it has been constructed.
1224+
This is useful for cases where you want to add certain headers based off of
1225+
the request properties, e.g. `url`, `method` etc.
1226+
"""
1227+
return None
1228+
12041229
@overload
12051230
async def request(
12061231
self,
@@ -1262,8 +1287,11 @@ async def _request(
12621287
stream_cls: type[_AsyncStreamT] | None,
12631288
remaining_retries: int | None,
12641289
) -> ResponseT | _AsyncStreamT:
1290+
await self._prepare_options(options)
1291+
12651292
retries = self._remaining_retries(remaining_retries, options)
12661293
request = self._build_request(options)
1294+
await self._prepare_request(request)
12671295

12681296
try:
12691297
response = await self._client.send(request, auth=self.custom_auth, stream=stream)

0 commit comments

Comments
 (0)