Skip to content

Commit b63d005

Browse files
authored
Don't swallow unexpected errors during Elasticsearch verification
1 parent 4be190e commit b63d005

File tree

4 files changed

+245
-139
lines changed

4 files changed

+245
-139
lines changed

elasticsearch/_async/transport.py

+67-70
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ async def _async_init(self):
120120

121121
# Set our 'verified_once' implementation to one that
122122
# works with 'asyncio' instead of 'threading'
123-
self._verified_once = Once()
123+
self._verify_elasticsearch_lock = asyncio.Lock()
124124

125125
# Now that we have a loop we can create all our HTTP connections...
126126
self.set_connections(self.hosts)
@@ -338,9 +338,7 @@ async def perform_request(self, method, url, headers=None, params=None, body=Non
338338

339339
# Before we make the actual API call we verify the Elasticsearch instance.
340340
if self._verified_elasticsearch is None:
341-
await self._verified_once.call(
342-
self._do_verify_elasticsearch, headers=headers, timeout=timeout
343-
)
341+
await self._do_verify_elasticsearch(headers=headers, timeout=timeout)
344342

345343
# If '_verified_elasticsearch' is False we know we're not connected to Elasticsearch.
346344
if self._verified_elasticsearch is False:
@@ -431,74 +429,73 @@ async def _do_verify_elasticsearch(self, headers, timeout):
431429
but we're also unable to rule it out due to a permission
432430
error we instead emit an 'ElasticsearchWarning'.
433431
"""
434-
# Product check has already been done, no need to do again.
435-
if self._verified_elasticsearch:
436-
return
437-
438-
headers = {header.lower(): value for header, value in (headers or {}).items()}
439-
# We know we definitely want JSON so request it via 'accept'
440-
headers.setdefault("accept", "application/json")
432+
# Ensure that there's only one async exec within this section
433+
# at a time to not emit unnecessary index API calls.
434+
async with self._verify_elasticsearch_lock:
441435

442-
info_headers = {}
443-
info_response = {}
444-
error = None
445-
446-
for conn in chain(self.connection_pool.connections, self.seed_connections):
447-
try:
448-
_, info_headers, info_response = await conn.perform_request(
449-
"GET", "/", headers=headers, timeout=timeout
450-
)
451-
452-
# Lowercase all the header names for consistency in accessing them.
453-
info_headers = {
454-
header.lower(): value for header, value in info_headers.items()
455-
}
456-
457-
info_response = self.deserializer.loads(
458-
info_response, mimetype="application/json"
459-
)
460-
break
461-
462-
# Previous versions of 7.x Elasticsearch required a specific
463-
# permission so if we receive HTTP 401/403 we should warn
464-
# instead of erroring out.
465-
except (AuthenticationException, AuthorizationException):
466-
warnings.warn(
467-
(
468-
"The client is unable to verify that the server is "
469-
"Elasticsearch due security privileges on the server side"
470-
),
471-
ElasticsearchWarning,
472-
stacklevel=4,
473-
)
474-
self._verified_elasticsearch = True
436+
# Product check has already been completed while we were
437+
# waiting our turn, no need to do again.
438+
if self._verified_elasticsearch is not None:
475439
return
476440

477-
# This connection didn't work, we'll try another.
478-
except (ConnectionError, SerializationError) as err:
479-
if error is None:
480-
error = err
481-
482-
# If we received a connection error and weren't successful
483-
# anywhere then we reraise the more appropriate error.
484-
if error and not info_response:
485-
raise error
486-
487-
# Check the information we got back from the index request.
488-
self._verified_elasticsearch = _verify_elasticsearch(
489-
info_headers, info_response
490-
)
491-
492-
493-
class Once:
494-
"""Simple class which forces an async function to only execute once."""
441+
headers = {
442+
header.lower(): value for header, value in (headers or {}).items()
443+
}
444+
# We know we definitely want JSON so request it via 'accept'
445+
headers.setdefault("accept", "application/json")
446+
447+
info_headers = {}
448+
info_response = {}
449+
error = None
450+
451+
attempted_conns = []
452+
for conn in chain(self.connection_pool.connections, self.seed_connections):
453+
# Only attempt once per connection max.
454+
if conn in attempted_conns:
455+
continue
456+
attempted_conns.append(conn)
457+
458+
try:
459+
_, info_headers, info_response = await conn.perform_request(
460+
"GET", "/", headers=headers, timeout=timeout
461+
)
495462

496-
def __init__(self):
497-
self._lock = asyncio.Lock()
498-
self._called = False
463+
# Lowercase all the header names for consistency in accessing them.
464+
info_headers = {
465+
header.lower(): value for header, value in info_headers.items()
466+
}
499467

500-
async def call(self, func, *args, **kwargs):
501-
async with self._lock:
502-
if not self._called:
503-
self._called = True
504-
await func(*args, **kwargs)
468+
info_response = self.deserializer.loads(
469+
info_response, mimetype="application/json"
470+
)
471+
break
472+
473+
# Previous versions of 7.x Elasticsearch required a specific
474+
# permission so if we receive HTTP 401/403 we should warn
475+
# instead of erroring out.
476+
except (AuthenticationException, AuthorizationException):
477+
warnings.warn(
478+
(
479+
"The client is unable to verify that the server is "
480+
"Elasticsearch due security privileges on the server side"
481+
),
482+
ElasticsearchWarning,
483+
stacklevel=4,
484+
)
485+
self._verified_elasticsearch = True
486+
return
487+
488+
# This connection didn't work, we'll try another.
489+
except (ConnectionError, SerializationError, TransportError) as err:
490+
if error is None:
491+
error = err
492+
493+
# If we received a connection error and weren't successful
494+
# anywhere then we re-raise the more appropriate error.
495+
if error and not info_response:
496+
raise error
497+
498+
# Check the information we got back from the index request.
499+
self._verified_elasticsearch = _verify_elasticsearch(
500+
info_headers, info_response
501+
)

elasticsearch/transport.py

+64-67
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,7 @@ def __init__(
220220

221221
# Ensures that the ES verification request only fires once and that
222222
# all requests block until this request returns back.
223-
self._verified_once = Once()
223+
self._verify_elasticsearch_lock = Lock()
224224

225225
def add_connection(self, host):
226226
"""
@@ -406,9 +406,7 @@ def perform_request(self, method, url, headers=None, params=None, body=None):
406406

407407
# Before we make the actual API call we verify the Elasticsearch instance.
408408
if self._verified_elasticsearch is None:
409-
self._verified_once.call(
410-
self._do_verify_elasticsearch, headers=headers, timeout=timeout
411-
)
409+
self._do_verify_elasticsearch(headers=headers, timeout=timeout)
412410

413411
# If '_verified_elasticsearch' is False we know we're not connected to Elasticsearch.
414412
if self._verified_elasticsearch is False:
@@ -536,63 +534,76 @@ def _do_verify_elasticsearch(self, headers, timeout):
536534
but we're also unable to rule it out due to a permission
537535
error we instead emit an 'ElasticsearchWarning'.
538536
"""
539-
# Product check has already been done, no need to do again.
540-
if self._verified_elasticsearch is not None:
541-
return
537+
# Ensure that there's only one thread within this section
538+
# at a time to not emit unnecessary index API calls.
539+
with self._verify_elasticsearch_lock:
542540

543-
headers = {header.lower(): value for header, value in (headers or {}).items()}
544-
# We know we definitely want JSON so request it via 'accept'
545-
headers.setdefault("accept", "application/json")
541+
# Product check has already been completed while we were
542+
# waiting our turn, no need to do again.
543+
if self._verified_elasticsearch is not None:
544+
return
546545

547-
info_headers = {}
548-
info_response = {}
549-
error = None
546+
headers = {
547+
header.lower(): value for header, value in (headers or {}).items()
548+
}
549+
# We know we definitely want JSON so request it via 'accept'
550+
headers.setdefault("accept", "application/json")
550551

551-
for conn in chain(self.connection_pool.connections, self.seed_connections):
552-
try:
553-
_, info_headers, info_response = conn.perform_request(
554-
"GET", "/", headers=headers, timeout=timeout
555-
)
552+
info_headers = {}
553+
info_response = {}
554+
error = None
556555

557-
# Lowercase all the header names for consistency in accessing them.
558-
info_headers = {
559-
header.lower(): value for header, value in info_headers.items()
560-
}
556+
attempted_conns = []
557+
for conn in chain(self.connection_pool.connections, self.seed_connections):
558+
# Only attempt once per connection max.
559+
if conn in attempted_conns:
560+
continue
561+
attempted_conns.append(conn)
561562

562-
info_response = self.deserializer.loads(
563-
info_response, mimetype="application/json"
564-
)
565-
break
566-
567-
# Previous versions of 7.x Elasticsearch required a specific
568-
# permission so if we receive HTTP 401/403 we should warn
569-
# instead of erroring out.
570-
except (AuthenticationException, AuthorizationException):
571-
warnings.warn(
572-
(
573-
"The client is unable to verify that the server is "
574-
"Elasticsearch due security privileges on the server side"
575-
),
576-
ElasticsearchWarning,
577-
stacklevel=5,
578-
)
579-
self._verified_elasticsearch = True
580-
return
563+
try:
564+
_, info_headers, info_response = conn.perform_request(
565+
"GET", "/", headers=headers, timeout=timeout
566+
)
581567

582-
# This connection didn't work, we'll try another.
583-
except (ConnectionError, SerializationError) as err:
584-
if error is None:
585-
error = err
568+
# Lowercase all the header names for consistency in accessing them.
569+
info_headers = {
570+
header.lower(): value for header, value in info_headers.items()
571+
}
586572

587-
# If we received a connection error and weren't successful
588-
# anywhere then we reraise the more appropriate error.
589-
if error and not info_response:
590-
raise error
573+
info_response = self.deserializer.loads(
574+
info_response, mimetype="application/json"
575+
)
576+
break
591577

592-
# Check the information we got back from the index request.
593-
self._verified_elasticsearch = _verify_elasticsearch(
594-
info_headers, info_response
595-
)
578+
# Previous versions of 7.x Elasticsearch required a specific
579+
# permission so if we receive HTTP 401/403 we should warn
580+
# instead of erroring out.
581+
except (AuthenticationException, AuthorizationException):
582+
warnings.warn(
583+
(
584+
"The client is unable to verify that the server is "
585+
"Elasticsearch due security privileges on the server side"
586+
),
587+
ElasticsearchWarning,
588+
stacklevel=5,
589+
)
590+
self._verified_elasticsearch = True
591+
return
592+
593+
# This connection didn't work, we'll try another.
594+
except (ConnectionError, SerializationError, TransportError) as err:
595+
if error is None:
596+
error = err
597+
598+
# If we received a connection error and weren't successful
599+
# anywhere then we re-raise the more appropriate error.
600+
if error and not info_response:
601+
raise error
602+
603+
# Check the information we got back from the index request.
604+
self._verified_elasticsearch = _verify_elasticsearch(
605+
info_headers, info_response
606+
)
596607

597608

598609
def _verify_elasticsearch(headers, response):
@@ -640,17 +651,3 @@ def _verify_elasticsearch(headers, response):
640651
return False
641652

642653
return True
643-
644-
645-
class Once:
646-
"""Simple class which forces a function to only execute once."""
647-
648-
def __init__(self):
649-
self._lock = Lock()
650-
self._called = False
651-
652-
def call(self, func, *args, **kwargs):
653-
with self._lock:
654-
if not self._called:
655-
self._called = True
656-
func(*args, **kwargs)

test_elasticsearch/test_async/test_transport.py

+54-1
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
ConnectionError,
3535
ElasticsearchWarning,
3636
NotElasticsearchError,
37+
NotFoundError,
3738
TransportError,
3839
)
3940

@@ -770,7 +771,9 @@ async def request_task():
770771
# The rest of the requests are 'GET /_search' afterwards
771772
assert all(call[0][:2] == ("GET", "/_search") for call in calls[1:])
772773

773-
async def test_multiple_requests_verify_elasticsearch_errors(self, event_loop):
774+
async def test_multiple_requests_verify_elasticsearch_product_error(
775+
self, event_loop
776+
):
774777
t = AsyncTransport(
775778
[
776779
{
@@ -823,3 +826,53 @@ async def request_task():
823826

824827
# The rest of the requests are 'GET /_search' afterwards
825828
assert all(call[0][:2] == ("GET", "/_search") for call in calls[1:])
829+
830+
@pytest.mark.parametrize("error_cls", [ConnectionError, NotFoundError])
831+
async def test_multiple_requests_verify_elasticsearch_retry_on_errors(
832+
self, event_loop, error_cls
833+
):
834+
t = AsyncTransport(
835+
[
836+
{
837+
"exception": error_cls(),
838+
"delay": 0.1,
839+
}
840+
],
841+
connection_class=DummyConnection,
842+
)
843+
844+
results = []
845+
completed_at = []
846+
847+
async def request_task():
848+
try:
849+
results.append(await t.perform_request("GET", "/_search"))
850+
except Exception as e:
851+
results.append(e)
852+
completed_at.append(event_loop.time())
853+
854+
# Execute a bunch of requests concurrently.
855+
tasks = []
856+
start_time = event_loop.time()
857+
for _ in range(5):
858+
tasks.append(event_loop.create_task(request_task()))
859+
await asyncio.gather(*tasks)
860+
end_time = event_loop.time()
861+
862+
# Exactly 5 results completed
863+
assert len(results) == 5
864+
865+
# All results were errors and not wrapped in 'NotElasticsearchError'
866+
assert all(isinstance(result, error_cls) for result in results)
867+
868+
# Assert that 5 requests were made in total (5 transport requests per x 0.1s/conn request)
869+
duration = end_time - start_time
870+
assert 0.5 <= duration <= 0.6
871+
872+
# Assert that the cluster is still in the unknown/unverified stage.
873+
assert t._verified_elasticsearch is None
874+
875+
# See that the API isn't hit, instead it's the index requests that are failing.
876+
calls = t.connection_pool.connections[0].calls
877+
assert len(calls) == 5
878+
assert all(call[0] == ("GET", "/") for call in calls)

0 commit comments

Comments
 (0)