Skip to content

Commit cbdfc00

Browse files
committed
Add AsyncTransport
1 parent 2f32100 commit cbdfc00

File tree

5 files changed

+831
-38
lines changed

5 files changed

+831
-38
lines changed

elasticsearch/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,8 @@
7272
raise ImportError
7373

7474
from ._async.http_aiohttp import AIOHttpConnection
75+
from ._async.transport import AsyncTransport
7576

76-
__all__ += ["AIOHttpConnection"]
77+
__all__ += ["AIOHttpConnection", "AsyncTransport"]
7778
except (ImportError, SyntaxError):
7879
pass

elasticsearch/_async/transport.py

Lines changed: 327 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,327 @@
1+
# Licensed to Elasticsearch B.V under one or more agreements.
2+
# Elasticsearch B.V licenses this file to you under the Apache 2.0 License.
3+
# See the LICENSE file in the project root for more information
4+
5+
import asyncio
6+
import logging
7+
from itertools import chain
8+
9+
from .compat import get_running_loop
10+
from .http_aiohttp import AIOHttpConnection
11+
from ..transport import Transport
12+
from ..exceptions import (
13+
TransportError,
14+
ConnectionTimeout,
15+
ConnectionError,
16+
SerializationError,
17+
)
18+
19+
20+
logger = logging.getLogger("elasticsearch")
21+
22+
23+
class AsyncTransport(Transport):
24+
"""
25+
Encapsulation of transport-related to logic. Handles instantiation of the
26+
individual connections as well as creating a connection pool to hold them.
27+
28+
Main interface is the `perform_request` method.
29+
"""
30+
31+
DEFAULT_CONNECTION_CLASS = AIOHttpConnection
32+
33+
def __init__(self, hosts, *args, sniff_on_start=False, **kwargs):
34+
"""
35+
:arg hosts: list of dictionaries, each containing keyword arguments to
36+
create a `connection_class` instance
37+
:arg connection_class: subclass of :class:`~elasticsearch.Connection` to use
38+
:arg connection_pool_class: subclass of :class:`~elasticsearch.ConnectionPool` to use
39+
:arg host_info_callback: callback responsible for taking the node information from
40+
`/_cluster/nodes`, along with already extracted information, and
41+
producing a list of arguments (same as `hosts` parameter)
42+
:arg sniff_on_start: flag indicating whether to obtain a list of nodes
43+
from the cluster at startup time
44+
:arg sniffer_timeout: number of seconds between automatic sniffs
45+
:arg sniff_on_connection_fail: flag controlling if connection failure triggers a sniff
46+
:arg sniff_timeout: timeout used for the sniff request - it should be a
47+
fast api call and we are talking potentially to more nodes so we want
48+
to fail quickly. Not used during initial sniffing (if
49+
``sniff_on_start`` is on) when the connection still isn't
50+
initialized.
51+
:arg serializer: serializer instance
52+
:arg serializers: optional dict of serializer instances that will be
53+
used for deserializing data coming from the server. (key is the mimetype)
54+
:arg default_mimetype: when no mimetype is specified by the server
55+
response assume this mimetype, defaults to `'application/json'`
56+
:arg max_retries: maximum number of retries before an exception is propagated
57+
:arg retry_on_status: set of HTTP status codes on which we should retry
58+
on a different node. defaults to ``(502, 503, 504)``
59+
:arg retry_on_timeout: should timeout trigger a retry on different
60+
node? (default `False`)
61+
:arg send_get_body_as: for GET requests with body this option allows
62+
you to specify an alternate way of execution for environments that
63+
don't support passing bodies with GET requests. If you set this to
64+
'POST' a POST method will be used instead, if to 'source' then the body
65+
will be serialized and passed as a query parameter `source`.
66+
67+
Any extra keyword arguments will be passed to the `connection_class`
68+
when creating and instance unless overridden by that connection's
69+
options provided as part of the hosts parameter.
70+
"""
71+
self.sniffing_task = None
72+
self.loop = None
73+
self._async_init_called = False
74+
75+
super(AsyncTransport, self).__init__(
76+
*args, hosts=[], sniff_on_start=False, **kwargs
77+
)
78+
79+
# Don't enable sniffing on Cloud instances.
80+
if kwargs.get("cloud_id", False):
81+
sniff_on_start = False
82+
83+
# Since we defer connections / sniffing to not occur
84+
# within the constructor we never want to signal to
85+
# our parent to 'sniff_on_start' or non-empty 'hosts'.
86+
self.hosts = hosts
87+
self.sniff_on_start = sniff_on_start
88+
89+
async def _async_init(self):
90+
"""This is our stand-in for an async constructor. Everything
91+
that was deferred within __init__() should be done here now.
92+
93+
This method will only be called once per AsyncTransport instance
94+
and is called from one of AsyncElasticsearch.__aenter__(),
95+
AsyncTransport.perform_request() or AsyncTransport.get_connection()
96+
"""
97+
# Detect the async loop we're running in and set it
98+
# on all already created HTTP connections.
99+
self.loop = get_running_loop()
100+
self.kwargs["loop"] = self.loop
101+
102+
# Now that we have a loop we can create all our HTTP connections
103+
self.set_connections(self.hosts)
104+
self.seed_connections = list(self.connection_pool.connections[:])
105+
106+
# ... and we can start sniffing in the background.
107+
if self.sniffing_task is None and self.sniff_on_start:
108+
self.last_sniff = self.loop.time()
109+
self.create_sniff_task(initial=True)
110+
111+
async def _async_call(self):
112+
"""This method is called within any async method of AsyncTransport
113+
where the transport is not closing. This will check to see if we should
114+
call our _async_init() or create a new sniffing task
115+
"""
116+
if not self._async_init_called:
117+
self._async_init_called = True
118+
await self._async_init()
119+
120+
if self.sniffer_timeout:
121+
if self.loop.time() >= self.last_sniff + self.sniff_timeout:
122+
self.create_sniff_task()
123+
124+
async def _get_node_info(self, conn, initial):
125+
try:
126+
# use small timeout for the sniffing request, should be a fast api call
127+
_, headers, node_info = await conn.perform_request(
128+
"GET",
129+
"/_nodes/_all/http",
130+
timeout=self.sniff_timeout if not initial else None,
131+
)
132+
return self.deserializer.loads(node_info, headers.get("content-type"))
133+
except Exception:
134+
pass
135+
return None
136+
137+
async def _get_sniff_data(self, initial=False):
138+
previous_sniff = self.last_sniff
139+
140+
# reset last_sniff timestamp
141+
self.last_sniff = self.loop.time()
142+
143+
# use small timeout for the sniffing request, should be a fast api call
144+
timeout = self.sniff_timeout if not initial else None
145+
146+
def _sniff_request(conn):
147+
return self.loop.create_task(
148+
conn.perform_request("GET", "/_nodes/_all/http", timeout=timeout)
149+
)
150+
151+
# Go through all current connections as well as the
152+
# seed_connections for good measure
153+
tasks = []
154+
for conn in self.connection_pool.connections:
155+
tasks.append(_sniff_request(conn))
156+
for conn in self.seed_connections:
157+
# Ensure that we don't have any duplication within seed_connections.
158+
if conn in self.connection_pool.connections:
159+
continue
160+
tasks.append(_sniff_request(conn))
161+
162+
done = ()
163+
try:
164+
while tasks:
165+
# execute sniff requests in parallel, wait for first to return
166+
done, tasks = await asyncio.wait(
167+
tasks, return_when=asyncio.FIRST_COMPLETED, loop=self.loop
168+
)
169+
# go through all the finished tasks
170+
for t in done:
171+
try:
172+
_, headers, node_info = t.result()
173+
node_info = self.deserializer.loads(
174+
node_info, headers.get("content-type")
175+
)
176+
except (ConnectionError, SerializationError):
177+
continue
178+
node_info = list(node_info["nodes"].values())
179+
return node_info
180+
else:
181+
# no task has finished completely
182+
raise TransportError("N/A", "Unable to sniff hosts.")
183+
except Exception:
184+
# keep the previous value on error
185+
self.last_sniff = previous_sniff
186+
raise
187+
finally:
188+
# Cancel all the pending tasks
189+
for task in chain(done, tasks):
190+
task.cancel()
191+
192+
async def sniff_hosts(self, initial=False):
193+
"""Either spawns a sniffing_task which does regular sniffing
194+
over time or does a single sniffing session and awaits the results.
195+
"""
196+
# Without a loop we can't do anything.
197+
if not self.loop:
198+
return
199+
200+
node_info = await self._get_sniff_data(initial)
201+
hosts = list(filter(None, (self._get_host_info(n) for n in node_info)))
202+
203+
# we weren't able to get any nodes, maybe using an incompatible
204+
# transport_schema or host_info_callback blocked all - raise error.
205+
if not hosts:
206+
raise TransportError(
207+
"N/A", "Unable to sniff hosts - no viable hosts found."
208+
)
209+
210+
# remember current live connections
211+
orig_connections = self.connection_pool.connections[:]
212+
self.set_connections(hosts)
213+
# close those connections that are not in use any more
214+
for c in orig_connections:
215+
if c not in self.connection_pool.connections:
216+
await c.close()
217+
218+
def create_sniff_task(self, initial=False):
219+
"""
220+
Initiate a sniffing task. Make sure we only have one sniff request
221+
running at any given time. If a finished sniffing request is around,
222+
collect its result (which can raise its exception).
223+
"""
224+
if self.sniffing_task and self.sniffing_task.done():
225+
try:
226+
if self.sniffing_task is not None:
227+
self.sniffing_task.result()
228+
finally:
229+
self.sniffing_task = None
230+
231+
if self.sniffing_task is None:
232+
self.sniffing_task = self.loop.create_task(self.sniff_hosts(initial))
233+
234+
def mark_dead(self, connection):
235+
"""
236+
Mark a connection as dead (failed) in the connection pool. If sniffing
237+
on failure is enabled this will initiate the sniffing process.
238+
239+
:arg connection: instance of :class:`~elasticsearch.Connection` that failed
240+
"""
241+
self.connection_pool.mark_dead(connection)
242+
if self.sniff_on_connection_fail:
243+
self.create_sniff_task()
244+
245+
def get_connection(self):
246+
return self.connection_pool.get_connection()
247+
248+
async def perform_request(self, method, url, headers=None, params=None, body=None):
249+
"""
250+
Perform the actual request. Retrieve a connection from the connection
251+
pool, pass all the information to it's perform_request method and
252+
return the data.
253+
254+
If an exception was raised, mark the connection as failed and retry (up
255+
to `max_retries` times).
256+
257+
If the operation was successful and the connection used was previously
258+
marked as dead, mark it as live, resetting it's failure count.
259+
260+
:arg method: HTTP method to use
261+
:arg url: absolute url (without host) to target
262+
:arg headers: dictionary of headers, will be handed over to the
263+
underlying :class:`~elasticsearch.Connection` class
264+
:arg params: dictionary of query parameters, will be handed over to the
265+
underlying :class:`~elasticsearch.Connection` class for serialization
266+
:arg body: body of the request, will be serialized using serializer and
267+
passed to the connection
268+
"""
269+
await self._async_call()
270+
271+
method, params, body, ignore, timeout = self._resolve_request_args(
272+
method, params, body
273+
)
274+
275+
for attempt in range(self.max_retries + 1):
276+
connection = self.get_connection()
277+
278+
try:
279+
status, headers, data = await connection.perform_request(
280+
method,
281+
url,
282+
params,
283+
body,
284+
headers=headers,
285+
ignore=ignore,
286+
timeout=timeout,
287+
)
288+
except TransportError as e:
289+
if method == "HEAD" and e.status_code == 404:
290+
return False
291+
292+
retry = False
293+
if isinstance(e, ConnectionTimeout):
294+
retry = self.retry_on_timeout
295+
elif isinstance(e, ConnectionError):
296+
retry = True
297+
elif e.status_code in self.retry_on_status:
298+
retry = True
299+
300+
if retry:
301+
# only mark as dead if we are retrying
302+
self.mark_dead(connection)
303+
# raise exception on last retry
304+
if attempt == self.max_retries:
305+
raise
306+
else:
307+
raise
308+
309+
else:
310+
if method == "HEAD":
311+
return 200 <= status < 300
312+
313+
# connection didn't fail, confirm it's live status
314+
self.connection_pool.mark_live(connection)
315+
if data:
316+
data = self.deserializer.loads(data, headers.get("content-type"))
317+
return data
318+
319+
async def close(self):
320+
"""
321+
Explicitly closes connections
322+
"""
323+
if self.sniffing_task:
324+
self.sniffing_task.cancel()
325+
self.sniffing_task = None
326+
for connection in self.connection_pool.connections:
327+
await connection.close()

elasticsearch/connection_pool.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -256,9 +256,12 @@ def close(self):
256256
"""
257257
Explicitly closes connections
258258
"""
259-
for conn in self.orig_connections:
259+
for conn in self.connections:
260260
conn.close()
261261

262+
def __repr__(self):
263+
return "<%s: %r>" % (type(self).__name__, self.connections)
264+
262265

263266
class DummyConnectionPool(ConnectionPool):
264267
def __init__(self, connections, **kwargs):
@@ -284,3 +287,19 @@ def _noop(self, *args, **kwargs):
284287
pass
285288

286289
mark_dead = mark_live = resurrect = _noop
290+
291+
292+
class EmptyConnectionPool(ConnectionPool):
293+
"""A connection pool that is empty. Errors out if used."""
294+
295+
def __init__(self, *_, **__):
296+
self.connections = []
297+
self.connection_opts = []
298+
299+
def get_connection(self):
300+
raise ImproperlyConfigured("No connections were configured")
301+
302+
def _noop(self, *args, **kwargs):
303+
pass
304+
305+
close = mark_dead = mark_live = resurrect = _noop

0 commit comments

Comments
 (0)