From 477530acff66f1d69adb3467db53d95b5fa25939 Mon Sep 17 00:00:00 2001 From: Seth Michael Larson Date: Mon, 28 Jun 2021 12:41:16 -0500 Subject: [PATCH 1/3] Verify connection to Elasticsearch --- elasticsearch/_async/transport.py | 74 ++++- elasticsearch/exceptions.py | 6 + elasticsearch/exceptions.pyi | 1 + elasticsearch/transport.py | 125 +++++++++ .../test_async/test_transport.py | 142 +++++++++- test_elasticsearch/test_transport.py | 253 +++++++++++++++++- 6 files changed, 597 insertions(+), 4 deletions(-) diff --git a/elasticsearch/_async/transport.py b/elasticsearch/_async/transport.py index 3fd637335..d193a2832 100644 --- a/elasticsearch/_async/transport.py +++ b/elasticsearch/_async/transport.py @@ -18,15 +18,19 @@ import asyncio import logging import sys +import warnings from itertools import chain from ..exceptions import ( + AuthenticationException, + AuthorizationException, ConnectionError, ConnectionTimeout, + ElasticsearchWarning, SerializationError, TransportError, ) -from ..transport import Transport +from ..transport import Transport, _verify_elasticsearch from .compat import get_running_loop from .http_aiohttp import AIOHttpConnection @@ -327,6 +331,10 @@ async def perform_request(self, method, url, headers=None, params=None, body=Non method, headers, params, body ) + # Before we make the actual API call we verify the Elasticsearch instance. + if not self._verified_elasticsearch: + await self._do_verify_elasticsearch(headers=headers, timeout=timeout) + for attempt in range(self.max_retries + 1): connection = self.get_connection() @@ -398,3 +406,67 @@ async def close(self): for connection in self.connection_pool.connections: await connection.close() + + async def _do_verify_elasticsearch(self, headers, timeout): + """Verifies that we're connected to an Elasticsearch cluster. + This is done at least once before the first actual API call + and makes a single request to the 'GET /' API endpoint and + check version along with other details of the response. + + If we're unable to verify we're talking to Elasticsearch + but we're also unable to rule it out due to a permission + error we instead emit an 'ElasticsearchWarning'. + """ + # Product check has already been done, no need to do again. + if self._verified_elasticsearch: + return + + headers = {header.lower(): value for header, value in (headers or {}).items()} + # We know we definitely want JSON so request it via 'accept' + headers.setdefault("accept", "application/json") + + info_headers = {} + info_response = {} + info_error = None + + for conn in chain(self.connection_pool.connections, self.seed_connections): + try: + _, info_headers, info_response = await conn.perform_request( + "GET", "/", headers=headers, timeout=timeout + ) + + # Lowercase all the header names for consistency in accessing them. + info_headers = { + header.lower(): value for header, value in info_headers.items() + } + + info_response = self.deserializer.loads( + info_response, mimetype="application/json" + ) + break + + # Previous versions of 7.x Elasticsearch required a specific + # permission so if we receive HTTP 401/403 we should warn + # instead of erroring out. + except (AuthenticationException, AuthorizationException): + warnings.warn( + ( + "The client is unable to verify that the server is " + "Elasticsearch due security privileges on the server side" + ), + ElasticsearchWarning, + stacklevel=3, + ) + self._verified_elasticsearch = True + return + + # This connection didn't work, we'll try another. + except (ConnectionError, SerializationError): + if info_error is None: + info_error = info_error + + # Check the information we got back from the index request. + _verify_elasticsearch(info_headers, info_response) + + # If we made it through the above call this config is verified. + self._verified_elasticsearch = True diff --git a/elasticsearch/exceptions.py b/elasticsearch/exceptions.py index 1a7a438dc..2b3c5ec2d 100644 --- a/elasticsearch/exceptions.py +++ b/elasticsearch/exceptions.py @@ -51,6 +51,12 @@ class SerializationError(ElasticsearchException): """ +class NotElasticsearchError(ElasticsearchException): + """Error which is raised when the client detects + it's not connected to an Elasticsearch cluster. + """ + + class TransportError(ElasticsearchException): """ Exception raised when ES returns a non-OK (>=400) HTTP status code. Or when diff --git a/elasticsearch/exceptions.pyi b/elasticsearch/exceptions.pyi index fd28fd80e..1dcab8c09 100644 --- a/elasticsearch/exceptions.pyi +++ b/elasticsearch/exceptions.pyi @@ -20,6 +20,7 @@ from typing import Any, Dict, Union class ImproperlyConfigured(Exception): ... class ElasticsearchException(Exception): ... class SerializationError(ElasticsearchException): ... +class NotElasticsearchError(ElasticsearchException): ... class TransportError(ElasticsearchException): @property diff --git a/elasticsearch/transport.py b/elasticsearch/transport.py index cf46e9b2d..c0f7ff8f2 100644 --- a/elasticsearch/transport.py +++ b/elasticsearch/transport.py @@ -15,7 +15,9 @@ # specific language governing permissions and limitations # under the License. +import re import time +import warnings from itertools import chain from platform import python_version @@ -23,8 +25,12 @@ from .connection import Urllib3HttpConnection from .connection_pool import ConnectionPool, DummyConnectionPool, EmptyConnectionPool from .exceptions import ( + AuthenticationException, + AuthorizationException, ConnectionError, ConnectionTimeout, + ElasticsearchWarning, + NotElasticsearchError, SerializationError, TransportError, ) @@ -198,6 +204,10 @@ def __init__( if http_client_meta: self._client_meta += (http_client_meta,) + # Flag which is set after verifying that we're + # connected to Elasticsearch. + self._verified_elasticsearch = False + def add_connection(self, host): """ Create a new :class:`~elasticsearch.Connection` instance and add it to the pool. @@ -380,6 +390,9 @@ def perform_request(self, method, url, headers=None, params=None, body=None): method, headers, params, body ) + # Before we make the actual API call we verify the Elasticsearch instance. + self._do_verify_elasticsearch(headers=headers, timeout=timeout) + for attempt in range(self.max_retries + 1): connection = self.get_connection() @@ -488,3 +501,115 @@ def _resolve_request_args(self, method, headers, params, body): ) return method, headers, params, body, ignore, timeout + + def _do_verify_elasticsearch(self, headers, timeout): + """Verifies that we're connected to an Elasticsearch cluster. + This is done at least once before the first actual API call + and makes a single request to the 'GET /' API endpoint to + check the version along with other details of the response. + + If we're unable to verify we're talking to Elasticsearch + but we're also unable to rule it out due to a permission + error we instead emit an 'ElasticsearchWarning'. + """ + # Product check has already been done, no need to do again. + if self._verified_elasticsearch: + return + + headers = {header.lower(): value for header, value in (headers or {}).items()} + # We know we definitely want JSON so request it via 'accept' + headers.setdefault("accept", "application/json") + + info_headers = {} + info_response = {} + + for conn in chain(self.connection_pool.connections, self.seed_connections): + try: + _, info_headers, info_response = conn.perform_request( + "GET", "/", headers=headers, timeout=timeout + ) + + # Lowercase all the header names for consistency in accessing them. + info_headers = { + header.lower(): value for header, value in info_headers.items() + } + + info_response = self.deserializer.loads( + info_response, mimetype="application/json" + ) + break + + # Previous versions of 7.x Elasticsearch required a specific + # permission so if we receive HTTP 401/403 we should warn + # instead of erroring out. + except (AuthenticationException, AuthorizationException): + warnings.warn( + ( + "The client is unable to verify that the server is " + "Elasticsearch due security privileges on the server side" + ), + ElasticsearchWarning, + stacklevel=3, + ) + self._verified_elasticsearch = True + return + + # This connection didn't work, we'll try another. + except (ConnectionError, SerializationError): + pass + + # Check the information we got back from the index request. + _verify_elasticsearch(info_headers, info_response) + + # If we made it through the above call this config is verified. + self._verified_elasticsearch = True + + +def _verify_elasticsearch(headers, response): + """Verifies that the server we're talking to is Elasticsearch. + Does this by checking HTTP headers and the deserialized + response to the 'info' API. + + If there's a problem this function raises 'NotElasticsearchError' + otherwise doesn't do anything. + """ + try: + version = response.get("version", {}) + version_number = tuple( + int(x) if x is not None else 999 + for x in re.search( + r"^([0-9]+)\.([0-9]+)(?:\.([0-9]+))?", version["number"] + ).groups() + ) + except (KeyError, TypeError, ValueError, AttributeError): + # No valid 'version.number' field, effectively 0.0.0 + version = {} + version_number = (0, 0, 0) + + # Check all of the fields and headers for missing/valid values. + try: + bad_tagline = response.get("tagline", None) != "You Know, for Search" + bad_build_flavor = version.get("build_flavor", None) != "default" + bad_product_header = headers.get("x-elastic-product", None) != "Elasticsearch" + except (AttributeError, TypeError): + bad_tagline = True + bad_build_flavor = True + bad_product_header = True + + if ( + # No version or version less than 6.x + version_number < (6, 0, 0) + # 6.x and there's a bad 'tagline' + or ((6, 0, 0) <= version_number < (7, 0, 0) and bad_tagline) + # 7.0-7.13 and there's a bad 'tagline' or 'build_flavor' + or ( + (7, 0, 0) <= version_number < (7, 14, 0) + and (bad_tagline or bad_build_flavor) + ) + # 7.14+ and there's a bad 'X-Elastic-Product' HTTP header + or ((7, 14, 0) <= version_number and bad_product_header) + ): + raise NotElasticsearchError( + "The client noticed that the server is not Elasticsearch " + "and we do not support this unknown product" + ) diff --git a/test_elasticsearch/test_async/test_transport.py b/test_elasticsearch/test_async/test_transport.py index 13bc492f8..c55012f54 100644 --- a/test_elasticsearch/test_async/test_transport.py +++ b/test_elasticsearch/test_async/test_transport.py @@ -28,7 +28,13 @@ from elasticsearch import AsyncTransport from elasticsearch.connection import Connection from elasticsearch.connection_pool import DummyConnectionPool -from elasticsearch.exceptions import ConnectionError, TransportError +from elasticsearch.exceptions import ( + AuthenticationException, + AuthorizationException, + ConnectionError, + ElasticsearchWarning, + TransportError, +) pytestmark = pytest.mark.asyncio @@ -121,6 +127,7 @@ async def test_single_connection_uses_dummy_connection_pool(self): async def test_request_timeout_extracted_from_params_and_passed(self): t = AsyncTransport([{}], connection_class=DummyConnection, meta_header=False) + t._verified_elasticsearch = True await t.perform_request("GET", "/", params={"request_timeout": 42}) assert 1 == len(t.get_connection().calls) @@ -135,6 +142,7 @@ async def test_opaque_id(self): t = AsyncTransport( [{}], opaque_id="app-1", connection_class=DummyConnection, meta_header=False ) + t._verified_elasticsearch = True await t.perform_request("GET", "/") assert 1 == len(t.get_connection().calls) @@ -157,6 +165,7 @@ async def test_opaque_id(self): async def test_request_with_custom_user_agent_header(self): t = AsyncTransport([{}], connection_class=DummyConnection, meta_header=False) + t._verified_elasticsearch = True await t.perform_request( "GET", "/", headers={"user-agent": "my-custom-value/1.2.3"} @@ -172,6 +181,7 @@ async def test_send_get_body_as_source(self): t = AsyncTransport( [{}], send_get_body_as="source", connection_class=DummyConnection ) + t._verified_elasticsearch = True await t.perform_request("GET", "/", body={}) assert 1 == len(t.get_connection().calls) @@ -181,6 +191,7 @@ async def test_send_get_body_as_post(self): t = AsyncTransport( [{}], send_get_body_as="POST", connection_class=DummyConnection ) + t._verified_elasticsearch = True await t.perform_request("GET", "/", body={}) assert 1 == len(t.get_connection().calls) @@ -188,6 +199,7 @@ async def test_send_get_body_as_post(self): async def test_client_meta_header(self): t = AsyncTransport([{}], connection_class=DummyConnection) + t._verified_elasticsearch = True await t.perform_request("GET", "/", body={}) assert len(t.get_connection().calls) == 1 @@ -201,6 +213,7 @@ class DummyConnectionWithMeta(DummyConnection): HTTP_CLIENT_META = ("dm", "1.2.3") t = AsyncTransport([{}], connection_class=DummyConnectionWithMeta) + t._verified_elasticsearch = True await t.perform_request("GET", "/", body={}, headers={"Custom": "header"}) assert len(t.get_connection().calls) == 1 @@ -213,6 +226,7 @@ class DummyConnectionWithMeta(DummyConnection): async def test_client_meta_header_not_sent(self): t = AsyncTransport([{}], meta_header=False, connection_class=DummyConnection) + t._verified_elasticsearch = True await t.perform_request("GET", "/", body={}) assert len(t.get_connection().calls) == 1 @@ -221,6 +235,7 @@ async def test_client_meta_header_not_sent(self): async def test_body_gets_encoded_into_bytes(self): t = AsyncTransport([{}], connection_class=DummyConnection) + t._verified_elasticsearch = True await t.perform_request("GET", "/", body="你好") assert 1 == len(t.get_connection().calls) @@ -233,6 +248,7 @@ async def test_body_gets_encoded_into_bytes(self): async def test_body_bytes_get_passed_untouched(self): t = AsyncTransport([{}], connection_class=DummyConnection) + t._verified_elasticsearch = True body = b"\xe4\xbd\xa0\xe5\xa5\xbd" await t.perform_request("GET", "/", body=body) @@ -241,6 +257,7 @@ async def test_body_bytes_get_passed_untouched(self): async def test_body_surrogates_replaced_encoded_into_bytes(self): t = AsyncTransport([{}], connection_class=DummyConnection) + t._verified_elasticsearch = True await t.perform_request("GET", "/", body="你好\uda6a") assert 1 == len(t.get_connection().calls) @@ -253,6 +270,8 @@ async def test_body_surrogates_replaced_encoded_into_bytes(self): async def test_kwargs_passed_on_to_connections(self): t = AsyncTransport([{"host": "google.com"}], port=123) + t._verified_elasticsearch = True + await t._async_call() assert 1 == len(t.connection_pool.connections) assert "http://google.com:123" == t.connection_pool.connections[0].host @@ -260,6 +279,8 @@ async def test_kwargs_passed_on_to_connections(self): async def test_kwargs_passed_on_to_connection_pool(self): dt = object() t = AsyncTransport([{}, {}], dead_timeout=dt) + t._verified_elasticsearch = True + await t._async_call() assert dt is t.connection_pool.dead_timeout @@ -269,12 +290,15 @@ def __init__(self, **kwargs): self.kwargs = kwargs t = AsyncTransport([{}], connection_class=MyConnection) + t._verified_elasticsearch = True + await t._async_call() assert 1 == len(t.connection_pool.connections) assert isinstance(t.connection_pool.connections[0], MyConnection) def test_add_connection(self): t = AsyncTransport([{}], randomize_hosts=False) + t._verified_elasticsearch = True t.add_connection({"host": "google.com", "port": 1234}) assert 2 == len(t.connection_pool.connections) @@ -285,6 +309,7 @@ async def test_request_will_fail_after_X_retries(self): [{"exception": ConnectionError("abandon ship")}], connection_class=DummyConnection, ) + t._verified_elasticsearch = True connection_error = False try: @@ -300,6 +325,7 @@ async def test_failed_connection_will_be_marked_as_dead(self): [{"exception": ConnectionError("abandon ship")}] * 2, connection_class=DummyConnection, ) + t._verified_elasticsearch = True connection_error = False try: @@ -313,6 +339,8 @@ async def test_failed_connection_will_be_marked_as_dead(self): async def test_resurrected_connection_will_be_marked_as_live_on_success(self): for method in ("GET", "HEAD"): t = AsyncTransport([{}, {}], connection_class=DummyConnection) + t._verified_elasticsearch = True + await t._async_call() con1 = t.connection_pool.get_connection() con2 = t.connection_pool.get_connection() @@ -338,6 +366,7 @@ async def test_sniff_on_start_fetches_and_uses_nodes_list(self): connection_class=DummyConnection, sniff_on_start=True, ) + await t._async_call() await t.sniffing_task # Need to wait for the sniffing task to complete @@ -351,6 +380,7 @@ async def test_sniff_on_start_ignores_sniff_timeout(self): sniff_on_start=True, sniff_timeout=12, ) + await t._async_call() await t.sniffing_task # Need to wait for the sniffing task to complete @@ -393,6 +423,8 @@ async def test_sniff_on_fail_triggers_sniffing_on_fail(self): max_retries=0, randomize_hosts=False, ) + t._verified_elasticsearch = True + await t._async_call() connection_error = False @@ -417,6 +449,7 @@ async def test_sniff_on_fail_failing_does_not_prevent_retires(self, sniff_hosts) max_retries=3, randomize_hosts=False, ) + t._verified_elasticsearch = True await t._async_init() conn_err, conn_data = t.connection_pool.connections @@ -432,6 +465,7 @@ async def test_sniff_after_n_seconds(self, event_loop): connection_class=DummyConnection, sniffer_timeout=5, ) + t._verified_elasticsearch = True await t._async_call() for _ in range(4): @@ -455,7 +489,9 @@ async def test_sniff_7x_publish_host(self): connection_class=DummyConnection, sniff_timeout=42, ) + t._verified_elasticsearch = True await t._async_call() + await t.sniff_hosts() # Ensure we parsed out the fqdn and port from the fqdn/ip:port string. assert t.connection_pool.connection_opts[0][1] == { @@ -472,6 +508,7 @@ async def test_sniffing_disabled_on_cloud_instances(self, sniff_hosts): connection_class=DummyConnection, cloud_id="cluster:dXMtZWFzdC0xLmF3cy5mb3VuZC5pbyQ0ZmE4ODIxZTc1NjM0MDMyYmVkMWNmMjIxMTBlMmY5NyQ0ZmE4ODIxZTc1NjM0MDMyYmVkMWNmMjIxMTBlMmY5Ng==", ) + t._verified_elasticsearch = True await t._async_call() assert not t.sniff_on_connection_fail @@ -482,6 +519,7 @@ async def test_sniffing_disabled_on_cloud_instances(self, sniff_hosts): async def test_transport_close_closes_all_pool_connections(self): t = AsyncTransport([{}], connection_class=DummyConnection) + t._verified_elasticsearch = True await t._async_call() assert not any([conn.closed for conn in t.connection_pool.connections]) @@ -489,6 +527,7 @@ async def test_transport_close_closes_all_pool_connections(self): assert all([conn.closed for conn in t.connection_pool.connections]) t = AsyncTransport([{}, {}], connection_class=DummyConnection) + t._verified_elasticsearch = True await t._async_call() assert not any([conn.closed for conn in t.connection_pool.connections]) @@ -505,6 +544,7 @@ async def test_sniff_on_start_error_if_no_sniffed_hosts(self, event_loop): connection_class=DummyConnection, sniff_on_start=True, ) + t._verified_elasticsearch = True # If our initial sniffing attempt comes back # empty then we raise an error. @@ -522,6 +562,7 @@ async def test_sniff_on_start_waits_for_sniff_to_complete(self, event_loop): connection_class=DummyConnection, sniff_on_start=True, ) + t._verified_elasticsearch = True # Start the timer right before the first task # and have a bunch of tasks come in immediately. @@ -556,6 +597,7 @@ async def test_sniff_on_start_close_unlocks_async_calls(self, event_loop): connection_class=DummyConnection, sniff_on_start=True, ) + t._verified_elasticsearch = True # Start making _async_calls() before we cancel tasks = [] @@ -574,3 +616,101 @@ async def test_sniff_on_start_close_unlocks_async_calls(self, event_loop): # A lot quicker than 10 seconds defined in 'delay' assert duration < 1 + + @pytest.mark.parametrize( + ["headers", "data"], + [ + ( + {}, + '{"version":{"number":"6.99.0"},"tagline":"You Know, for Search"}', + ), + ( + {}, + '{"version":{"number":"7.13.0","build_flavor":"default"},"tagline":"You Know, for Search"}', + ), + ( + {"X-elastic-product": "Elasticsearch"}, + '{"version":{"number":"7.14.0","build_flavor":"default"},"tagline":"You Know, for Search"}', + ), + ], + ) + async def test_verify_elasticsearch(self, headers, data): + t = AsyncTransport( + [{"data": data, "headers": headers}], connection_class=DummyConnection + ) + await t.perform_request("GET", "/_search") + assert t._verified_elasticsearch + + calls = t.connection_pool.connections[0].calls + _ = [call[1]["headers"].pop("x-elastic-client-meta") for call in calls] + + assert calls == [ + ( + ("GET", "/"), + { + "headers": { + "accept": "application/json", + }, + "timeout": None, + }, + ), + ( + ("GET", "/_search", None, None), + { + "headers": {}, + "ignore": (), + "timeout": None, + }, + ), + ] + + @pytest.mark.parametrize( + "exception_cls", [AuthorizationException, AuthenticationException] + ) + async def test_verify_elasticsearch_skips_on_auth_errors(self, exception_cls): + t = AsyncTransport( + [{"exception": exception_cls(exception_cls.status_code)}], + connection_class=DummyConnection, + ) + + with pytest.warns(ElasticsearchWarning) as warns: + with pytest.raises(exception_cls): + await t.perform_request( + "GET", "/_search", headers={"Authorization": "testme"} + ) + + # Assert that a warning was raised due to security privileges + assert [str(w.message) for w in warns] == [ + "The client is unable to verify that the server is " + "Elasticsearch due security privileges on the server side" + ] + + # Assert that the cluster is "verified" + assert t._verified_elasticsearch + + # See that the headers were passed along to the "info" request made + calls = t.connection_pool.connections[0].calls + _ = [call[1]["headers"].pop("x-elastic-client-meta") for call in calls] + + assert calls == [ + ( + ("GET", "/"), + { + "headers": { + "accept": "application/json", + "authorization": "testme", + }, + "timeout": None, + }, + ), + ( + ("GET", "/_search", None, None), + { + "headers": { + "Authorization": "testme", + }, + "ignore": (), + "timeout": None, + }, + ), + ] diff --git a/test_elasticsearch/test_transport.py b/test_elasticsearch/test_transport.py index c5cd3f0c0..6043222ba 100644 --- a/test_elasticsearch/test_transport.py +++ b/test_elasticsearch/test_transport.py @@ -26,8 +26,17 @@ from elasticsearch.connection import Connection from elasticsearch.connection_pool import DummyConnectionPool -from elasticsearch.exceptions import ConnectionError, TransportError -from elasticsearch.transport import Transport, get_host_info +from elasticsearch.exceptions import ( + AuthenticationException, + AuthorizationException, + ConnectionError, + ElasticsearchWarning, + NotElasticsearchError, + TransportError, +) +from elasticsearch.transport import Transport +from elasticsearch.transport import _verify_elasticsearch as verify_elasticsearch +from elasticsearch.transport import get_host_info from .test_cases import TestCase @@ -118,12 +127,16 @@ def test_master_only_nodes_are_ignored(self): class TestTransport(TestCase): def test_single_connection_uses_dummy_connection_pool(self): t = Transport([{}]) + t._verified_elasticsearch = True self.assertIsInstance(t.connection_pool, DummyConnectionPool) + t = Transport([{"host": "localhost"}]) + t._verified_elasticsearch = True self.assertIsInstance(t.connection_pool, DummyConnectionPool) def test_request_timeout_extracted_from_params_and_passed(self): t = Transport([{}], meta_header=False, connection_class=DummyConnection) + t._verified_elasticsearch = True t.perform_request("GET", "/", params={"request_timeout": 42}) self.assertEqual(1, len(t.get_connection().calls)) @@ -137,6 +150,7 @@ def test_opaque_id(self): t = Transport( [{}], opaque_id="app-1", meta_header=False, connection_class=DummyConnection ) + t._verified_elasticsearch = True t.perform_request("GET", "/") self.assertEqual(1, len(t.get_connection().calls)) @@ -157,6 +171,7 @@ def test_opaque_id(self): def test_request_with_custom_user_agent_header(self): t = Transport([{}], meta_header=False, connection_class=DummyConnection) + t._verified_elasticsearch = True t.perform_request("GET", "/", headers={"user-agent": "my-custom-value/1.2.3"}) self.assertEqual(1, len(t.get_connection().calls)) @@ -171,6 +186,7 @@ def test_request_with_custom_user_agent_header(self): def test_send_get_body_as_source(self): t = Transport([{}], send_get_body_as="source", connection_class=DummyConnection) + t._verified_elasticsearch = True t.perform_request("GET", "/", body={}) self.assertEqual(1, len(t.get_connection().calls)) @@ -180,6 +196,7 @@ def test_send_get_body_as_source(self): def test_send_get_body_as_post(self): t = Transport([{}], send_get_body_as="POST", connection_class=DummyConnection) + t._verified_elasticsearch = True t.perform_request("GET", "/", body={}) self.assertEqual(1, len(t.get_connection().calls)) @@ -187,6 +204,7 @@ def test_send_get_body_as_post(self): def test_client_meta_header(self): t = Transport([{}], connection_class=DummyConnection) + t._verified_elasticsearch = True t.perform_request("GET", "/", body={}) self.assertEqual(1, len(t.get_connection().calls)) @@ -199,6 +217,7 @@ class DummyConnectionWithMeta(DummyConnection): HTTP_CLIENT_META = ("dm", "1.2.3") t = Transport([{}], connection_class=DummyConnectionWithMeta) + t._verified_elasticsearch = True t.perform_request("GET", "/", body={}, headers={"Custom": "header"}) self.assertEqual(1, len(t.get_connection().calls)) @@ -211,6 +230,7 @@ class DummyConnectionWithMeta(DummyConnection): def test_client_meta_header_not_sent(self): t = Transport([{}], meta_header=False, connection_class=DummyConnection) + t._verified_elasticsearch = True t.perform_request("GET", "/", body={}) self.assertEqual(1, len(t.get_connection().calls)) @@ -224,6 +244,7 @@ def test_meta_header_type_error(self): def test_body_gets_encoded_into_bytes(self): t = Transport([{}], connection_class=DummyConnection) + t._verified_elasticsearch = True t.perform_request("GET", "/", body="你好") self.assertEqual(1, len(t.get_connection().calls)) @@ -234,6 +255,7 @@ def test_body_gets_encoded_into_bytes(self): def test_body_bytes_get_passed_untouched(self): t = Transport([{}], connection_class=DummyConnection) + t._verified_elasticsearch = True body = b"\xe4\xbd\xa0\xe5\xa5\xbd" t.perform_request("GET", "/", body=body) @@ -242,6 +264,7 @@ def test_body_bytes_get_passed_untouched(self): def test_body_surrogates_replaced_encoded_into_bytes(self): t = Transport([{}], connection_class=DummyConnection) + t._verified_elasticsearch = True t.perform_request("GET", "/", body="你好\uda6a") self.assertEqual(1, len(t.get_connection().calls)) @@ -252,12 +275,14 @@ def test_body_surrogates_replaced_encoded_into_bytes(self): def test_kwargs_passed_on_to_connections(self): t = Transport([{"host": "google.com"}], port=123) + t._verified_elasticsearch = True self.assertEqual(1, len(t.connection_pool.connections)) self.assertEqual("http://google.com:123", t.connection_pool.connections[0].host) def test_kwargs_passed_on_to_connection_pool(self): dt = object() t = Transport([{}, {}], dead_timeout=dt) + t._verified_elasticsearch = True self.assertIs(dt, t.connection_pool.dead_timeout) def test_custom_connection_class(self): @@ -266,11 +291,13 @@ def __init__(self, **kwargs): self.kwargs = kwargs t = Transport([{}], connection_class=MyConnection) + t._verified_elasticsearch = True self.assertEqual(1, len(t.connection_pool.connections)) self.assertIsInstance(t.connection_pool.connections[0], MyConnection) def test_add_connection(self): t = Transport([{}], randomize_hosts=False) + t._verified_elasticsearch = True t.add_connection({"host": "google.com", "port": 1234}) self.assertEqual(2, len(t.connection_pool.connections)) @@ -283,6 +310,7 @@ def test_request_will_fail_after_X_retries(self): [{"exception": ConnectionError("abandon ship")}], connection_class=DummyConnection, ) + t._verified_elasticsearch = True self.assertRaises(ConnectionError, t.perform_request, "GET", "/") self.assertEqual(4, len(t.get_connection().calls)) @@ -292,6 +320,7 @@ def test_failed_connection_will_be_marked_as_dead(self): [{"exception": ConnectionError("abandon ship")}] * 2, connection_class=DummyConnection, ) + t._verified_elasticsearch = True self.assertRaises(ConnectionError, t.perform_request, "GET", "/") self.assertEqual(0, len(t.connection_pool.connections)) @@ -299,6 +328,7 @@ def test_failed_connection_will_be_marked_as_dead(self): def test_resurrected_connection_will_be_marked_as_live_on_success(self): for method in ("GET", "HEAD"): t = Transport([{}, {}], connection_class=DummyConnection) + t._verified_elasticsearch = True con1 = t.connection_pool.get_connection() con2 = t.connection_pool.get_connection() t.connection_pool.mark_dead(con1) @@ -310,6 +340,7 @@ def test_resurrected_connection_will_be_marked_as_live_on_success(self): def test_sniff_will_use_seed_connections(self): t = Transport([{"data": CLUSTER_NODES}], connection_class=DummyConnection) + t._verified_elasticsearch = True t.set_connections([{"data": "invalid"}]) t.sniff_hosts() @@ -322,6 +353,7 @@ def test_sniff_on_start_fetches_and_uses_nodes_list(self): connection_class=DummyConnection, sniff_on_start=True, ) + t._verified_elasticsearch = True self.assertEqual(1, len(t.connection_pool.connections)) self.assertEqual("http://1.1.1.1:123", t.get_connection().host) @@ -332,6 +364,7 @@ def test_sniff_on_start_ignores_sniff_timeout(self): sniff_on_start=True, sniff_timeout=12, ) + t._verified_elasticsearch = True self.assertEqual( (("GET", "/_nodes/_all/http"), {"timeout": None}), t.seed_connections[0].calls[0], @@ -343,6 +376,7 @@ def test_sniff_uses_sniff_timeout(self): connection_class=DummyConnection, sniff_timeout=42, ) + t._verified_elasticsearch = True t.sniff_hosts() self.assertEqual( (("GET", "/_nodes/_all/http"), {"timeout": 42}), @@ -355,6 +389,7 @@ def test_sniff_reuses_connection_instances_if_possible(self): connection_class=DummyConnection, randomize_hosts=False, ) + t._verified_elasticsearch = True connection = t.connection_pool.connections[1] t.sniff_hosts() @@ -369,6 +404,7 @@ def test_sniff_on_fail_triggers_sniffing_on_fail(self): max_retries=0, randomize_hosts=False, ) + t._verified_elasticsearch = True self.assertRaises(ConnectionError, t.perform_request, "GET", "/") self.assertEqual(1, len(t.connection_pool.connections)) @@ -384,6 +420,7 @@ def test_sniff_on_fail_failing_does_not_prevent_retires(self, sniff_hosts): max_retries=3, randomize_hosts=False, ) + t._verified_elasticsearch = True conn_err, conn_data = t.connection_pool.connections response = t.perform_request("GET", "/") @@ -398,6 +435,7 @@ def test_sniff_after_n_seconds(self): connection_class=DummyConnection, sniffer_timeout=5, ) + t._verified_elasticsearch = True for _ in range(4): t.perform_request("GET", "/") @@ -418,6 +456,7 @@ def test_sniff_7x_publish_host(self): connection_class=DummyConnection, sniff_timeout=42, ) + t._verified_elasticsearch = True t.sniff_hosts() # Ensure we parsed out the fqdn and port from the fqdn/ip:port string. self.assertEqual( @@ -433,6 +472,216 @@ def test_sniffing_disabled_on_cloud_instances(self, sniff_hosts): sniff_on_connection_fail=True, cloud_id="cluster:dXMtZWFzdC0xLmF3cy5mb3VuZC5pbyQ0ZmE4ODIxZTc1NjM0MDMyYmVkMWNmMjIxMTBlMmY5NyQ0ZmE4ODIxZTc1NjM0MDMyYmVkMWNmMjIxMTBlMmY5Ng==", ) + t._verified_elasticsearch = True self.assertFalse(t.sniff_on_connection_fail) self.assertIs(sniff_hosts.call_args, None) # Assert not called. + + +TAGLINE = "You Know, for Search" + + +@pytest.mark.parametrize( + ["headers", "response"], + [ + # All empty. + ({}, {}), + # Don't check the product header immediately, need to check version first. + ({"x-elastic-product": "Elasticsearch"}, {}), + # Version not there. + ({}, {"tagline": TAGLINE}), + # Version is nonsense + ({}, {"version": "1.0.0", "tagline": TAGLINE}), + # Version number not there + ({}, {"version": {}, "tagline": TAGLINE}), + # Version number is nonsense + ({}, {"version": {"number": "nonsense"}, "tagline": TAGLINE}), + # Version number way in the past + ({}, {"version": {"number": "1.0.0"}, "tagline": TAGLINE}), + # Version number way in the future + ({}, {"version": {"number": "999.0.0"}, "tagline": TAGLINE}), + # Build flavor not supposed to be missing + ({}, {"version": {"number": "7.13.0"}, "tagline": TAGLINE}), + # Build flavor is 'oss' + ( + {}, + { + "version": {"number": "7.10.0", "build_flavor": "oss"}, + "tagline": TAGLINE, + }, + ), + # Build flavor is nonsense + ( + {}, + { + "version": {"number": "7.13.0", "build_flavor": "nonsense"}, + "tagline": TAGLINE, + }, + ), + # Tagline is nonsense + ({}, {"version": {"number": "7.1.0-SNAPSHOT"}, "tagline": "nonsense"}), + # Product header is not supposed to be missing + ({}, {"version": {"number": "7.14.0"}, "tagline": "You Know, for Search"}), + # Product header is nonsense + ( + {"x-elastic-product": "nonsense"}, + {"version": {"number": "7.15.0"}, "tagline": TAGLINE}, + ), + ], +) +def test_verify_elasticsearch_errors(headers, response): + with pytest.raises(NotElasticsearchError) as e: + verify_elasticsearch(headers, response) + + assert str(e.value) == ( + "The client noticed that the server is not Elasticsearch " + "and we do not support this unknown product" + ) + + +@pytest.mark.parametrize( + ["headers", "response"], + [ + ({}, {"version": {"number": "6.0.0"}, "tagline": TAGLINE}), + ({}, {"version": {"number": "6.99.99"}, "tagline": TAGLINE}), + ( + {}, + { + "version": {"number": "7.0.0", "build_flavor": "default"}, + "tagline": TAGLINE, + }, + ), + ( + {}, + { + "version": {"number": "7.13.99", "build_flavor": "default"}, + "tagline": TAGLINE, + }, + ), + ( + {"x-elastic-product": "Elasticsearch"}, + { + "version": {"number": "7.14.0", "build_flavor": "default"}, + "tagline": TAGLINE, + }, + ), + ( + {"x-elastic-product": "Elasticsearch"}, + { + "version": {"number": "7.99.99", "build_flavor": "default"}, + "tagline": TAGLINE, + }, + ), + ( + {"x-elastic-product": "Elasticsearch"}, + { + "version": {"number": "8.0.0"}, + }, + ), + ], +) +def test_verify_elasticsearch_passes(headers, response): + assert verify_elasticsearch(headers, response) is None + + +@pytest.mark.parametrize( + ["headers", "data"], + [ + ( + {}, + '{"version":{"number":"6.99.0"},"tagline":"You Know, for Search"}', + ), + ( + {}, + '{"version":{"number":"7.13.0","build_flavor":"default"},"tagline":"You Know, for Search"}', + ), + ( + {"X-elastic-product": "Elasticsearch"}, + '{"version":{"number":"7.14.0","build_flavor":"default"},"tagline":"You Know, for Search"}', + ), + ], +) +def test_verify_elasticsearch(headers, data): + t = Transport( + [{"data": data, "headers": headers}], connection_class=DummyConnection + ) + t.perform_request("GET", "/_search") + assert t._verified_elasticsearch + + calls = t.connection_pool.connections[0].calls + _ = [call[1]["headers"].pop("x-elastic-client-meta") for call in calls] + + assert calls == [ + ( + ("GET", "/"), + { + "headers": { + "accept": "application/json", + }, + "timeout": None, + }, + ), + ( + ("GET", "/_search", None, None), + { + "headers": {}, + "ignore": (), + "timeout": None, + }, + ), + ] + + +@pytest.mark.parametrize( + "exception_cls", [AuthorizationException, AuthenticationException] +) +def test_verify_elasticsearch_skips_on_auth_errors(exception_cls): + t = Transport( + [{"exception": exception_cls(exception_cls.status_code)}], + connection_class=DummyConnection, + ) + + with pytest.warns(ElasticsearchWarning) as warns: + with pytest.raises(exception_cls): + t.perform_request( + "GET", + "/_search", + headers={"Authorization": "testme"}, + params={"request_timeout": 3}, + ) + + # Assert that a warning was raised due to security privileges + assert [str(w.message) for w in warns] == [ + "The client is unable to verify that the server is " + "Elasticsearch due security privileges on the server side" + ] + + # Assert that the cluster is "verified" + assert t._verified_elasticsearch + + # See that the headers were passed along to the "info" request made + calls = t.connection_pool.connections[0].calls + _ = [call[1]["headers"].pop("x-elastic-client-meta") for call in calls] + + assert calls == [ + ( + ("GET", "/"), + { + "headers": { + "accept": "application/json", + "authorization": "testme", + }, + "timeout": 3, + }, + ), + ( + ("GET", "/_search", {}, None), + { + "headers": { + "Authorization": "testme", + }, + "ignore": (), + "timeout": 3, + }, + ), + ] From 214cc766db6b8d23f626528c987fe6d6b0d75616 Mon Sep 17 00:00:00 2001 From: Seth Michael Larson Date: Tue, 29 Jun 2021 17:17:23 -0500 Subject: [PATCH 2/3] Verify a ConnectionError is reraised in that situation --- elasticsearch/_async/transport.py | 15 ++++++++++----- elasticsearch/transport.py | 13 ++++++++++--- test_elasticsearch/test_async/test_connection.py | 8 +++++++- test_elasticsearch/test_connection.py | 8 +++++++- 4 files changed, 34 insertions(+), 10 deletions(-) diff --git a/elasticsearch/_async/transport.py b/elasticsearch/_async/transport.py index d193a2832..fbb083d10 100644 --- a/elasticsearch/_async/transport.py +++ b/elasticsearch/_async/transport.py @@ -427,7 +427,7 @@ async def _do_verify_elasticsearch(self, headers, timeout): info_headers = {} info_response = {} - info_error = None + error = None for conn in chain(self.connection_pool.connections, self.seed_connections): try: @@ -455,15 +455,20 @@ async def _do_verify_elasticsearch(self, headers, timeout): "Elasticsearch due security privileges on the server side" ), ElasticsearchWarning, - stacklevel=3, + stacklevel=4, ) self._verified_elasticsearch = True return # This connection didn't work, we'll try another. - except (ConnectionError, SerializationError): - if info_error is None: - info_error = info_error + except (ConnectionError, SerializationError) as err: + if error is None: + error = err + + # If we received a connection error and weren't successful + # anywhere then we reraise the more appropriate error. + if error and not info_response: + raise error # Check the information we got back from the index request. _verify_elasticsearch(info_headers, info_response) diff --git a/elasticsearch/transport.py b/elasticsearch/transport.py index c0f7ff8f2..8cb1317ac 100644 --- a/elasticsearch/transport.py +++ b/elasticsearch/transport.py @@ -522,6 +522,7 @@ def _do_verify_elasticsearch(self, headers, timeout): info_headers = {} info_response = {} + error = None for conn in chain(self.connection_pool.connections, self.seed_connections): try: @@ -549,14 +550,20 @@ def _do_verify_elasticsearch(self, headers, timeout): "Elasticsearch due security privileges on the server side" ), ElasticsearchWarning, - stacklevel=3, + stacklevel=5, ) self._verified_elasticsearch = True return # This connection didn't work, we'll try another. - except (ConnectionError, SerializationError): - pass + except (ConnectionError, SerializationError) as err: + if error is None: + error = err + + # If we received a connection error and weren't successful + # anywhere then we reraise the more appropriate error. + if error and not info_response: + raise error # Check the information we got back from the index request. _verify_elasticsearch(info_headers, info_response) diff --git a/test_elasticsearch/test_async/test_connection.py b/test_elasticsearch/test_async/test_connection.py index d0646e1c7..1d7d1ae6b 100644 --- a/test_elasticsearch/test_async/test_connection.py +++ b/test_elasticsearch/test_async/test_connection.py @@ -28,7 +28,7 @@ from mock import patch from multidict import CIMultiDict -from elasticsearch import AIOHttpConnection, __versionstr__ +from elasticsearch import AIOHttpConnection, AsyncElasticsearch, __versionstr__ from elasticsearch.compat import reraise_exceptions from elasticsearch.exceptions import ConnectionError @@ -410,3 +410,9 @@ async def test_aiohttp_connection_error(self): conn = AIOHttpConnection("not.a.host.name") with pytest.raises(ConnectionError): await conn.perform_request("GET", "/") + + async def test_elasticsearch_connection_error(self): + es = AsyncElasticsearch("http://not.a.host.name") + + with pytest.raises(ConnectionError): + await es.search() diff --git a/test_elasticsearch/test_connection.py b/test_elasticsearch/test_connection.py index 2ade950b5..5a1f5449f 100644 --- a/test_elasticsearch/test_connection.py +++ b/test_elasticsearch/test_connection.py @@ -31,7 +31,7 @@ from requests.auth import AuthBase from urllib3._collections import HTTPHeaderDict -from elasticsearch import __versionstr__ +from elasticsearch import Elasticsearch, __versionstr__ from elasticsearch.compat import reraise_exceptions from elasticsearch.connection import ( Connection, @@ -1045,3 +1045,9 @@ def test_requests_connection_error(self): conn = RequestsHttpConnection("not.a.host.name") with pytest.raises(ConnectionError): conn.perform_request("GET", "/") + + def test_elasticsearch_connection_error(self): + es = Elasticsearch("http://not.a.host.name") + + with pytest.raises(ConnectionError): + es.search() From 89c5cb7630fec6f4a8c53908623772e7e6130989 Mon Sep 17 00:00:00 2001 From: Seth Michael Larson Date: Wed, 30 Jun 2021 10:29:10 -0500 Subject: [PATCH 3/3] Force verification to make at most one request --- elasticsearch/_async/transport.py | 37 ++++- elasticsearch/compat.py | 11 ++ elasticsearch/transport.py | 68 ++++++--- .../test_async/test_transport.py | 109 ++++++++++++++ test_elasticsearch/test_transport.py | 141 +++++++++++++++++- 5 files changed, 336 insertions(+), 30 deletions(-) diff --git a/elasticsearch/_async/transport.py b/elasticsearch/_async/transport.py index fbb083d10..3b969b1ea 100644 --- a/elasticsearch/_async/transport.py +++ b/elasticsearch/_async/transport.py @@ -27,6 +27,7 @@ ConnectionError, ConnectionTimeout, ElasticsearchWarning, + NotElasticsearchError, SerializationError, TransportError, ) @@ -117,6 +118,10 @@ async def _async_init(self): self.loop = get_running_loop() self.kwargs["loop"] = self.loop + # Set our 'verified_once' implementation to one that + # works with 'asyncio' instead of 'threading' + self._verified_once = Once() + # Now that we have a loop we can create all our HTTP connections... self.set_connections(self.hosts) self.seed_connections = list(self.connection_pool.connections[:]) @@ -332,8 +337,17 @@ async def perform_request(self, method, url, headers=None, params=None, body=Non ) # Before we make the actual API call we verify the Elasticsearch instance. - if not self._verified_elasticsearch: - await self._do_verify_elasticsearch(headers=headers, timeout=timeout) + if self._verified_elasticsearch is None: + await self._verified_once.call( + self._do_verify_elasticsearch, headers=headers, timeout=timeout + ) + + # If '_verified_elasticsearch' is False we know we're not connected to Elasticsearch. + if self._verified_elasticsearch is False: + raise NotElasticsearchError( + "The client noticed that the server is not Elasticsearch " + "and we do not support this unknown product" + ) for attempt in range(self.max_retries + 1): connection = self.get_connection() @@ -471,7 +485,20 @@ async def _do_verify_elasticsearch(self, headers, timeout): raise error # Check the information we got back from the index request. - _verify_elasticsearch(info_headers, info_response) + self._verified_elasticsearch = _verify_elasticsearch( + info_headers, info_response + ) + + +class Once: + """Simple class which forces an async function to only execute once.""" + + def __init__(self): + self._lock = asyncio.Lock() + self._called = False - # If we made it through the above call this config is verified. - self._verified_elasticsearch = True + async def call(self, func, *args, **kwargs): + async with self._lock: + if not self._called: + self._called = True + await func(*args, **kwargs) diff --git a/elasticsearch/compat.py b/elasticsearch/compat.py index 99425ce6c..912d2c72a 100644 --- a/elasticsearch/compat.py +++ b/elasticsearch/compat.py @@ -70,6 +70,17 @@ def to_bytes(x, encoding="ascii"): except (ImportError, AttributeError): pass +try: + from threading import Lock +except ImportError: # Python <3.7 isn't guaranteed to have threading support. + + class Lock: + def __enter__(self): + pass + + def __exit__(self, *_): + pass + __all__ = [ "string_types", diff --git a/elasticsearch/transport.py b/elasticsearch/transport.py index 8cb1317ac..6e5fb0124 100644 --- a/elasticsearch/transport.py +++ b/elasticsearch/transport.py @@ -22,6 +22,7 @@ from platform import python_version from ._version import __versionstr__ +from .compat import Lock from .connection import Urllib3HttpConnection from .connection_pool import ConnectionPool, DummyConnectionPool, EmptyConnectionPool from .exceptions import ( @@ -204,9 +205,22 @@ def __init__( if http_client_meta: self._client_meta += (http_client_meta,) - # Flag which is set after verifying that we're - # connected to Elasticsearch. - self._verified_elasticsearch = False + # Tri-state flag that describes what state the verification + # of whether we're connected to an Elasticsearch cluster or not. + # The three states are: + # - 'None': Means we've either not started the verification process + # or that the verification is in progress. '_verified_once' ensures + # that multiple requests don't kick off multiple verification processes. + # - 'True': Means we've verified that we're talking to Elasticsearch or + # that we can't rule out Elasticsearch due to auth issues. A warning + # will be raised if we receive 401/403. + # - 'False': Means we've discovered we're not talking to Elasticsearch, + # should raise an error in this case for every request. + self._verified_elasticsearch = None + + # Ensures that the ES verification request only fires once and that + # all requests block until this request returns back. + self._verified_once = Once() def add_connection(self, host): """ @@ -391,7 +405,17 @@ def perform_request(self, method, url, headers=None, params=None, body=None): ) # Before we make the actual API call we verify the Elasticsearch instance. - self._do_verify_elasticsearch(headers=headers, timeout=timeout) + if self._verified_elasticsearch is None: + self._verified_once.call( + self._do_verify_elasticsearch, headers=headers, timeout=timeout + ) + + # If '_verified_elasticsearch' is False we know we're not connected to Elasticsearch. + if self._verified_elasticsearch is False: + raise NotElasticsearchError( + "The client noticed that the server is not Elasticsearch " + "and we do not support this unknown product" + ) for attempt in range(self.max_retries + 1): connection = self.get_connection() @@ -513,7 +537,7 @@ def _do_verify_elasticsearch(self, headers, timeout): error we instead emit an 'ElasticsearchWarning'. """ # Product check has already been done, no need to do again. - if self._verified_elasticsearch: + if self._verified_elasticsearch is not None: return headers = {header.lower(): value for header, value in (headers or {}).items()} @@ -566,19 +590,16 @@ def _do_verify_elasticsearch(self, headers, timeout): raise error # Check the information we got back from the index request. - _verify_elasticsearch(info_headers, info_response) - - # If we made it through the above call this config is verified. - self._verified_elasticsearch = True + self._verified_elasticsearch = _verify_elasticsearch( + info_headers, info_response + ) def _verify_elasticsearch(headers, response): """Verifies that the server we're talking to is Elasticsearch. Does this by checking HTTP headers and the deserialized - response to the 'info' API. - - If there's a problem this function raises 'NotElasticsearchError' - otherwise doesn't do anything. + response to the 'info' API. Returns 'True' if we're verified + against Elasticsearch, 'False' otherwise. """ try: version = response.get("version", {}) @@ -616,7 +637,20 @@ def _verify_elasticsearch(headers, response): # 7.14+ and there's a bad 'X-Elastic-Product' HTTP header or ((7, 14, 0) <= version_number and bad_product_header) ): - raise NotElasticsearchError( - "The client noticed that the server is not Elasticsearch " - "and we do not support this unknown product" - ) + return False + + return True + + +class Once: + """Simple class which forces a function to only execute once.""" + + def __init__(self): + self._lock = Lock() + self._called = False + + def call(self, func, *args, **kwargs): + with self._lock: + if not self._called: + self._called = True + func(*args, **kwargs) diff --git a/test_elasticsearch/test_async/test_transport.py b/test_elasticsearch/test_async/test_transport.py index c55012f54..30db9401b 100644 --- a/test_elasticsearch/test_async/test_transport.py +++ b/test_elasticsearch/test_async/test_transport.py @@ -33,6 +33,7 @@ AuthorizationException, ConnectionError, ElasticsearchWarning, + NotElasticsearchError, TransportError, ) @@ -714,3 +715,111 @@ async def test_verify_elasticsearch_skips_on_auth_errors(self, exception_cls): }, ), ] + + async def test_multiple_requests_verify_elasticsearch_success(self, event_loop): + t = AsyncTransport( + [ + { + "data": '{"version":{"number":"7.13.0","build_flavor":"default"},"tagline":"You Know, for Search"}', + "delay": 1, + } + ], + connection_class=DummyConnection, + ) + + results = [] + completed_at = [] + + async def request_task(): + try: + results.append(await t.perform_request("GET", "/_search")) + except Exception as e: + results.append(e) + completed_at.append(event_loop.time()) + + # Execute a bunch of requests concurrently. + tasks = [] + start_time = event_loop.time() + for _ in range(10): + tasks.append(event_loop.create_task(request_task())) + await asyncio.gather(*tasks) + end_time = event_loop.time() + + # Exactly 10 results completed + assert len(results) == 10 + + # No errors in the results + assert all(isinstance(result, dict) for result in results) + + # Assert that this took longer than 2 seconds but less than 2.1 seconds + duration = end_time - start_time + assert 2 <= duration <= 2.1 + + # Assert that every result came after ~2 seconds, no fast completions. + assert all( + 2 <= completed_time - start_time <= 2.1 for completed_time in completed_at + ) + + # Assert that the cluster is "verified" + assert t._verified_elasticsearch + + # See that the first request is always 'GET /' for ES check + calls = t.connection_pool.connections[0].calls + assert calls[0][0] == ("GET", "/") + + # The rest of the requests are 'GET /_search' afterwards + assert all(call[0][:2] == ("GET", "/_search") for call in calls[1:]) + + async def test_multiple_requests_verify_elasticsearch_errors(self, event_loop): + t = AsyncTransport( + [ + { + "data": '{"version":{"number":"7.13.0","build_flavor":"default"},"tagline":"BAD TAGLINE"}', + "delay": 1, + } + ], + connection_class=DummyConnection, + ) + + results = [] + completed_at = [] + + async def request_task(): + try: + results.append(await t.perform_request("GET", "/_search")) + except Exception as e: + results.append(e) + completed_at.append(event_loop.time()) + + # Execute a bunch of requests concurrently. + tasks = [] + start_time = event_loop.time() + for _ in range(10): + tasks.append(event_loop.create_task(request_task())) + await asyncio.gather(*tasks) + end_time = event_loop.time() + + # Exactly 10 results completed + assert len(results) == 10 + + # All results were errors + assert all(isinstance(result, NotElasticsearchError) for result in results) + + # Assert that one request was made but not 2 requests. + duration = end_time - start_time + assert 1 <= duration <= 1.1 + + # Assert that every result came after ~1 seconds, no fast completions. + assert all( + 1 <= completed_time - start_time <= 1.1 for completed_time in completed_at + ) + + # Assert that the cluster is definitely not Elasticsearch + assert t._verified_elasticsearch is False + + # See that the first request is always 'GET /' for ES check + calls = t.connection_pool.connections[0].calls + assert calls[0][0] == ("GET", "/") + + # The rest of the requests are 'GET /_search' afterwards + assert all(call[0][:2] == ("GET", "/_search") for call in calls[1:]) diff --git a/test_elasticsearch/test_transport.py b/test_elasticsearch/test_transport.py index 6043222ba..17c308a81 100644 --- a/test_elasticsearch/test_transport.py +++ b/test_elasticsearch/test_transport.py @@ -46,11 +46,14 @@ def __init__(self, **kwargs): self.exception = kwargs.pop("exception", None) self.status, self.data = kwargs.pop("status", 200), kwargs.pop("data", "{}") self.headers = kwargs.pop("headers", {}) + self.delay = kwargs.pop("delay", None) self.calls = [] super(DummyConnection, self).__init__(**kwargs) def perform_request(self, *args, **kwargs): self.calls.append((args, kwargs)) + if self.delay is not None: + time.sleep(self.delay) if self.exception: raise self.exception return self.status, self.headers, self.data @@ -530,13 +533,7 @@ def test_sniffing_disabled_on_cloud_instances(self, sniff_hosts): ], ) def test_verify_elasticsearch_errors(headers, response): - with pytest.raises(NotElasticsearchError) as e: - verify_elasticsearch(headers, response) - - assert str(e.value) == ( - "The client noticed that the server is not Elasticsearch " - "and we do not support this unknown product" - ) + assert verify_elasticsearch(headers, response) is False @pytest.mark.parametrize( @@ -581,7 +578,7 @@ def test_verify_elasticsearch_errors(headers, response): ], ) def test_verify_elasticsearch_passes(headers, response): - assert verify_elasticsearch(headers, response) is None + assert verify_elasticsearch(headers, response) is True @pytest.mark.parametrize( @@ -685,3 +682,131 @@ def test_verify_elasticsearch_skips_on_auth_errors(exception_cls): }, ), ] + + +def test_multiple_requests_verify_elasticsearch_success(): + try: + import threading + except ImportError: + return pytest.skip("Requires the 'threading' module") + + t = Transport( + [ + { + "data": '{"version":{"number":"7.13.0","build_flavor":"default"},"tagline":"You Know, for Search"}', + "delay": 1, + } + ], + connection_class=DummyConnection, + ) + + results = [] + completed_at = [] + + class RequestThread(threading.Thread): + def run(self): + try: + results.append(t.perform_request("GET", "/_search")) + except Exception as e: + results.append(e) + completed_at.append(time.time()) + + # Execute a bunch of requests concurrently. + threads = [] + start_time = time.time() + for _ in range(10): + thread = RequestThread() + thread.start() + threads.append(thread) + for thread in threads: + thread.join() + end_time = time.time() + + # Exactly 10 results completed + assert len(results) == 10 + + # No errors in the results + assert all(isinstance(result, dict) for result in results) + + # Assert that this took longer than 2 seconds but less than 2.1 seconds + duration = end_time - start_time + assert 2 <= duration <= 2.1 + + # Assert that every result came after ~2 seconds, no fast completions. + assert all( + 2 <= completed_time - start_time <= 2.1 for completed_time in completed_at + ) + + # Assert that the cluster is "verified" + assert t._verified_elasticsearch + + # See that the first request is always 'GET /' for ES check + calls = t.connection_pool.connections[0].calls + assert calls[0][0] == ("GET", "/") + + # The rest of the requests are 'GET /_search' afterwards + assert all(call[0][:2] == ("GET", "/_search") for call in calls[1:]) + + +def test_multiple_requests_verify_elasticsearch_errors(): + try: + import threading + except ImportError: + return pytest.skip("Requires the 'threading' module") + + t = Transport( + [ + { + "data": '{"version":{"number":"7.13.0","build_flavor":"default"},"tagline":"BAD TAGLINE"}', + "delay": 1, + } + ], + connection_class=DummyConnection, + ) + + results = [] + completed_at = [] + + class RequestThread(threading.Thread): + def run(self): + try: + results.append(t.perform_request("GET", "/_search")) + except Exception as e: + results.append(e) + completed_at.append(time.time()) + + # Execute a bunch of requests concurrently. + threads = [] + start_time = time.time() + for _ in range(10): + thread = RequestThread() + thread.start() + threads.append(thread) + for thread in threads: + thread.join() + end_time = time.time() + + # Exactly 10 results completed + assert len(results) == 10 + + # All results were errors + assert all(isinstance(result, NotElasticsearchError) for result in results) + + # Assert that one request was made but not 2 requests. + duration = end_time - start_time + assert 1 <= duration <= 1.1 + + # Assert that every result came after ~1 seconds, no fast completions. + assert all( + 1 <= completed_time - start_time <= 1.1 for completed_time in completed_at + ) + + # Assert that the cluster is definitely not Elasticsearch + assert t._verified_elasticsearch is False + + # See that the first request is always 'GET /' for ES check + calls = t.connection_pool.connections[0].calls + assert calls[0][0] == ("GET", "/") + + # The rest of the requests are 'GET /_search' afterwards + assert all(call[0][:2] == ("GET", "/_search") for call in calls[1:])