Skip to content

Commit 0c117ab

Browse files
committed
Force verification to make at most one request
1 parent 214cc76 commit 0c117ab

File tree

6 files changed

+337
-27
lines changed

6 files changed

+337
-27
lines changed

elasticsearch/_async/transport.py

+32-5
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
ConnectionError,
2828
ConnectionTimeout,
2929
ElasticsearchWarning,
30+
NotElasticsearchError,
3031
SerializationError,
3132
TransportError,
3233
)
@@ -117,6 +118,10 @@ async def _async_init(self):
117118
self.loop = get_running_loop()
118119
self.kwargs["loop"] = self.loop
119120

121+
# Set our 'verified_once' implementation to one that
122+
# works with 'asyncio' instead of 'threading'
123+
self._verified_once = Once()
124+
120125
# Now that we have a loop we can create all our HTTP connections...
121126
self.set_connections(self.hosts)
122127
self.seed_connections = list(self.connection_pool.connections[:])
@@ -332,8 +337,17 @@ async def perform_request(self, method, url, headers=None, params=None, body=Non
332337
)
333338

334339
# Before we make the actual API call we verify the Elasticsearch instance.
335-
if not self._verified_elasticsearch:
336-
await self._do_verify_elasticsearch(headers=headers, timeout=timeout)
340+
if self._verified_elasticsearch is None:
341+
await self._verified_once.call(
342+
self._do_verify_elasticsearch, headers=headers, timeout=timeout
343+
)
344+
345+
# If '_verified_elasticsearch' is False we know we're not connected to Elasticsearch.
346+
if self._verified_elasticsearch is False:
347+
raise NotElasticsearchError(
348+
"The client noticed that the server is not Elasticsearch "
349+
"and we do not support this unknown product"
350+
)
337351

338352
for attempt in range(self.max_retries + 1):
339353
connection = self.get_connection()
@@ -471,7 +485,20 @@ async def _do_verify_elasticsearch(self, headers, timeout):
471485
raise error
472486

473487
# Check the information we got back from the index request.
474-
_verify_elasticsearch(info_headers, info_response)
488+
self._verified_elasticsearch = _verify_elasticsearch(
489+
info_headers, info_response
490+
)
491+
492+
493+
class Once:
494+
"""Simple class which forces an async function to only execute once."""
495+
496+
def __init__(self):
497+
self._lock = asyncio.Lock()
498+
self._called = False
475499

476-
# If we made it through the above call this config is verified.
477-
self._verified_elasticsearch = True
500+
async def call(self, func, *args, **kwargs):
501+
async with self._lock:
502+
if not self._called:
503+
self._called = True
504+
await func(*args, **kwargs)

elasticsearch/compat.py

+11
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,17 @@ def to_bytes(x, encoding="ascii"):
7070
except (ImportError, AttributeError):
7171
pass
7272

73+
try:
74+
from threading import Lock
75+
except ImportError: # Python <3.7 isn't guaranteed to have threading support.
76+
77+
class Lock:
78+
def __enter__(self):
79+
pass
80+
81+
def __exit__(self, *_):
82+
pass
83+
7384

7485
__all__ = [
7586
"string_types",

elasticsearch/compat.pyi

+3-1
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
# under the License.
1717

1818
import sys
19-
from typing import Callable, Tuple, Type, Union
19+
from typing import Callable, ContextManager, Tuple, Type, Union
2020

2121
PY2: bool
2222
string_types: Tuple[type, ...]
@@ -25,6 +25,8 @@ to_str: Callable[[Union[str, bytes]], str]
2525
to_bytes: Callable[[Union[str, bytes]], bytes]
2626
reraise_exceptions: Tuple[Type[Exception], ...]
2727

28+
Lock: ContextManager[None]
29+
2830
if sys.version_info[0] == 2:
2931
from itertools import imap as map
3032
from urllib import quote as quote

elasticsearch/transport.py

+49-13
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from platform import python_version
2323

2424
from ._version import __versionstr__
25+
from .compat import Lock
2526
from .connection import Urllib3HttpConnection
2627
from .connection_pool import ConnectionPool, DummyConnectionPool, EmptyConnectionPool
2728
from .exceptions import (
@@ -204,9 +205,22 @@ def __init__(
204205
if http_client_meta:
205206
self._client_meta += (http_client_meta,)
206207

207-
# Flag which is set after verifying that we're
208-
# connected to Elasticsearch.
209-
self._verified_elasticsearch = False
208+
# Tri-state flag that describes what state the verification
209+
# of whether we're connected to an Elasticsearch cluster or not.
210+
# The three states are:
211+
# - 'None': Means we've either not started the verification process
212+
# or that the verification is in progress. '_verified_once' ensures
213+
# that multiple requests don't kick off multiple verification processes.
214+
# - 'True': Means we've verified that we're talking to Elasticsearch or
215+
# that we can't rule out Elasticsearch due to auth issues. A warning
216+
# will be raised if we receive 401/403.
217+
# - 'False': Means we've discovered we're not talking to Elasticsearch,
218+
# should raise an error in this case for every request.
219+
self._verified_elasticsearch = None
220+
221+
# Ensures that the ES verification request only fires once and that
222+
# all requests block until this request returns back.
223+
self._verified_once = Once()
210224

211225
def add_connection(self, host):
212226
"""
@@ -391,7 +405,17 @@ def perform_request(self, method, url, headers=None, params=None, body=None):
391405
)
392406

393407
# Before we make the actual API call we verify the Elasticsearch instance.
394-
self._do_verify_elasticsearch(headers=headers, timeout=timeout)
408+
if self._verified_elasticsearch is None:
409+
self._verified_once.call(
410+
self._do_verify_elasticsearch, headers=headers, timeout=timeout
411+
)
412+
413+
# If '_verified_elasticsearch' is False we know we're not connected to Elasticsearch.
414+
if self._verified_elasticsearch is False:
415+
raise NotElasticsearchError(
416+
"The client noticed that the server is not Elasticsearch "
417+
"and we do not support this unknown product"
418+
)
395419

396420
for attempt in range(self.max_retries + 1):
397421
connection = self.get_connection()
@@ -513,7 +537,7 @@ def _do_verify_elasticsearch(self, headers, timeout):
513537
error we instead emit an 'ElasticsearchWarning'.
514538
"""
515539
# Product check has already been done, no need to do again.
516-
if self._verified_elasticsearch:
540+
if self._verified_elasticsearch is not None:
517541
return
518542

519543
headers = {header.lower(): value for header, value in (headers or {}).items()}
@@ -566,10 +590,9 @@ def _do_verify_elasticsearch(self, headers, timeout):
566590
raise error
567591

568592
# Check the information we got back from the index request.
569-
_verify_elasticsearch(info_headers, info_response)
570-
571-
# If we made it through the above call this config is verified.
572-
self._verified_elasticsearch = True
593+
self._verified_elasticsearch = _verify_elasticsearch(
594+
info_headers, info_response
595+
)
573596

574597

575598
def _verify_elasticsearch(headers, response):
@@ -616,7 +639,20 @@ def _verify_elasticsearch(headers, response):
616639
# 7.14+ and there's a bad 'X-Elastic-Product' HTTP header
617640
or ((7, 14, 0) <= version_number and bad_product_header)
618641
):
619-
raise NotElasticsearchError(
620-
"The client noticed that the server is not Elasticsearch "
621-
"and we do not support this unknown product"
622-
)
642+
return False
643+
644+
return True
645+
646+
647+
class Once:
648+
"""Simple class which forces a function to only execute once."""
649+
650+
def __init__(self):
651+
self._lock = Lock()
652+
self._called = False
653+
654+
def call(self, func, *args, **kwargs):
655+
with self._lock:
656+
if not self._called:
657+
self._called = True
658+
func(*args, **kwargs)

test_elasticsearch/test_async/test_transport.py

+109
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
AuthorizationException,
3434
ConnectionError,
3535
ElasticsearchWarning,
36+
NotElasticsearchError,
3637
TransportError,
3738
)
3839

@@ -714,3 +715,111 @@ async def test_verify_elasticsearch_skips_on_auth_errors(self, exception_cls):
714715
},
715716
),
716717
]
718+
719+
async def test_multiple_requests_verify_elasticsearch_success(self, event_loop):
720+
t = AsyncTransport(
721+
[
722+
{
723+
"data": '{"version":{"number":"7.13.0","build_flavor":"default"},"tagline":"You Know, for Search"}',
724+
"delay": 1,
725+
}
726+
],
727+
connection_class=DummyConnection,
728+
)
729+
730+
results = []
731+
completed_at = []
732+
733+
async def request_task():
734+
try:
735+
results.append(await t.perform_request("GET", "/_search"))
736+
except Exception as e:
737+
results.append(e)
738+
completed_at.append(event_loop.time())
739+
740+
# Execute a bunch of requests concurrently.
741+
tasks = []
742+
start_time = event_loop.time()
743+
for _ in range(10):
744+
tasks.append(event_loop.create_task(request_task()))
745+
await asyncio.gather(*tasks)
746+
end_time = event_loop.time()
747+
748+
# Exactly 10 results completed
749+
assert len(results) == 10
750+
751+
# No errors in the results
752+
assert all(isinstance(result, dict) for result in results)
753+
754+
# Assert that this took longer than 2 seconds but less than 2.1 seconds
755+
duration = end_time - start_time
756+
assert 2 <= duration <= 2.1
757+
758+
# Assert that every result came after ~2 seconds, no fast completions.
759+
assert all(
760+
2 <= completed_time - start_time <= 2.1 for completed_time in completed_at
761+
)
762+
763+
# Assert that the cluster is "verified"
764+
assert t._verified_elasticsearch
765+
766+
# See that the first request is always 'GET /' for ES check
767+
calls = t.connection_pool.connections[0].calls
768+
assert calls[0][0] == ("GET", "/")
769+
770+
# The rest of the requests are 'GET /_search' afterwards
771+
assert all(call[0][:2] == ("GET", "/_search") for call in calls[1:])
772+
773+
async def test_multiple_requests_verify_elasticsearch_errors(self, event_loop):
774+
t = AsyncTransport(
775+
[
776+
{
777+
"data": '{"version":{"number":"7.13.0","build_flavor":"default"},"tagline":"BAD TAGLINE"}',
778+
"delay": 1,
779+
}
780+
],
781+
connection_class=DummyConnection,
782+
)
783+
784+
results = []
785+
completed_at = []
786+
787+
async def request_task():
788+
try:
789+
results.append(await t.perform_request("GET", "/_search"))
790+
except Exception as e:
791+
results.append(e)
792+
completed_at.append(event_loop.time())
793+
794+
# Execute a bunch of requests concurrently.
795+
tasks = []
796+
start_time = event_loop.time()
797+
for _ in range(10):
798+
tasks.append(event_loop.create_task(request_task()))
799+
await asyncio.gather(*tasks)
800+
end_time = event_loop.time()
801+
802+
# Exactly 10 results completed
803+
assert len(results) == 10
804+
805+
# All results were errors
806+
assert all(isinstance(result, NotElasticsearchError) for result in results)
807+
808+
# Assert that one request was made but not 2 requests.
809+
duration = end_time - start_time
810+
assert 1 <= duration <= 1.1
811+
812+
# Assert that every result came after ~1 seconds, no fast completions.
813+
assert all(
814+
1 <= completed_time - start_time <= 1.1 for completed_time in completed_at
815+
)
816+
817+
# Assert that the cluster is definitely not Elasticsearch
818+
assert t._verified_elasticsearch is False
819+
820+
# See that the first request is always 'GET /' for ES check
821+
calls = t.connection_pool.connections[0].calls
822+
assert calls[0][0] == ("GET", "/")
823+
824+
# The rest of the requests are 'GET /_search' afterwards
825+
assert all(call[0][:2] == ("GET", "/_search") for call in calls[1:])

0 commit comments

Comments
 (0)