Skip to content

Commit abf6406

Browse files
committed
fix: avoid leaking memory when Client.with_options is used (#316)
Fixes openai/openai-python#865.
1 parent f80056b commit abf6406

File tree

4 files changed

+141
-17
lines changed

4 files changed

+141
-17
lines changed

pyproject.toml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -149,8 +149,6 @@ select = [
149149
"T203",
150150
]
151151
ignore = [
152-
# lru_cache in methods, will be fixed separately
153-
"B019",
154152
# mutable defaults
155153
"B006",
156154
]

src/modern_treasury/_base_client.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -403,14 +403,12 @@ def _build_headers(self, options: FinalRequestOptions) -> httpx.Headers:
403403
headers_dict = _merge_mappings(self.default_headers, custom_headers)
404404
self._validate_headers(headers_dict, custom_headers)
405405

406+
# headers are case-insensitive while dictionaries are not.
406407
headers = httpx.Headers(headers_dict)
407408

408409
idempotency_header = self._idempotency_header
409410
if idempotency_header and options.method.lower() != "get" and idempotency_header not in headers:
410-
if not options.idempotency_key:
411-
options.idempotency_key = self._idempotency_key()
412-
413-
headers[idempotency_header] = options.idempotency_key
411+
headers[idempotency_header] = options.idempotency_key or self._idempotency_key()
414412

415413
return headers
416414

@@ -594,16 +592,8 @@ def base_url(self) -> URL:
594592
def base_url(self, url: URL | str) -> None:
595593
self._base_url = self._enforce_trailing_slash(url if isinstance(url, URL) else URL(url))
596594

597-
@lru_cache(maxsize=None)
598595
def platform_headers(self) -> Dict[str, str]:
599-
return {
600-
"X-Stainless-Lang": "python",
601-
"X-Stainless-Package-Version": self._version,
602-
"X-Stainless-OS": str(get_platform()),
603-
"X-Stainless-Arch": str(get_architecture()),
604-
"X-Stainless-Runtime": platform.python_implementation(),
605-
"X-Stainless-Runtime-Version": platform.python_version(),
606-
}
596+
return platform_headers(self._version)
607597

608598
def _calculate_retry_timeout(
609599
self,
@@ -1691,6 +1681,18 @@ def get_platform() -> Platform:
16911681
return "Unknown"
16921682

16931683

1684+
@lru_cache(maxsize=None)
1685+
def platform_headers(version: str) -> Dict[str, str]:
1686+
return {
1687+
"X-Stainless-Lang": "python",
1688+
"X-Stainless-Package-Version": version,
1689+
"X-Stainless-OS": str(get_platform()),
1690+
"X-Stainless-Arch": str(get_architecture()),
1691+
"X-Stainless-Runtime": platform.python_implementation(),
1692+
"X-Stainless-Runtime-Version": platform.python_version(),
1693+
}
1694+
1695+
16941696
class OtherArch:
16951697
def __init__(self, name: str) -> None:
16961698
self.name = name

src/modern_treasury/_client.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -291,7 +291,7 @@ def copy(
291291
api_key=api_key or self.api_key,
292292
organization_id=organization_id or self.organization_id,
293293
webhook_key=webhook_key or self.webhook_key,
294-
base_url=base_url or str(self.base_url),
294+
base_url=base_url or self.base_url,
295295
timeout=self.timeout if isinstance(timeout, NotGiven) else timeout,
296296
http_client=http_client,
297297
connection_pool_limits=connection_pool_limits,
@@ -609,7 +609,7 @@ def copy(
609609
api_key=api_key or self.api_key,
610610
organization_id=organization_id or self.organization_id,
611611
webhook_key=webhook_key or self.webhook_key,
612-
base_url=base_url or str(self.base_url),
612+
base_url=base_url or self.base_url,
613613
timeout=self.timeout if isinstance(timeout, NotGiven) else timeout,
614614
http_client=http_client,
615615
connection_pool_limits=connection_pool_limits,

tests/test_client.py

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,12 @@
22

33
from __future__ import annotations
44

5+
import gc
56
import os
67
import json
78
import asyncio
89
import inspect
10+
import tracemalloc
911
from typing import Any, Union, cast
1012
from unittest import mock
1113

@@ -213,6 +215,67 @@ def test_copy_signature(self) -> None:
213215
copy_param = copy_signature.parameters.get(name)
214216
assert copy_param is not None, f"copy() signature is missing the {name} param"
215217

218+
def test_copy_build_request(self) -> None:
219+
options = FinalRequestOptions(method="get", url="/foo")
220+
221+
def build_request(options: FinalRequestOptions) -> None:
222+
client = self.client.copy()
223+
client._build_request(options)
224+
225+
# ensure that the machinery is warmed up before tracing starts.
226+
build_request(options)
227+
gc.collect()
228+
229+
tracemalloc.start(1000)
230+
231+
snapshot_before = tracemalloc.take_snapshot()
232+
233+
ITERATIONS = 10
234+
for _ in range(ITERATIONS):
235+
build_request(options)
236+
gc.collect()
237+
238+
snapshot_after = tracemalloc.take_snapshot()
239+
240+
tracemalloc.stop()
241+
242+
def add_leak(leaks: list[tracemalloc.StatisticDiff], diff: tracemalloc.StatisticDiff) -> None:
243+
if diff.count == 0:
244+
# Avoid false positives by considering only leaks (i.e. allocations that persist).
245+
return
246+
247+
if diff.count % ITERATIONS != 0:
248+
# Avoid false positives by considering only leaks that appear per iteration.
249+
return
250+
251+
for frame in diff.traceback:
252+
if any(
253+
frame.filename.endswith(fragment)
254+
for fragment in [
255+
# to_raw_response_wrapper leaks through the @functools.wraps() decorator.
256+
#
257+
# removing the decorator fixes the leak for reasons we don't understand.
258+
"modern_treasury/_response.py",
259+
# pydantic.BaseModel.model_dump || pydantic.BaseModel.dict leak memory for some reason.
260+
"modern_treasury/_compat.py",
261+
# Standard library leaks we don't care about.
262+
"/logging/__init__.py",
263+
]
264+
):
265+
return
266+
267+
leaks.append(diff)
268+
269+
leaks: list[tracemalloc.StatisticDiff] = []
270+
for diff in snapshot_after.compare_to(snapshot_before, "traceback"):
271+
add_leak(leaks, diff)
272+
if leaks:
273+
for leak in leaks:
274+
print("MEMORY LEAK:", leak)
275+
for frame in leak.traceback:
276+
print(frame)
277+
raise AssertionError()
278+
216279
def test_request_timeout(self) -> None:
217280
request = self.client._build_request(FinalRequestOptions(method="get", url="/foo"))
218281
timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
@@ -1061,6 +1124,67 @@ def test_copy_signature(self) -> None:
10611124
copy_param = copy_signature.parameters.get(name)
10621125
assert copy_param is not None, f"copy() signature is missing the {name} param"
10631126

1127+
def test_copy_build_request(self) -> None:
1128+
options = FinalRequestOptions(method="get", url="/foo")
1129+
1130+
def build_request(options: FinalRequestOptions) -> None:
1131+
client = self.client.copy()
1132+
client._build_request(options)
1133+
1134+
# ensure that the machinery is warmed up before tracing starts.
1135+
build_request(options)
1136+
gc.collect()
1137+
1138+
tracemalloc.start(1000)
1139+
1140+
snapshot_before = tracemalloc.take_snapshot()
1141+
1142+
ITERATIONS = 10
1143+
for _ in range(ITERATIONS):
1144+
build_request(options)
1145+
gc.collect()
1146+
1147+
snapshot_after = tracemalloc.take_snapshot()
1148+
1149+
tracemalloc.stop()
1150+
1151+
def add_leak(leaks: list[tracemalloc.StatisticDiff], diff: tracemalloc.StatisticDiff) -> None:
1152+
if diff.count == 0:
1153+
# Avoid false positives by considering only leaks (i.e. allocations that persist).
1154+
return
1155+
1156+
if diff.count % ITERATIONS != 0:
1157+
# Avoid false positives by considering only leaks that appear per iteration.
1158+
return
1159+
1160+
for frame in diff.traceback:
1161+
if any(
1162+
frame.filename.endswith(fragment)
1163+
for fragment in [
1164+
# to_raw_response_wrapper leaks through the @functools.wraps() decorator.
1165+
#
1166+
# removing the decorator fixes the leak for reasons we don't understand.
1167+
"modern_treasury/_response.py",
1168+
# pydantic.BaseModel.model_dump || pydantic.BaseModel.dict leak memory for some reason.
1169+
"modern_treasury/_compat.py",
1170+
# Standard library leaks we don't care about.
1171+
"/logging/__init__.py",
1172+
]
1173+
):
1174+
return
1175+
1176+
leaks.append(diff)
1177+
1178+
leaks: list[tracemalloc.StatisticDiff] = []
1179+
for diff in snapshot_after.compare_to(snapshot_before, "traceback"):
1180+
add_leak(leaks, diff)
1181+
if leaks:
1182+
for leak in leaks:
1183+
print("MEMORY LEAK:", leak)
1184+
for frame in leak.traceback:
1185+
print(frame)
1186+
raise AssertionError()
1187+
10641188
async def test_request_timeout(self) -> None:
10651189
request = self.client._build_request(FinalRequestOptions(method="get", url="/foo"))
10661190
timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore

0 commit comments

Comments
 (0)