Skip to content

Commit 801a839

Browse files
authored
Verify we're connected to Elasticsearch before requests
1 parent a2bacb2 commit 801a839

File tree

9 files changed

+929
-6
lines changed

9 files changed

+929
-6
lines changed

elasticsearch/_async/transport.py

+105-1
Original file line numberDiff line numberDiff line change
@@ -18,15 +18,20 @@
1818
import asyncio
1919
import logging
2020
import sys
21+
import warnings
2122
from itertools import chain
2223

2324
from ..exceptions import (
25+
AuthenticationException,
26+
AuthorizationException,
2427
ConnectionError,
2528
ConnectionTimeout,
29+
ElasticsearchWarning,
30+
NotElasticsearchError,
2631
SerializationError,
2732
TransportError,
2833
)
29-
from ..transport import Transport
34+
from ..transport import Transport, _verify_elasticsearch
3035
from .compat import get_running_loop
3136
from .http_aiohttp import AIOHttpConnection
3237

@@ -113,6 +118,10 @@ async def _async_init(self):
113118
self.loop = get_running_loop()
114119
self.kwargs["loop"] = self.loop
115120

121+
# Set our 'verified_once' implementation to one that
122+
# works with 'asyncio' instead of 'threading'
123+
self._verified_once = Once()
124+
116125
# Now that we have a loop we can create all our HTTP connections...
117126
self.set_connections(self.hosts)
118127
self.seed_connections = list(self.connection_pool.connections[:])
@@ -327,6 +336,19 @@ async def perform_request(self, method, url, headers=None, params=None, body=Non
327336
method, headers, params, body
328337
)
329338

339+
# Before we make the actual API call we verify the Elasticsearch instance.
340+
if self._verified_elasticsearch is None:
341+
await self._verified_once.call(
342+
self._do_verify_elasticsearch, headers=headers, timeout=timeout
343+
)
344+
345+
# If '_verified_elasticsearch' is False we know we're not connected to Elasticsearch.
346+
if self._verified_elasticsearch is False:
347+
raise NotElasticsearchError(
348+
"The client noticed that the server is not Elasticsearch "
349+
"and we do not support this unknown product"
350+
)
351+
330352
for attempt in range(self.max_retries + 1):
331353
connection = self.get_connection()
332354

@@ -398,3 +420,85 @@ async def close(self):
398420

399421
for connection in self.connection_pool.connections:
400422
await connection.close()
423+
424+
async def _do_verify_elasticsearch(self, headers, timeout):
425+
"""Verifies that we're connected to an Elasticsearch cluster.
426+
This is done at least once before the first actual API call
427+
and makes a single request to the 'GET /' API endpoint and
428+
check version along with other details of the response.
429+
430+
If we're unable to verify we're talking to Elasticsearch
431+
but we're also unable to rule it out due to a permission
432+
error we instead emit an 'ElasticsearchWarning'.
433+
"""
434+
# Product check has already been done, no need to do again.
435+
if self._verified_elasticsearch:
436+
return
437+
438+
headers = {header.lower(): value for header, value in (headers or {}).items()}
439+
# We know we definitely want JSON so request it via 'accept'
440+
headers.setdefault("accept", "application/json")
441+
442+
info_headers = {}
443+
info_response = {}
444+
error = None
445+
446+
for conn in chain(self.connection_pool.connections, self.seed_connections):
447+
try:
448+
_, info_headers, info_response = await conn.perform_request(
449+
"GET", "/", headers=headers, timeout=timeout
450+
)
451+
452+
# Lowercase all the header names for consistency in accessing them.
453+
info_headers = {
454+
header.lower(): value for header, value in info_headers.items()
455+
}
456+
457+
info_response = self.deserializer.loads(
458+
info_response, mimetype="application/json"
459+
)
460+
break
461+
462+
# Previous versions of 7.x Elasticsearch required a specific
463+
# permission so if we receive HTTP 401/403 we should warn
464+
# instead of erroring out.
465+
except (AuthenticationException, AuthorizationException):
466+
warnings.warn(
467+
(
468+
"The client is unable to verify that the server is "
469+
"Elasticsearch due security privileges on the server side"
470+
),
471+
ElasticsearchWarning,
472+
stacklevel=4,
473+
)
474+
self._verified_elasticsearch = True
475+
return
476+
477+
# This connection didn't work, we'll try another.
478+
except (ConnectionError, SerializationError) as err:
479+
if error is None:
480+
error = err
481+
482+
# If we received a connection error and weren't successful
483+
# anywhere then we reraise the more appropriate error.
484+
if error and not info_response:
485+
raise error
486+
487+
# Check the information we got back from the index request.
488+
self._verified_elasticsearch = _verify_elasticsearch(
489+
info_headers, info_response
490+
)
491+
492+
493+
class Once:
494+
"""Simple class which forces an async function to only execute once."""
495+
496+
def __init__(self):
497+
self._lock = asyncio.Lock()
498+
self._called = False
499+
500+
async def call(self, func, *args, **kwargs):
501+
async with self._lock:
502+
if not self._called:
503+
self._called = True
504+
await func(*args, **kwargs)

elasticsearch/compat.py

+11
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,17 @@ def to_bytes(x, encoding="ascii"):
7070
except (ImportError, AttributeError):
7171
pass
7272

73+
try:
74+
from threading import Lock
75+
except ImportError: # Python <3.7 isn't guaranteed to have threading support.
76+
77+
class Lock:
78+
def __enter__(self):
79+
pass
80+
81+
def __exit__(self, *_):
82+
pass
83+
7384

7485
__all__ = [
7586
"string_types",

elasticsearch/exceptions.py

+6
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,12 @@ class SerializationError(ElasticsearchException):
5151
"""
5252

5353

54+
class NotElasticsearchError(ElasticsearchException):
55+
"""Error which is raised when the client detects
56+
it's not connected to an Elasticsearch cluster.
57+
"""
58+
59+
5460
class TransportError(ElasticsearchException):
5561
"""
5662
Exception raised when ES returns a non-OK (>=400) HTTP status code. Or when

elasticsearch/exceptions.pyi

+1
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ from typing import Any, Dict, Union
2020
class ImproperlyConfigured(Exception): ...
2121
class ElasticsearchException(Exception): ...
2222
class SerializationError(ElasticsearchException): ...
23+
class NotElasticsearchError(ElasticsearchException): ...
2324

2425
class TransportError(ElasticsearchException):
2526
@property

elasticsearch/transport.py

+166
Original file line numberDiff line numberDiff line change
@@ -15,16 +15,23 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717

18+
import re
1819
import time
20+
import warnings
1921
from itertools import chain
2022
from platform import python_version
2123

2224
from ._version import __versionstr__
25+
from .compat import Lock
2326
from .connection import Urllib3HttpConnection
2427
from .connection_pool import ConnectionPool, DummyConnectionPool, EmptyConnectionPool
2528
from .exceptions import (
29+
AuthenticationException,
30+
AuthorizationException,
2631
ConnectionError,
2732
ConnectionTimeout,
33+
ElasticsearchWarning,
34+
NotElasticsearchError,
2835
SerializationError,
2936
TransportError,
3037
)
@@ -198,6 +205,23 @@ def __init__(
198205
if http_client_meta:
199206
self._client_meta += (http_client_meta,)
200207

208+
# Tri-state flag that describes what state the verification
209+
# of whether we're connected to an Elasticsearch cluster or not.
210+
# The three states are:
211+
# - 'None': Means we've either not started the verification process
212+
# or that the verification is in progress. '_verified_once' ensures
213+
# that multiple requests don't kick off multiple verification processes.
214+
# - 'True': Means we've verified that we're talking to Elasticsearch or
215+
# that we can't rule out Elasticsearch due to auth issues. A warning
216+
# will be raised if we receive 401/403.
217+
# - 'False': Means we've discovered we're not talking to Elasticsearch,
218+
# should raise an error in this case for every request.
219+
self._verified_elasticsearch = None
220+
221+
# Ensures that the ES verification request only fires once and that
222+
# all requests block until this request returns back.
223+
self._verified_once = Once()
224+
201225
def add_connection(self, host):
202226
"""
203227
Create a new :class:`~elasticsearch.Connection` instance and add it to the pool.
@@ -380,6 +404,19 @@ def perform_request(self, method, url, headers=None, params=None, body=None):
380404
method, headers, params, body
381405
)
382406

407+
# Before we make the actual API call we verify the Elasticsearch instance.
408+
if self._verified_elasticsearch is None:
409+
self._verified_once.call(
410+
self._do_verify_elasticsearch, headers=headers, timeout=timeout
411+
)
412+
413+
# If '_verified_elasticsearch' is False we know we're not connected to Elasticsearch.
414+
if self._verified_elasticsearch is False:
415+
raise NotElasticsearchError(
416+
"The client noticed that the server is not Elasticsearch "
417+
"and we do not support this unknown product"
418+
)
419+
383420
for attempt in range(self.max_retries + 1):
384421
connection = self.get_connection()
385422

@@ -488,3 +525,132 @@ def _resolve_request_args(self, method, headers, params, body):
488525
)
489526

490527
return method, headers, params, body, ignore, timeout
528+
529+
def _do_verify_elasticsearch(self, headers, timeout):
530+
"""Verifies that we're connected to an Elasticsearch cluster.
531+
This is done at least once before the first actual API call
532+
and makes a single request to the 'GET /' API endpoint to
533+
check the version along with other details of the response.
534+
535+
If we're unable to verify we're talking to Elasticsearch
536+
but we're also unable to rule it out due to a permission
537+
error we instead emit an 'ElasticsearchWarning'.
538+
"""
539+
# Product check has already been done, no need to do again.
540+
if self._verified_elasticsearch is not None:
541+
return
542+
543+
headers = {header.lower(): value for header, value in (headers or {}).items()}
544+
# We know we definitely want JSON so request it via 'accept'
545+
headers.setdefault("accept", "application/json")
546+
547+
info_headers = {}
548+
info_response = {}
549+
error = None
550+
551+
for conn in chain(self.connection_pool.connections, self.seed_connections):
552+
try:
553+
_, info_headers, info_response = conn.perform_request(
554+
"GET", "/", headers=headers, timeout=timeout
555+
)
556+
557+
# Lowercase all the header names for consistency in accessing them.
558+
info_headers = {
559+
header.lower(): value for header, value in info_headers.items()
560+
}
561+
562+
info_response = self.deserializer.loads(
563+
info_response, mimetype="application/json"
564+
)
565+
break
566+
567+
# Previous versions of 7.x Elasticsearch required a specific
568+
# permission so if we receive HTTP 401/403 we should warn
569+
# instead of erroring out.
570+
except (AuthenticationException, AuthorizationException):
571+
warnings.warn(
572+
(
573+
"The client is unable to verify that the server is "
574+
"Elasticsearch due security privileges on the server side"
575+
),
576+
ElasticsearchWarning,
577+
stacklevel=5,
578+
)
579+
self._verified_elasticsearch = True
580+
return
581+
582+
# This connection didn't work, we'll try another.
583+
except (ConnectionError, SerializationError) as err:
584+
if error is None:
585+
error = err
586+
587+
# If we received a connection error and weren't successful
588+
# anywhere then we reraise the more appropriate error.
589+
if error and not info_response:
590+
raise error
591+
592+
# Check the information we got back from the index request.
593+
self._verified_elasticsearch = _verify_elasticsearch(
594+
info_headers, info_response
595+
)
596+
597+
598+
def _verify_elasticsearch(headers, response):
599+
"""Verifies that the server we're talking to is Elasticsearch.
600+
Does this by checking HTTP headers and the deserialized
601+
response to the 'info' API. Returns 'True' if we're verified
602+
against Elasticsearch, 'False' otherwise.
603+
"""
604+
try:
605+
version = response.get("version", {})
606+
version_number = tuple(
607+
int(x) if x is not None else 999
608+
for x in re.search(
609+
r"^([0-9]+)\.([0-9]+)(?:\.([0-9]+))?", version["number"]
610+
).groups()
611+
)
612+
except (KeyError, TypeError, ValueError, AttributeError):
613+
# No valid 'version.number' field, effectively 0.0.0
614+
version = {}
615+
version_number = (0, 0, 0)
616+
617+
# Check all of the fields and headers for missing/valid values.
618+
try:
619+
bad_tagline = response.get("tagline", None) != "You Know, for Search"
620+
bad_build_flavor = version.get("build_flavor", None) != "default"
621+
bad_product_header = headers.get("x-elastic-product", None) != "Elasticsearch"
622+
except (AttributeError, TypeError):
623+
bad_tagline = True
624+
bad_build_flavor = True
625+
bad_product_header = True
626+
627+
if (
628+
# No version or version less than 6.x
629+
version_number < (6, 0, 0)
630+
# 6.x and there's a bad 'tagline'
631+
or ((6, 0, 0) <= version_number < (7, 0, 0) and bad_tagline)
632+
# 7.0-7.13 and there's a bad 'tagline' or 'build_flavor'
633+
or (
634+
(7, 0, 0) <= version_number < (7, 14, 0)
635+
and (bad_tagline or bad_build_flavor)
636+
)
637+
# 7.14+ and there's a bad 'X-Elastic-Product' HTTP header
638+
or ((7, 14, 0) <= version_number and bad_product_header)
639+
):
640+
return False
641+
642+
return True
643+
644+
645+
class Once:
646+
"""Simple class which forces a function to only execute once."""
647+
648+
def __init__(self):
649+
self._lock = Lock()
650+
self._called = False
651+
652+
def call(self, func, *args, **kwargs):
653+
with self._lock:
654+
if not self._called:
655+
self._called = True
656+
func(*args, **kwargs)

test_elasticsearch/test_async/test_connection.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from mock import patch
2929
from multidict import CIMultiDict
3030

31-
from elasticsearch import AIOHttpConnection, __versionstr__
31+
from elasticsearch import AIOHttpConnection, AsyncElasticsearch, __versionstr__
3232
from elasticsearch.compat import reraise_exceptions
3333
from elasticsearch.exceptions import ConnectionError
3434

@@ -410,3 +410,9 @@ async def test_aiohttp_connection_error(self):
410410
conn = AIOHttpConnection("not.a.host.name")
411411
with pytest.raises(ConnectionError):
412412
await conn.perform_request("GET", "/")
413+
414+
async def test_elasticsearch_connection_error(self):
415+
es = AsyncElasticsearch("http://not.a.host.name")
416+
417+
with pytest.raises(ConnectionError):
418+
await es.search()

0 commit comments

Comments
 (0)