Skip to content

Commit 48b1558

Browse files
Fix regression with connection upgrade (#7879)
Fixes #7867.
1 parent 28d0b06 commit 48b1558

File tree

4 files changed

+32
-11
lines changed

4 files changed

+32
-11
lines changed

CHANGES/7879.bugfix

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Fixed a regression where connection may get closed during upgrade. -- by :user:`Dreamsorcerer`

aiohttp/client_reqrep.py

+8-11
Original file line numberDiff line numberDiff line change
@@ -940,19 +940,14 @@ def _response_eof(self) -> None:
940940
if self._closed:
941941
return
942942

943-
if self._connection is not None:
944-
# websocket, protocol could be None because
945-
# connection could be detached
946-
if (
947-
self._connection.protocol is not None
948-
and self._connection.protocol.upgraded
949-
):
950-
return
951-
952-
self._release_connection()
943+
# protocol could be None because connection could be detached
944+
protocol = self._connection and self._connection.protocol
945+
if protocol is not None and protocol.upgraded:
946+
return
953947

954948
self._closed = True
955949
self._cleanup_writer()
950+
self._release_connection()
956951

957952
@property
958953
def closed(self) -> bool:
@@ -1048,7 +1043,9 @@ async def read(self) -> bytes:
10481043
elif self._released: # Response explicitly released
10491044
raise ClientConnectionError("Connection closed")
10501045

1051-
await self._wait_released() # Underlying connection released
1046+
protocol = self._connection and self._connection.protocol
1047+
if protocol is None or not protocol.upgraded:
1048+
await self._wait_released() # Underlying connection released
10521049
return self._body
10531050

10541051
def get_encoding(self) -> str:

aiohttp/connector.py

+4
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,10 @@ def __del__(self, _warnings: Any = warnings) -> None:
107107
context["source_traceback"] = self._source_traceback
108108
self._loop.call_exception_handler(context)
109109

110+
def __bool__(self) -> Literal[True]:
111+
"""Force subclasses to not be falsy, to make checks simpler."""
112+
return True
113+
110114
@property
111115
def transport(self) -> Optional[asyncio.Transport]:
112116
if self._protocol is None:

tests/test_client_functional.py

+19
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,25 @@ async def handler(request):
174174
assert 1 == len(client._session.connector._conns)
175175

176176

177+
async def test_upgrade_connection_not_released_after_read(aiohttp_client: Any) -> None:
178+
async def handler(request: web.Request) -> web.Response:
179+
body = await request.read()
180+
assert b"" == body
181+
return web.Response(
182+
status=101, headers={"Connection": "Upgrade", "Upgrade": "tcp"}
183+
)
184+
185+
app = web.Application()
186+
app.router.add_route("GET", "/", handler)
187+
188+
client = await aiohttp_client(app)
189+
190+
resp = await client.get("/")
191+
await resp.read()
192+
assert resp.connection is not None
193+
assert not resp.closed
194+
195+
177196
async def test_keepalive_server_force_close_connection(aiohttp_client: Any) -> None:
178197
async def handler(request):
179198
body = await request.read()

0 commit comments

Comments
 (0)