Skip to content

Commit b2c05d4

Browse files
[7.x] AsyncTransport(sniff_on_start=True) should block further requests until complete
Co-authored-by: Seth Michael Larson <[email protected]>
1 parent 20d23a3 commit b2c05d4

File tree

3 files changed

+142
-5
lines changed

3 files changed

+142
-5
lines changed

elasticsearch/_async/transport.py

+51-5
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ def __init__(self, hosts, *args, sniff_on_start=False, **kwargs):
8484
self.sniffing_task = None
8585
self.loop = None
8686
self._async_init_called = False
87+
self._sniff_on_start_event = None # type: asyncio.Event
8788

8889
super(AsyncTransport, self).__init__(
8990
*args, hosts=[], sniff_on_start=False, **kwargs
@@ -112,14 +113,35 @@ async def _async_init(self):
112113
self.loop = get_running_loop()
113114
self.kwargs["loop"] = self.loop
114115

115-
# Now that we have a loop we can create all our HTTP connections
116+
# Now that we have a loop we can create all our HTTP connections...
116117
self.set_connections(self.hosts)
117118
self.seed_connections = list(self.connection_pool.connections[:])
118119

119120
# ... and we can start sniffing in the background.
120121
if self.sniffing_task is None and self.sniff_on_start:
121-
self.last_sniff = self.loop.time()
122-
self.create_sniff_task(initial=True)
122+
123+
# Create an asyncio.Event for future calls to block on
124+
# until the initial sniffing task completes.
125+
self._sniff_on_start_event = asyncio.Event()
126+
127+
try:
128+
self.last_sniff = self.loop.time()
129+
self.create_sniff_task(initial=True)
130+
131+
# Since this is the first one we wait for it to complete
132+
# in case there's an error it'll get raised here.
133+
await self.sniffing_task
134+
135+
# If the task gets cancelled here it likely means the
136+
# transport got closed.
137+
except asyncio.CancelledError:
138+
pass
139+
140+
# Once we exit this section we want to unblock any _async_calls()
141+
# that are blocking on our initial sniff attempt regardless of it
142+
# was successful or not.
143+
finally:
144+
self._sniff_on_start_event.set()
123145

124146
async def _async_call(self):
125147
"""This method is called within any async method of AsyncTransport
@@ -130,6 +152,14 @@ async def _async_call(self):
130152
self._async_init_called = True
131153
await self._async_init()
132154

155+
# If the initial sniff_on_start hasn't returned yet
156+
# then we need to wait for node information to come back
157+
# or for the task to be cancelled via AsyncTransport.close()
158+
if self._sniff_on_start_event and not self._sniff_on_start_event.is_set():
159+
# This is already a no-op if the event is set but we try to
160+
# avoid an 'await' by checking 'not event.is_set()' above first.
161+
await self._sniff_on_start_event.wait()
162+
133163
if self.sniffer_timeout:
134164
if self.loop.time() >= self.last_sniff + self.sniffer_timeout:
135165
self.create_sniff_task()
@@ -187,6 +217,12 @@ def _sniff_request(conn):
187217
for t in done:
188218
try:
189219
_, headers, node_info = t.result()
220+
221+
# Lowercase all the header names for consistency in accessing them.
222+
headers = {
223+
header.lower(): value for header, value in headers.items()
224+
}
225+
190226
node_info = self.deserializer.loads(
191227
node_info, headers.get("content-type")
192228
)
@@ -212,6 +248,8 @@ async def sniff_hosts(self, initial=False):
212248
"""
213249
# Without a loop we can't do anything.
214250
if not self.loop:
251+
if initial:
252+
raise RuntimeError("Event loop not running on initial sniffing task")
215253
return
216254

217255
node_info = await self._get_sniff_data(initial)
@@ -293,7 +331,7 @@ async def perform_request(self, method, url, headers=None, params=None, body=Non
293331
connection = self.get_connection()
294332

295333
try:
296-
status, headers, data = await connection.perform_request(
334+
status, headers_response, data = await connection.perform_request(
297335
method,
298336
url,
299337
params,
@@ -302,6 +340,11 @@ async def perform_request(self, method, url, headers=None, params=None, body=Non
302340
ignore=ignore,
303341
timeout=timeout,
304342
)
343+
344+
# Lowercase all the header names for consistency in accessing them.
345+
headers_response = {
346+
header.lower(): value for header, value in headers_response.items()
347+
}
305348
except TransportError as e:
306349
if method == "HEAD" and e.status_code == 404:
307350
return False
@@ -336,7 +379,9 @@ async def perform_request(self, method, url, headers=None, params=None, body=Non
336379
return 200 <= status < 300
337380

338381
if data:
339-
data = self.deserializer.loads(data, headers.get("content-type"))
382+
data = self.deserializer.loads(
383+
data, headers_response.get("content-type")
384+
)
340385
return data
341386

342387
async def close(self):
@@ -350,5 +395,6 @@ async def close(self):
350395
except asyncio.CancelledError:
351396
pass
352397
self.sniffing_task = None
398+
353399
for connection in self.connection_pool.connections:
354400
await connection.close()

elasticsearch/transport.py

+11
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,12 @@ def _get_sniff_data(self, initial=False):
278278
"/_nodes/_all/http",
279279
timeout=self.sniff_timeout if not initial else None,
280280
)
281+
282+
# Lowercase all the header names for consistency in accessing them.
283+
headers = {
284+
header.lower(): value for header, value in headers.items()
285+
}
286+
281287
node_info = self.deserializer.loads(
282288
node_info, headers.get("content-type")
283289
)
@@ -388,6 +394,11 @@ def perform_request(self, method, url, headers=None, params=None, body=None):
388394
timeout=timeout,
389395
)
390396

397+
# Lowercase all the header names for consistency in accessing them.
398+
headers_response = {
399+
header.lower(): value for header, value in headers_response.items()
400+
}
401+
391402
except TransportError as e:
392403
if method == "HEAD" and e.status_code == 404:
393404
return False

test_elasticsearch/test_async/test_transport.py

+80
Original file line numberDiff line numberDiff line change
@@ -494,3 +494,83 @@ async def test_transport_close_closes_all_pool_connections(self):
494494
assert not any([conn.closed for conn in t.connection_pool.connections])
495495
await t.close()
496496
assert all([conn.closed for conn in t.connection_pool.connections])
497+
498+
async def test_sniff_on_start_error_if_no_sniffed_hosts(self, event_loop):
499+
t = AsyncTransport(
500+
[
501+
{"data": ""},
502+
{"data": ""},
503+
{"data": ""},
504+
],
505+
connection_class=DummyConnection,
506+
sniff_on_start=True,
507+
)
508+
509+
# If our initial sniffing attempt comes back
510+
# empty then we raise an error.
511+
with pytest.raises(TransportError) as e:
512+
await t._async_call()
513+
assert str(e.value) == "TransportError(N/A, 'Unable to sniff hosts.')"
514+
515+
async def test_sniff_on_start_waits_for_sniff_to_complete(self, event_loop):
516+
t = AsyncTransport(
517+
[
518+
{"delay": 1, "data": ""},
519+
{"delay": 1, "data": ""},
520+
{"delay": 1, "data": CLUSTER_NODES},
521+
],
522+
connection_class=DummyConnection,
523+
sniff_on_start=True,
524+
)
525+
526+
# Start the timer right before the first task
527+
# and have a bunch of tasks come in immediately.
528+
tasks = []
529+
start_time = event_loop.time()
530+
for _ in range(5):
531+
tasks.append(event_loop.create_task(t._async_call()))
532+
await asyncio.sleep(0) # Yield to the loop
533+
534+
assert t.sniffing_task is not None
535+
536+
# Tasks streaming in later.
537+
for _ in range(5):
538+
tasks.append(event_loop.create_task(t._async_call()))
539+
await asyncio.sleep(0.1)
540+
541+
# Now that all the API calls have come in we wait for
542+
# them all to resolve before
543+
await asyncio.gather(*tasks)
544+
end_time = event_loop.time()
545+
duration = end_time - start_time
546+
547+
# All the tasks blocked on the sniff of each node
548+
# and then resolved immediately after.
549+
assert 1 <= duration < 2
550+
551+
async def test_sniff_on_start_close_unlocks_async_calls(self, event_loop):
552+
t = AsyncTransport(
553+
[
554+
{"delay": 10, "data": CLUSTER_NODES},
555+
],
556+
connection_class=DummyConnection,
557+
sniff_on_start=True,
558+
)
559+
560+
# Start making _async_calls() before we cancel
561+
tasks = []
562+
start_time = event_loop.time()
563+
for _ in range(3):
564+
tasks.append(event_loop.create_task(t._async_call()))
565+
await asyncio.sleep(0)
566+
567+
# Close the transport while the sniffing task is active! :(
568+
await t.close()
569+
570+
# Now we start waiting on all those _async_calls()
571+
await asyncio.gather(*tasks)
572+
end_time = event_loop.time()
573+
duration = end_time - start_time
574+
575+
# A lot quicker than 10 seconds defined in 'delay'
576+
assert duration < 1

0 commit comments

Comments
 (0)