Skip to content

Commit 89c5cb7

Browse files
committed
Force verification to make at most one request
1 parent 214cc76 commit 89c5cb7

File tree

5 files changed

+336
-30
lines changed

5 files changed

+336
-30
lines changed

elasticsearch/_async/transport.py

+32-5
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
ConnectionError,
2828
ConnectionTimeout,
2929
ElasticsearchWarning,
30+
NotElasticsearchError,
3031
SerializationError,
3132
TransportError,
3233
)
@@ -117,6 +118,10 @@ async def _async_init(self):
117118
self.loop = get_running_loop()
118119
self.kwargs["loop"] = self.loop
119120

121+
# Set our 'verified_once' implementation to one that
122+
# works with 'asyncio' instead of 'threading'
123+
self._verified_once = Once()
124+
120125
# Now that we have a loop we can create all our HTTP connections...
121126
self.set_connections(self.hosts)
122127
self.seed_connections = list(self.connection_pool.connections[:])
@@ -332,8 +337,17 @@ async def perform_request(self, method, url, headers=None, params=None, body=Non
332337
)
333338

334339
# Before we make the actual API call we verify the Elasticsearch instance.
335-
if not self._verified_elasticsearch:
336-
await self._do_verify_elasticsearch(headers=headers, timeout=timeout)
340+
if self._verified_elasticsearch is None:
341+
await self._verified_once.call(
342+
self._do_verify_elasticsearch, headers=headers, timeout=timeout
343+
)
344+
345+
# If '_verified_elasticsearch' is False we know we're not connected to Elasticsearch.
346+
if self._verified_elasticsearch is False:
347+
raise NotElasticsearchError(
348+
"The client noticed that the server is not Elasticsearch "
349+
"and we do not support this unknown product"
350+
)
337351

338352
for attempt in range(self.max_retries + 1):
339353
connection = self.get_connection()
@@ -471,7 +485,20 @@ async def _do_verify_elasticsearch(self, headers, timeout):
471485
raise error
472486

473487
# Check the information we got back from the index request.
474-
_verify_elasticsearch(info_headers, info_response)
488+
self._verified_elasticsearch = _verify_elasticsearch(
489+
info_headers, info_response
490+
)
491+
492+
493+
class Once:
494+
"""Simple class which forces an async function to only execute once."""
495+
496+
def __init__(self):
497+
self._lock = asyncio.Lock()
498+
self._called = False
475499

476-
# If we made it through the above call this config is verified.
477-
self._verified_elasticsearch = True
500+
async def call(self, func, *args, **kwargs):
501+
async with self._lock:
502+
if not self._called:
503+
self._called = True
504+
await func(*args, **kwargs)

elasticsearch/compat.py

+11
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,17 @@ def to_bytes(x, encoding="ascii"):
7070
except (ImportError, AttributeError):
7171
pass
7272

73+
try:
74+
from threading import Lock
75+
except ImportError: # Python <3.7 isn't guaranteed to have threading support.
76+
77+
class Lock:
78+
def __enter__(self):
79+
pass
80+
81+
def __exit__(self, *_):
82+
pass
83+
7384

7485
__all__ = [
7586
"string_types",

elasticsearch/transport.py

+51-17
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from platform import python_version
2323

2424
from ._version import __versionstr__
25+
from .compat import Lock
2526
from .connection import Urllib3HttpConnection
2627
from .connection_pool import ConnectionPool, DummyConnectionPool, EmptyConnectionPool
2728
from .exceptions import (
@@ -204,9 +205,22 @@ def __init__(
204205
if http_client_meta:
205206
self._client_meta += (http_client_meta,)
206207

207-
# Flag which is set after verifying that we're
208-
# connected to Elasticsearch.
209-
self._verified_elasticsearch = False
208+
# Tri-state flag that describes what state the verification
209+
# of whether we're connected to an Elasticsearch cluster or not.
210+
# The three states are:
211+
# - 'None': Means we've either not started the verification process
212+
# or that the verification is in progress. '_verified_once' ensures
213+
# that multiple requests don't kick off multiple verification processes.
214+
# - 'True': Means we've verified that we're talking to Elasticsearch or
215+
# that we can't rule out Elasticsearch due to auth issues. A warning
216+
# will be raised if we receive 401/403.
217+
# - 'False': Means we've discovered we're not talking to Elasticsearch,
218+
# should raise an error in this case for every request.
219+
self._verified_elasticsearch = None
220+
221+
# Ensures that the ES verification request only fires once and that
222+
# all requests block until this request returns back.
223+
self._verified_once = Once()
210224

211225
def add_connection(self, host):
212226
"""
@@ -391,7 +405,17 @@ def perform_request(self, method, url, headers=None, params=None, body=None):
391405
)
392406

393407
# Before we make the actual API call we verify the Elasticsearch instance.
394-
self._do_verify_elasticsearch(headers=headers, timeout=timeout)
408+
if self._verified_elasticsearch is None:
409+
self._verified_once.call(
410+
self._do_verify_elasticsearch, headers=headers, timeout=timeout
411+
)
412+
413+
# If '_verified_elasticsearch' is False we know we're not connected to Elasticsearch.
414+
if self._verified_elasticsearch is False:
415+
raise NotElasticsearchError(
416+
"The client noticed that the server is not Elasticsearch "
417+
"and we do not support this unknown product"
418+
)
395419

396420
for attempt in range(self.max_retries + 1):
397421
connection = self.get_connection()
@@ -513,7 +537,7 @@ def _do_verify_elasticsearch(self, headers, timeout):
513537
error we instead emit an 'ElasticsearchWarning'.
514538
"""
515539
# Product check has already been done, no need to do again.
516-
if self._verified_elasticsearch:
540+
if self._verified_elasticsearch is not None:
517541
return
518542

519543
headers = {header.lower(): value for header, value in (headers or {}).items()}
@@ -566,19 +590,16 @@ def _do_verify_elasticsearch(self, headers, timeout):
566590
raise error
567591

568592
# Check the information we got back from the index request.
569-
_verify_elasticsearch(info_headers, info_response)
570-
571-
# If we made it through the above call this config is verified.
572-
self._verified_elasticsearch = True
593+
self._verified_elasticsearch = _verify_elasticsearch(
594+
info_headers, info_response
595+
)
573596

574597

575598
def _verify_elasticsearch(headers, response):
576599
"""Verifies that the server we're talking to is Elasticsearch.
577600
Does this by checking HTTP headers and the deserialized
578-
response to the 'info' API.
579-
580-
If there's a problem this function raises 'NotElasticsearchError'
581-
otherwise doesn't do anything.
601+
response to the 'info' API. Returns 'True' if we're verified
602+
against Elasticsearch, 'False' otherwise.
582603
"""
583604
try:
584605
version = response.get("version", {})
@@ -616,7 +637,20 @@ def _verify_elasticsearch(headers, response):
616637
# 7.14+ and there's a bad 'X-Elastic-Product' HTTP header
617638
or ((7, 14, 0) <= version_number and bad_product_header)
618639
):
619-
raise NotElasticsearchError(
620-
"The client noticed that the server is not Elasticsearch "
621-
"and we do not support this unknown product"
622-
)
640+
return False
641+
642+
return True
643+
644+
645+
class Once:
646+
"""Simple class which forces a function to only execute once."""
647+
648+
def __init__(self):
649+
self._lock = Lock()
650+
self._called = False
651+
652+
def call(self, func, *args, **kwargs):
653+
with self._lock:
654+
if not self._called:
655+
self._called = True
656+
func(*args, **kwargs)

test_elasticsearch/test_async/test_transport.py

+109
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
AuthorizationException,
3434
ConnectionError,
3535
ElasticsearchWarning,
36+
NotElasticsearchError,
3637
TransportError,
3738
)
3839

@@ -714,3 +715,111 @@ async def test_verify_elasticsearch_skips_on_auth_errors(self, exception_cls):
714715
},
715716
),
716717
]
718+
719+
async def test_multiple_requests_verify_elasticsearch_success(self, event_loop):
720+
t = AsyncTransport(
721+
[
722+
{
723+
"data": '{"version":{"number":"7.13.0","build_flavor":"default"},"tagline":"You Know, for Search"}',
724+
"delay": 1,
725+
}
726+
],
727+
connection_class=DummyConnection,
728+
)
729+
730+
results = []
731+
completed_at = []
732+
733+
async def request_task():
734+
try:
735+
results.append(await t.perform_request("GET", "/_search"))
736+
except Exception as e:
737+
results.append(e)
738+
completed_at.append(event_loop.time())
739+
740+
# Execute a bunch of requests concurrently.
741+
tasks = []
742+
start_time = event_loop.time()
743+
for _ in range(10):
744+
tasks.append(event_loop.create_task(request_task()))
745+
await asyncio.gather(*tasks)
746+
end_time = event_loop.time()
747+
748+
# Exactly 10 results completed
749+
assert len(results) == 10
750+
751+
# No errors in the results
752+
assert all(isinstance(result, dict) for result in results)
753+
754+
# Assert that this took longer than 2 seconds but less than 2.1 seconds
755+
duration = end_time - start_time
756+
assert 2 <= duration <= 2.1
757+
758+
# Assert that every result came after ~2 seconds, no fast completions.
759+
assert all(
760+
2 <= completed_time - start_time <= 2.1 for completed_time in completed_at
761+
)
762+
763+
# Assert that the cluster is "verified"
764+
assert t._verified_elasticsearch
765+
766+
# See that the first request is always 'GET /' for ES check
767+
calls = t.connection_pool.connections[0].calls
768+
assert calls[0][0] == ("GET", "/")
769+
770+
# The rest of the requests are 'GET /_search' afterwards
771+
assert all(call[0][:2] == ("GET", "/_search") for call in calls[1:])
772+
773+
async def test_multiple_requests_verify_elasticsearch_errors(self, event_loop):
774+
t = AsyncTransport(
775+
[
776+
{
777+
"data": '{"version":{"number":"7.13.0","build_flavor":"default"},"tagline":"BAD TAGLINE"}',
778+
"delay": 1,
779+
}
780+
],
781+
connection_class=DummyConnection,
782+
)
783+
784+
results = []
785+
completed_at = []
786+
787+
async def request_task():
788+
try:
789+
results.append(await t.perform_request("GET", "/_search"))
790+
except Exception as e:
791+
results.append(e)
792+
completed_at.append(event_loop.time())
793+
794+
# Execute a bunch of requests concurrently.
795+
tasks = []
796+
start_time = event_loop.time()
797+
for _ in range(10):
798+
tasks.append(event_loop.create_task(request_task()))
799+
await asyncio.gather(*tasks)
800+
end_time = event_loop.time()
801+
802+
# Exactly 10 results completed
803+
assert len(results) == 10
804+
805+
# All results were errors
806+
assert all(isinstance(result, NotElasticsearchError) for result in results)
807+
808+
# Assert that one request was made but not 2 requests.
809+
duration = end_time - start_time
810+
assert 1 <= duration <= 1.1
811+
812+
# Assert that every result came after ~1 seconds, no fast completions.
813+
assert all(
814+
1 <= completed_time - start_time <= 1.1 for completed_time in completed_at
815+
)
816+
817+
# Assert that the cluster is definitely not Elasticsearch
818+
assert t._verified_elasticsearch is False
819+
820+
# See that the first request is always 'GET /' for ES check
821+
calls = t.connection_pool.connections[0].calls
822+
assert calls[0][0] == ("GET", "/")
823+
824+
# The rest of the requests are 'GET /_search' afterwards
825+
assert all(call[0][:2] == ("GET", "/_search") for call in calls[1:])

0 commit comments

Comments
 (0)