Skip to content

Commit 012f986

Browse files
bdracowebknjaz
andauthored
Fix AsyncResolver to match ThreadedResolver behavior (#8270)
Co-authored-by: Sviatoslav Sydorenko (Святослав Сидоренко) <[email protected]>
1 parent 28f1fd8 commit 012f986

File tree

10 files changed

+334
-76
lines changed

10 files changed

+334
-76
lines changed

CHANGES/8270.bugfix.rst

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
Fix ``AsyncResolver`` to match ``ThreadedResolver`` behavior
2+
-- by :user:`bdraco`.
3+
4+
On system with IPv6 support, the :py:class:`~aiohttp.resolver.AsyncResolver` would not fallback
5+
to providing A records when AAAA records were not available.
6+
Additionally, unlike the :py:class:`~aiohttp.resolver.ThreadedResolver`, the :py:class:`~aiohttp.resolver.AsyncResolver`
7+
did not handle link-local addresses correctly.
8+
9+
This change makes the behavior consistent with the :py:class:`~aiohttp.resolver.ThreadedResolver`.

aiohttp/abc.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import logging
2+
import socket
23
from abc import ABC, abstractmethod
34
from collections.abc import Sized
45
from http.cookies import BaseCookie, Morsel
@@ -13,6 +14,7 @@
1314
List,
1415
Optional,
1516
Tuple,
17+
TypedDict,
1618
)
1719

1820
from multidict import CIMultiDict
@@ -117,11 +119,35 @@ def __await__(self) -> Generator[Any, None, StreamResponse]:
117119
"""Execute the view handler."""
118120

119121

122+
class ResolveResult(TypedDict):
123+
"""Resolve result.
124+
125+
This is the result returned from an AbstractResolver's
126+
resolve method.
127+
128+
:param hostname: The hostname that was provided.
129+
:param host: The IP address that was resolved.
130+
:param port: The port that was resolved.
131+
:param family: The address family that was resolved.
132+
:param proto: The protocol that was resolved.
133+
:param flags: The flags that were resolved.
134+
"""
135+
136+
hostname: str
137+
host: str
138+
port: int
139+
family: int
140+
proto: int
141+
flags: int
142+
143+
120144
class AbstractResolver(ABC):
121145
"""Abstract DNS resolver."""
122146

123147
@abstractmethod
124-
async def resolve(self, host: str, port: int, family: int) -> List[Dict[str, Any]]:
148+
async def resolve(
149+
self, host: str, port: int = 0, family: int = socket.AF_INET
150+
) -> List[ResolveResult]:
125151
"""Return IP address for given hostname"""
126152

127153
@abstractmethod

aiohttp/connector.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
import aiohappyeyeballs
3636

3737
from . import hdrs, helpers
38-
from .abc import AbstractResolver
38+
from .abc import AbstractResolver, ResolveResult
3939
from .client_exceptions import (
4040
ClientConnectionError,
4141
ClientConnectorCertificateError,
@@ -674,14 +674,14 @@ async def _create_connection(
674674

675675
class _DNSCacheTable:
676676
def __init__(self, ttl: Optional[float] = None) -> None:
677-
self._addrs_rr: Dict[Tuple[str, int], Tuple[Iterator[Dict[str, Any]], int]] = {}
677+
self._addrs_rr: Dict[Tuple[str, int], Tuple[Iterator[ResolveResult], int]] = {}
678678
self._timestamps: Dict[Tuple[str, int], float] = {}
679679
self._ttl = ttl
680680

681681
def __contains__(self, host: object) -> bool:
682682
return host in self._addrs_rr
683683

684-
def add(self, key: Tuple[str, int], addrs: List[Dict[str, Any]]) -> None:
684+
def add(self, key: Tuple[str, int], addrs: List[ResolveResult]) -> None:
685685
self._addrs_rr[key] = (cycle(addrs), len(addrs))
686686

687687
if self._ttl is not None:
@@ -697,7 +697,7 @@ def clear(self) -> None:
697697
self._addrs_rr.clear()
698698
self._timestamps.clear()
699699

700-
def next_addrs(self, key: Tuple[str, int]) -> List[Dict[str, Any]]:
700+
def next_addrs(self, key: Tuple[str, int]) -> List[ResolveResult]:
701701
loop, length = self._addrs_rr[key]
702702
addrs = list(islice(loop, length))
703703
# Consume one more element to shift internal state of `cycle`
@@ -813,7 +813,7 @@ def clear_dns_cache(
813813

814814
async def _resolve_host(
815815
self, host: str, port: int, traces: Optional[List["Trace"]] = None
816-
) -> List[Dict[str, Any]]:
816+
) -> List[ResolveResult]:
817817
"""Resolve host and return list of addresses."""
818818
if is_ip_address(host):
819819
return [
@@ -868,7 +868,7 @@ async def _resolve_host(
868868
return await asyncio.shield(resolved_host_task)
869869
except asyncio.CancelledError:
870870

871-
def drop_exception(fut: "asyncio.Future[List[Dict[str, Any]]]") -> None:
871+
def drop_exception(fut: "asyncio.Future[List[ResolveResult]]") -> None:
872872
with suppress(Exception, asyncio.CancelledError):
873873
fut.result()
874874

@@ -881,7 +881,7 @@ async def _resolve_host_with_throttle(
881881
host: str,
882882
port: int,
883883
traces: Optional[List["Trace"]],
884-
) -> List[Dict[str, Any]]:
884+
) -> List[ResolveResult]:
885885
"""Resolve host with a dns events throttle."""
886886
if key in self._throttle_dns_events:
887887
# get event early, before any await (#4014)
@@ -1129,7 +1129,7 @@ async def _start_tls_connection(
11291129
return tls_transport, tls_proto
11301130

11311131
def _convert_hosts_to_addr_infos(
1132-
self, hosts: List[Dict[str, Any]]
1132+
self, hosts: List[ResolveResult]
11331133
) -> List[aiohappyeyeballs.AddrInfoType]:
11341134
"""Converts the list of hosts to a list of addr_infos.
11351135

aiohttp/resolver.py

Lines changed: 62 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,25 @@
11
import asyncio
22
import socket
3-
from typing import Any, Dict, List, Type, Union
3+
import sys
4+
from typing import Any, List, Tuple, Type, Union
45

5-
from .abc import AbstractResolver
6+
from .abc import AbstractResolver, ResolveResult
67

78
__all__ = ("ThreadedResolver", "AsyncResolver", "DefaultResolver")
89

910
try:
1011
import aiodns
1112

12-
# aiodns_default = hasattr(aiodns.DNSResolver, 'gethostbyname')
13+
# aiodns_default = hasattr(aiodns.DNSResolver, 'getaddrinfo')
1314
except ImportError: # pragma: no cover
1415
aiodns = None
1516

17+
1618
aiodns_default = False
1719

20+
_NUMERIC_SOCKET_FLAGS = socket.AI_NUMERICHOST | socket.AI_NUMERICSERV
21+
_SUPPORTS_SCOPE_ID = sys.version_info >= (3, 9, 0)
22+
1823

1924
class ThreadedResolver(AbstractResolver):
2025
"""Threaded resolver.
@@ -27,45 +32,45 @@ def __init__(self) -> None:
2732
self._loop = asyncio.get_running_loop()
2833

2934
async def resolve(
30-
self, hostname: str, port: int = 0, family: int = socket.AF_INET
31-
) -> List[Dict[str, Any]]:
35+
self, host: str, port: int = 0, family: int = socket.AF_INET
36+
) -> List[ResolveResult]:
3237
infos = await self._loop.getaddrinfo(
33-
hostname,
38+
host,
3439
port,
3540
type=socket.SOCK_STREAM,
3641
family=family,
3742
flags=socket.AI_ADDRCONFIG,
3843
)
3944

40-
hosts = []
45+
hosts: List[ResolveResult] = []
4146
for family, _, proto, _, address in infos:
4247
if family == socket.AF_INET6:
4348
if len(address) < 3:
4449
# IPv6 is not supported by Python build,
4550
# or IPv6 is not enabled in the host
4651
continue
47-
if address[3]:
52+
if address[3] and _SUPPORTS_SCOPE_ID:
4853
# This is essential for link-local IPv6 addresses.
4954
# LL IPv6 is a VERY rare case. Strictly speaking, we should use
5055
# getnameinfo() unconditionally, but performance makes sense.
51-
host, _port = socket.getnameinfo(
52-
address, socket.NI_NUMERICHOST | socket.NI_NUMERICSERV
56+
resolved_host, _port = await self._loop.getnameinfo(
57+
address, _NUMERIC_SOCKET_FLAGS
5358
)
5459
port = int(_port)
5560
else:
56-
host, port = address[:2]
61+
resolved_host, port = address[:2]
5762
else: # IPv4
5863
assert family == socket.AF_INET
59-
host, port = address # type: ignore[misc]
64+
resolved_host, port = address # type: ignore[misc]
6065
hosts.append(
61-
{
62-
"hostname": hostname,
63-
"host": host,
64-
"port": port,
65-
"family": family,
66-
"proto": proto,
67-
"flags": socket.AI_NUMERICHOST | socket.AI_NUMERICSERV,
68-
}
66+
ResolveResult(
67+
hostname=host,
68+
host=resolved_host,
69+
port=port,
70+
family=family,
71+
proto=proto,
72+
flags=_NUMERIC_SOCKET_FLAGS,
73+
)
6974
)
7075

7176
return hosts
@@ -86,23 +91,48 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
8691

8792
async def resolve(
8893
self, host: str, port: int = 0, family: int = socket.AF_INET
89-
) -> List[Dict[str, Any]]:
94+
) -> List[ResolveResult]:
9095
try:
91-
resp = await self._resolver.gethostbyname(host, family)
96+
resp = await self._resolver.getaddrinfo(
97+
host,
98+
port=port,
99+
type=socket.SOCK_STREAM,
100+
family=family,
101+
flags=socket.AI_ADDRCONFIG,
102+
)
92103
except aiodns.error.DNSError as exc:
93104
msg = exc.args[1] if len(exc.args) >= 1 else "DNS lookup failed"
94105
raise OSError(msg) from exc
95-
hosts = []
96-
for address in resp.addresses:
106+
hosts: List[ResolveResult] = []
107+
for node in resp.nodes:
108+
address: Union[Tuple[bytes, int], Tuple[bytes, int, int, int]] = node.addr
109+
family = node.family
110+
if family == socket.AF_INET6:
111+
if len(address) > 3 and address[3] and _SUPPORTS_SCOPE_ID:
112+
# This is essential for link-local IPv6 addresses.
113+
# LL IPv6 is a VERY rare case. Strictly speaking, we should use
114+
# getnameinfo() unconditionally, but performance makes sense.
115+
result = await self._resolver.getnameinfo(
116+
(address[0].decode("ascii"), *address[1:]),
117+
_NUMERIC_SOCKET_FLAGS,
118+
)
119+
resolved_host = result.node
120+
else:
121+
resolved_host = address[0].decode("ascii")
122+
port = address[1]
123+
else: # IPv4
124+
assert family == socket.AF_INET
125+
resolved_host = address[0].decode("ascii")
126+
port = address[1]
97127
hosts.append(
98-
{
99-
"hostname": host,
100-
"host": address,
101-
"port": port,
102-
"family": family,
103-
"proto": 0,
104-
"flags": socket.AI_NUMERICHOST | socket.AI_NUMERICSERV,
105-
}
128+
ResolveResult(
129+
hostname=host,
130+
host=resolved_host,
131+
port=port,
132+
family=family,
133+
proto=0,
134+
flags=_NUMERIC_SOCKET_FLAGS,
135+
)
106136
)
107137

108138
if not hosts:

docs/abc.rst

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,3 +181,57 @@ Abstract Access Logger
181181
:param response: :class:`aiohttp.web.Response` object.
182182

183183
:param float time: Time taken to serve the request.
184+
185+
186+
Abstract Resolver
187+
-------------------------------
188+
189+
.. class:: AbstractResolver
190+
191+
An abstract class, base for all resolver implementations.
192+
193+
Method ``resolve`` should be overridden.
194+
195+
.. method:: resolve(host, port, family)
196+
197+
Resolve host name to IP address.
198+
199+
:param str host: host name to resolve.
200+
201+
:param int port: port number.
202+
203+
:param int family: socket family.
204+
205+
:return: list of :class:`aiohttp.abc.ResolveResult` instances.
206+
207+
.. method:: close()
208+
209+
Release resolver.
210+
211+
.. class:: ResolveResult
212+
213+
Result of host name resolution.
214+
215+
.. attribute:: hostname
216+
217+
The host name that was provided.
218+
219+
.. attribute:: host
220+
221+
The IP address that was resolved.
222+
223+
.. attribute:: port
224+
225+
The port that was resolved.
226+
227+
.. attribute:: family
228+
229+
The address family that was resolved.
230+
231+
.. attribute:: proto
232+
233+
The protocol that was resolved.
234+
235+
.. attribute:: flags
236+
237+
The flags that were resolved.

docs/conf.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -393,7 +393,8 @@
393393
("py:class", "aiohttp.protocol.HttpVersion"), # undocumented
394394
("py:class", "aiohttp.ClientRequest"), # undocumented
395395
("py:class", "aiohttp.payload.Payload"), # undocumented
396-
("py:class", "aiohttp.abc.AbstractResolver"), # undocumented
396+
("py:class", "aiohttp.resolver.AsyncResolver"), # undocumented
397+
("py:class", "aiohttp.resolver.ThreadedResolver"), # undocumented
397398
("py:func", "aiohttp.ws_connect"), # undocumented
398399
("py:meth", "start"), # undocumented
399400
("py:exc", "aiohttp.ClientHttpProxyError"), # undocumented

examples/fake_server.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,10 @@
33
import pathlib
44
import socket
55
import ssl
6-
from typing import Any, Dict, List, Union
6+
from typing import Dict, List, Union
77

88
from aiohttp import ClientSession, TCPConnector, resolver, test_utils, web
9-
from aiohttp.abc import AbstractResolver
9+
from aiohttp.abc import AbstractResolver, ResolveResult
1010

1111

1212
class FakeResolver(AbstractResolver):
@@ -22,7 +22,7 @@ async def resolve(
2222
host: str,
2323
port: int = 0,
2424
family: Union[socket.AddressFamily, int] = socket.AF_INET,
25-
) -> List[Dict[str, Any]]:
25+
) -> List[ResolveResult]:
2626
fake_port = self._fakes.get(host)
2727
if fake_port is not None:
2828
return [

requirements/runtime-deps.in

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Extracted from `setup.cfg` via `make sync-direct-runtime-deps`
22

3-
aiodns >= 1.1; sys_platform=="linux" or sys_platform=="darwin"
3+
aiodns >= 3.2.0; sys_platform=="linux" or sys_platform=="darwin"
44
aiohappyeyeballs >= 2.3.0
55
aiosignal >= 1.1.2
66
async-timeout >= 4.0, < 5.0 ; python_version < "3.11"

setup.cfg

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ install_requires =
6464
[options.extras_require]
6565
speedups =
6666
# required c-ares (aiodns' backend) will not build on windows
67-
aiodns >= 1.1; sys_platform=="linux" or sys_platform=="darwin"
67+
aiodns >= 3.2.0; sys_platform=="linux" or sys_platform=="darwin"
6868
Brotli; platform_python_implementation == 'CPython'
6969
brotlicffi; platform_python_implementation != 'CPython'
7070

0 commit comments

Comments
 (0)