Skip to content

Commit 7e29e3f

Browse files
committed
fix: avoid leaking memory when Client.with_options is used (#97)
Fixes openai/openai-python#865.
1 parent e2feae9 commit 7e29e3f

File tree

4 files changed

+141
-17
lines changed

4 files changed

+141
-17
lines changed

pyproject.toml

-2
Original file line numberDiff line numberDiff line change
@@ -151,8 +151,6 @@ select = [
151151
"T203",
152152
]
153153
ignore = [
154-
# lru_cache in methods, will be fixed separately
155-
"B019",
156154
# mutable defaults
157155
"B006",
158156
]

src/anthropic_bedrock/_base_client.py

+15-13
Original file line numberDiff line numberDiff line change
@@ -403,14 +403,12 @@ def _build_headers(self, options: FinalRequestOptions) -> httpx.Headers:
403403
headers_dict = _merge_mappings(self.default_headers, custom_headers)
404404
self._validate_headers(headers_dict, custom_headers)
405405

406+
# headers are case-insensitive while dictionaries are not.
406407
headers = httpx.Headers(headers_dict)
407408

408409
idempotency_header = self._idempotency_header
409410
if idempotency_header and options.method.lower() != "get" and idempotency_header not in headers:
410-
if not options.idempotency_key:
411-
options.idempotency_key = self._idempotency_key()
412-
413-
headers[idempotency_header] = options.idempotency_key
411+
headers[idempotency_header] = options.idempotency_key or self._idempotency_key()
414412

415413
return headers
416414

@@ -594,16 +592,8 @@ def base_url(self) -> URL:
594592
def base_url(self, url: URL | str) -> None:
595593
self._base_url = self._enforce_trailing_slash(url if isinstance(url, URL) else URL(url))
596594

597-
@lru_cache(maxsize=None)
598595
def platform_headers(self) -> Dict[str, str]:
599-
return {
600-
"X-Stainless-Lang": "python",
601-
"X-Stainless-Package-Version": self._version,
602-
"X-Stainless-OS": str(get_platform()),
603-
"X-Stainless-Arch": str(get_architecture()),
604-
"X-Stainless-Runtime": platform.python_implementation(),
605-
"X-Stainless-Runtime-Version": platform.python_version(),
606-
}
596+
return platform_headers(self._version)
607597

608598
def _calculate_retry_timeout(
609599
self,
@@ -1691,6 +1681,18 @@ def get_platform() -> Platform:
16911681
return "Unknown"
16921682

16931683

1684+
@lru_cache(maxsize=None)
1685+
def platform_headers(version: str) -> Dict[str, str]:
1686+
return {
1687+
"X-Stainless-Lang": "python",
1688+
"X-Stainless-Package-Version": version,
1689+
"X-Stainless-OS": str(get_platform()),
1690+
"X-Stainless-Arch": str(get_architecture()),
1691+
"X-Stainless-Runtime": platform.python_implementation(),
1692+
"X-Stainless-Runtime-Version": platform.python_version(),
1693+
}
1694+
1695+
16941696
class OtherArch:
16951697
def __init__(self, name: str) -> None:
16961698
self.name = name

src/anthropic_bedrock/_client.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,7 @@ def copy(
191191
aws_access_key=aws_access_key or self.aws_access_key,
192192
aws_region=aws_region or self.aws_region,
193193
aws_session_token=aws_session_token or self.aws_session_token,
194-
base_url=base_url or str(self.base_url),
194+
base_url=base_url or self.base_url,
195195
timeout=self.timeout if isinstance(timeout, NotGiven) else timeout,
196196
http_client=http_client,
197197
max_retries=max_retries if is_given(max_retries) else self.max_retries,
@@ -410,7 +410,7 @@ def copy(
410410
aws_access_key=aws_access_key or self.aws_access_key,
411411
aws_region=aws_region or self.aws_region,
412412
aws_session_token=aws_session_token or self.aws_session_token,
413-
base_url=base_url or str(self.base_url),
413+
base_url=base_url or self.base_url,
414414
timeout=self.timeout if isinstance(timeout, NotGiven) else timeout,
415415
http_client=http_client,
416416
max_retries=max_retries if is_given(max_retries) else self.max_retries,

tests/test_client.py

+124
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,12 @@
22

33
from __future__ import annotations
44

5+
import gc
56
import os
67
import json
78
import asyncio
89
import inspect
10+
import tracemalloc
911
from typing import Any, Union, cast
1012
from unittest import mock
1113

@@ -224,6 +226,67 @@ def test_copy_signature(self) -> None:
224226
copy_param = copy_signature.parameters.get(name)
225227
assert copy_param is not None, f"copy() signature is missing the {name} param"
226228

229+
def test_copy_build_request(self) -> None:
230+
options = FinalRequestOptions(method="get", url="/foo")
231+
232+
def build_request(options: FinalRequestOptions) -> None:
233+
client = self.client.copy()
234+
client._build_request(options)
235+
236+
# ensure that the machinery is warmed up before tracing starts.
237+
build_request(options)
238+
gc.collect()
239+
240+
tracemalloc.start(1000)
241+
242+
snapshot_before = tracemalloc.take_snapshot()
243+
244+
ITERATIONS = 10
245+
for _ in range(ITERATIONS):
246+
build_request(options)
247+
gc.collect()
248+
249+
snapshot_after = tracemalloc.take_snapshot()
250+
251+
tracemalloc.stop()
252+
253+
def add_leak(leaks: list[tracemalloc.StatisticDiff], diff: tracemalloc.StatisticDiff) -> None:
254+
if diff.count == 0:
255+
# Avoid false positives by considering only leaks (i.e. allocations that persist).
256+
return
257+
258+
if diff.count % ITERATIONS != 0:
259+
# Avoid false positives by considering only leaks that appear per iteration.
260+
return
261+
262+
for frame in diff.traceback:
263+
if any(
264+
frame.filename.endswith(fragment)
265+
for fragment in [
266+
# to_raw_response_wrapper leaks through the @functools.wraps() decorator.
267+
#
268+
# removing the decorator fixes the leak for reasons we don't understand.
269+
"anthropic_bedrock/_response.py",
270+
# pydantic.BaseModel.model_dump || pydantic.BaseModel.dict leak memory for some reason.
271+
"anthropic_bedrock/_compat.py",
272+
# Standard library leaks we don't care about.
273+
"/logging/__init__.py",
274+
]
275+
):
276+
return
277+
278+
leaks.append(diff)
279+
280+
leaks: list[tracemalloc.StatisticDiff] = []
281+
for diff in snapshot_after.compare_to(snapshot_before, "traceback"):
282+
add_leak(leaks, diff)
283+
if leaks:
284+
for leak in leaks:
285+
print("MEMORY LEAK:", leak)
286+
for frame in leak.traceback:
287+
print(frame)
288+
raise AssertionError()
289+
227290
def test_request_timeout(self) -> None:
228291
request = self.client._build_request(FinalRequestOptions(method="get", url="/foo"))
229292
timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
@@ -997,6 +1060,67 @@ def test_copy_signature(self) -> None:
9971060
copy_param = copy_signature.parameters.get(name)
9981061
assert copy_param is not None, f"copy() signature is missing the {name} param"
9991062

1063+
def test_copy_build_request(self) -> None:
1064+
options = FinalRequestOptions(method="get", url="/foo")
1065+
1066+
def build_request(options: FinalRequestOptions) -> None:
1067+
client = self.client.copy()
1068+
client._build_request(options)
1069+
1070+
# ensure that the machinery is warmed up before tracing starts.
1071+
build_request(options)
1072+
gc.collect()
1073+
1074+
tracemalloc.start(1000)
1075+
1076+
snapshot_before = tracemalloc.take_snapshot()
1077+
1078+
ITERATIONS = 10
1079+
for _ in range(ITERATIONS):
1080+
build_request(options)
1081+
gc.collect()
1082+
1083+
snapshot_after = tracemalloc.take_snapshot()
1084+
1085+
tracemalloc.stop()
1086+
1087+
def add_leak(leaks: list[tracemalloc.StatisticDiff], diff: tracemalloc.StatisticDiff) -> None:
1088+
if diff.count == 0:
1089+
# Avoid false positives by considering only leaks (i.e. allocations that persist).
1090+
return
1091+
1092+
if diff.count % ITERATIONS != 0:
1093+
# Avoid false positives by considering only leaks that appear per iteration.
1094+
return
1095+
1096+
for frame in diff.traceback:
1097+
if any(
1098+
frame.filename.endswith(fragment)
1099+
for fragment in [
1100+
# to_raw_response_wrapper leaks through the @functools.wraps() decorator.
1101+
#
1102+
# removing the decorator fixes the leak for reasons we don't understand.
1103+
"anthropic_bedrock/_response.py",
1104+
# pydantic.BaseModel.model_dump || pydantic.BaseModel.dict leak memory for some reason.
1105+
"anthropic_bedrock/_compat.py",
1106+
# Standard library leaks we don't care about.
1107+
"/logging/__init__.py",
1108+
]
1109+
):
1110+
return
1111+
1112+
leaks.append(diff)
1113+
1114+
leaks: list[tracemalloc.StatisticDiff] = []
1115+
for diff in snapshot_after.compare_to(snapshot_before, "traceback"):
1116+
add_leak(leaks, diff)
1117+
if leaks:
1118+
for leak in leaks:
1119+
print("MEMORY LEAK:", leak)
1120+
for frame in leak.traceback:
1121+
print(frame)
1122+
raise AssertionError()
1123+
10001124
async def test_request_timeout(self) -> None:
10011125
request = self.client._build_request(FinalRequestOptions(method="get", url="/foo"))
10021126
timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore

0 commit comments

Comments
 (0)