Skip to content

Commit aaed0d4

Browse files
committed
Add tests for AIOHttpConnection
1 parent e7228fb commit aaed0d4

File tree

5 files changed

+323
-7
lines changed

5 files changed

+323
-7
lines changed

elasticsearch/_async/client/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,9 @@ async def __aenter__(self):
229229
return self
230230

231231
async def __aexit__(self, *_):
232+
await self.close()
233+
234+
async def close(self):
232235
await self.transport.close()
233236

234237
# AUTO-GENERATED-API-DEFINITIONS #

elasticsearch/_async/http_aiohttp.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import asyncio
66
import ssl
77
import os
8+
import urllib3
89
import warnings
910

1011
import aiohttp
@@ -25,6 +26,7 @@
2526
# This is used to detect if a user is passing in a value
2627
# for SSL kwargs if also using an SSLContext.
2728
VERIFY_CERTS_DEFAULT = object()
29+
SSL_SHOW_WARN_DEFAULT = object()
2830

2931
CA_CERTS = None
3032

@@ -43,7 +45,8 @@ def __init__(
4345
port=None,
4446
http_auth=None,
4547
use_ssl=False,
46-
verify_certs=True,
48+
verify_certs=VERIFY_CERTS_DEFAULT,
49+
ssl_show_warn=SSL_SHOW_WARN_DEFAULT,
4750
ca_certs=None,
4851
client_cert=None,
4952
client_key=None,
@@ -74,15 +77,14 @@ def __init__(
7477
)
7578

7679
if http_auth is not None:
77-
if isinstance(http_auth, str):
78-
http_auth = tuple(http_auth.split(":", 1))
79-
8080
if isinstance(http_auth, (tuple, list)):
81-
http_auth = aiohttp.BasicAuth(*http_auth)
81+
http_auth = ":".join(http_auth)
82+
self.headers.update(urllib3.make_headers(basic_auth=http_auth))
8283

8384
# if providing an SSL context, raise error if any other SSL related flag is used
8485
if ssl_context and (
8586
(verify_certs is not VERIFY_CERTS_DEFAULT)
87+
or (ssl_show_warn is not SSL_SHOW_WARN_DEFAULT)
8688
or ca_certs
8789
or client_cert
8890
or client_key
@@ -100,6 +102,8 @@ def __init__(
100102
# values if not using an SSLContext.
101103
if verify_certs is VERIFY_CERTS_DEFAULT:
102104
verify_certs = True
105+
if ssl_show_warn is SSL_SHOW_WARN_DEFAULT:
106+
ssl_show_warn = True
103107

104108
ca_certs = CA_CERTS if ca_certs is None else ca_certs
105109
if verify_certs:
@@ -109,6 +113,13 @@ def __init__(
109113
"validation. Either pass them in using the ca_certs parameter or "
110114
"install certifi to use it automatically."
111115
)
116+
else:
117+
if ssl_show_warn:
118+
warnings.warn(
119+
"Connecting to %s using SSL with verify_certs=False is insecure."
120+
% self.host
121+
)
122+
112123
if os.path.isfile(ca_certs):
113124
ssl_context.load_verify_locations(cafile=ca_certs)
114125
elif os.path.isdir(ca_certs):
@@ -180,7 +191,7 @@ async def perform_request(
180191
await response.release()
181192
raw_data = ""
182193
else:
183-
raw_data = await response.text()
194+
raw_data = (await response.read()).decode("utf-8", "surrogatepass")
184195
duration = self.loop.time() - start
185196

186197
# We want to reraise a cancellation.
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
# Licensed to Elasticsearch B.V under one or more agreements.
2+
# Elasticsearch B.V licenses this file to you under the Apache 2.0 License.
3+
# See the LICENSE file in the project root for more information
4+
Lines changed: 298 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,298 @@
1+
# -*- coding: utf-8 -*-
2+
# Licensed to Elasticsearch B.V under one or more agreements.
3+
# Elasticsearch B.V licenses this file to you under the Apache 2.0 License.
4+
# See the LICENSE file in the project root for more information
5+
6+
import re
7+
import ssl
8+
from mock import Mock, patch
9+
import warnings
10+
from platform import python_version
11+
12+
from elasticsearch.exceptions import (
13+
TransportError,
14+
ConflictError,
15+
RequestError,
16+
NotFoundError,
17+
)
18+
from elasticsearch import AIOHttpConnection
19+
from elasticsearch import __versionstr__
20+
from ..test_cases import TestCase, SkipTest
21+
22+
23+
class TestAIOHttpConnection(TestCase):
24+
async def _get_mock_connection(self, connection_params={}, response_body=b"{}"):
25+
con = AIOHttpConnection(**connection_params)
26+
27+
async def _dummy_request(*args, **kwargs):
28+
async def read():
29+
return response_body
30+
31+
dummy_response = Mock()
32+
dummy_response.headers = {}
33+
dummy_response.status_code = 200
34+
dummy_response.read = read
35+
_dummy_request.call_args = (args, kwargs)
36+
return dummy_response
37+
38+
con.session.request = _dummy_request
39+
return con
40+
41+
async def test_ssl_context(self):
42+
try:
43+
context = ssl.create_default_context()
44+
except AttributeError:
45+
# if create_default_context raises an AttributeError Exception
46+
# it means SSLContext is not available for that version of python
47+
# and we should skip this test.
48+
raise SkipTest(
49+
"Test test_ssl_context is skipped cause SSLContext is not available for this version of ptyhon"
50+
)
51+
52+
con = AIOHttpConnection(use_ssl=True, ssl_context=context)
53+
await con._create_aiohttp_session()
54+
self.assertTrue(con.use_ssl)
55+
self.assertEqual(con.session.ssl_context, context)
56+
57+
def test_opaque_id(self):
58+
con = AIOHttpConnection(opaque_id="app-1")
59+
self.assertEqual(con.headers["x-opaque-id"], "app-1")
60+
61+
def test_http_cloud_id(self):
62+
con = AIOHttpConnection(
63+
cloud_id="cluster:dXMtZWFzdC0xLmF3cy5mb3VuZC5pbyQ0ZmE4ODIxZTc1NjM0MDMyYmVkMWNmMjIxMTBlMmY5NyQ0ZmE4ODIxZTc1NjM0MDMyYmVkMWNmMjIxMTBlMmY5Ng=="
64+
)
65+
self.assertTrue(con.use_ssl)
66+
self.assertEqual(
67+
con.host, "https://4fa8821e75634032bed1cf22110e2f97.us-east-1.aws.found.io"
68+
)
69+
self.assertEqual(con.port, None)
70+
self.assertEqual(
71+
con.hostname, "4fa8821e75634032bed1cf22110e2f97.us-east-1.aws.found.io"
72+
)
73+
self.assertTrue(con.http_compress)
74+
75+
con = AIOHttpConnection(
76+
cloud_id="cluster:dXMtZWFzdC0xLmF3cy5mb3VuZC5pbyQ0ZmE4ODIxZTc1NjM0MDMyYmVkMWNmMjIxMTBlMmY5NyQ0ZmE4ODIxZTc1NjM0MDMyYmVkMWNmMjIxMTBlMmY5Ng==",
77+
port=9243,
78+
)
79+
self.assertEqual(
80+
con.host,
81+
"https://4fa8821e75634032bed1cf22110e2f97.us-east-1.aws.found.io:9243",
82+
)
83+
self.assertEqual(con.port, 9243)
84+
self.assertEqual(
85+
con.hostname, "4fa8821e75634032bed1cf22110e2f97.us-east-1.aws.found.io"
86+
)
87+
88+
def test_api_key_auth(self):
89+
# test with tuple
90+
con = AIOHttpConnection(
91+
cloud_id="cluster:dXMtZWFzdC0xLmF3cy5mb3VuZC5pbyQ0ZmE4ODIxZTc1NjM0MDMyYmVkMWNmMjIxMTBlMmY5NyQ0ZmE4ODIxZTc1NjM0MDMyYmVkMWNmMjIxMTBlMmY5Ng==",
92+
api_key=("elastic", "changeme1"),
93+
)
94+
self.assertEqual(
95+
con.headers["authorization"], "ApiKey ZWxhc3RpYzpjaGFuZ2VtZTE="
96+
)
97+
self.assertEqual(
98+
con.host, "https://4fa8821e75634032bed1cf22110e2f97.us-east-1.aws.found.io"
99+
)
100+
101+
# test with base64 encoded string
102+
con = AIOHttpConnection(
103+
cloud_id="cluster:dXMtZWFzdC0xLmF3cy5mb3VuZC5pbyQ0ZmE4ODIxZTc1NjM0MDMyYmVkMWNmMjIxMTBlMmY5NyQ0ZmE4ODIxZTc1NjM0MDMyYmVkMWNmMjIxMTBlMmY5Ng==",
104+
api_key="ZWxhc3RpYzpjaGFuZ2VtZTI=",
105+
)
106+
self.assertEqual(
107+
con.headers["authorization"], "ApiKey ZWxhc3RpYzpjaGFuZ2VtZTI="
108+
)
109+
self.assertEqual(
110+
con.host, "https://4fa8821e75634032bed1cf22110e2f97.us-east-1.aws.found.io"
111+
)
112+
113+
async def test_no_http_compression(self):
114+
con = await self._get_mock_connection()
115+
self.assertFalse(con.http_compress)
116+
self.assertNotIn("accept-encoding", con.headers)
117+
118+
await con.perform_request("GET", "/")
119+
120+
(_, _, req_body), kwargs = con.pool.urlopen.call_args
121+
122+
self.assertFalse(req_body)
123+
self.assertNotIn("accept-encoding", kwargs["headers"])
124+
self.assertNotIn("content-encoding", kwargs["headers"])
125+
126+
async def test_http_compression(self):
127+
con = await self._get_mock_connection({"http_compress": True})
128+
self.assertTrue(con.http_compress)
129+
self.assertEqual(con.headers["accept-encoding"], "gzip,deflate")
130+
131+
# 'content-encoding' shouldn't be set at a connection level.
132+
# Should be applied only if the request is sent with a body.
133+
self.assertNotIn("content-encoding", con.headers)
134+
135+
await con.perform_request("GET", "/", body=b"{}")
136+
137+
(_, _, req_body), kwargs = con.pool.urlopen.call_args
138+
139+
self.assertEqual(gzip_decompress(req_body), b"{}")
140+
self.assertEqual(kwargs["headers"]["accept-encoding"], "gzip,deflate")
141+
self.assertEqual(kwargs["headers"]["content-encoding"], "gzip")
142+
143+
await con.perform_request("GET", "/")
144+
145+
(_, _, req_body), kwargs = con.pool.urlopen.call_args
146+
147+
self.assertFalse(req_body)
148+
self.assertEqual(kwargs["headers"]["accept-encoding"], "gzip,deflate")
149+
self.assertNotIn("content-encoding", kwargs["headers"])
150+
151+
def test_cloud_id_http_compress_override(self):
152+
# 'http_compress' will be 'True' by default for connections with
153+
# 'cloud_id' set but should prioritize user-defined values.
154+
con = AIOHttpConnection(
155+
cloud_id="cluster:dXMtZWFzdC0xLmF3cy5mb3VuZC5pbyQ0ZmE4ODIxZTc1NjM0MDMyYmVkMWNmMjIxMTBlMmY5NyQ0ZmE4ODIxZTc1NjM0MDMyYmVkMWNmMjIxMTBlMmY5Ng==",
156+
)
157+
self.assertEqual(con.http_compress, True)
158+
159+
con = AIOHttpConnection(
160+
cloud_id="cluster:dXMtZWFzdC0xLmF3cy5mb3VuZC5pbyQ0ZmE4ODIxZTc1NjM0MDMyYmVkMWNmMjIxMTBlMmY5NyQ0ZmE4ODIxZTc1NjM0MDMyYmVkMWNmMjIxMTBlMmY5Ng==",
161+
http_compress=False,
162+
)
163+
self.assertEqual(con.http_compress, False)
164+
165+
con = AIOHttpConnection(
166+
cloud_id="cluster:dXMtZWFzdC0xLmF3cy5mb3VuZC5pbyQ0ZmE4ODIxZTc1NjM0MDMyYmVkMWNmMjIxMTBlMmY5NyQ0ZmE4ODIxZTc1NjM0MDMyYmVkMWNmMjIxMTBlMmY5Ng==",
167+
http_compress=True,
168+
)
169+
self.assertEqual(con.http_compress, True)
170+
171+
def test_default_user_agent(self):
172+
con = AIOHttpConnection()
173+
self.assertEqual(
174+
con._get_default_user_agent(),
175+
"elasticsearch-py/%s (Python %s)" % (__versionstr__, python_version()),
176+
)
177+
178+
def test_timeout_set(self):
179+
con = AIOHttpConnection(timeout=42)
180+
self.assertEqual(42, con.timeout)
181+
182+
def test_keep_alive_is_on_by_default(self):
183+
con = AIOHttpConnection()
184+
self.assertEqual(
185+
{
186+
"connection": "keep-alive",
187+
"content-type": "application/json",
188+
"user-agent": con._get_default_user_agent(),
189+
},
190+
con.headers,
191+
)
192+
193+
def test_http_auth(self):
194+
con = AIOHttpConnection(http_auth="username:secret")
195+
self.assertEqual(
196+
{
197+
"authorization": "Basic dXNlcm5hbWU6c2VjcmV0",
198+
"connection": "keep-alive",
199+
"content-type": "application/json",
200+
"user-agent": con._get_default_user_agent(),
201+
},
202+
con.headers,
203+
)
204+
205+
def test_http_auth_tuple(self):
206+
con = AIOHttpConnection(http_auth=("username", "secret"))
207+
self.assertEqual(
208+
{
209+
"authorization": "Basic dXNlcm5hbWU6c2VjcmV0",
210+
"content-type": "application/json",
211+
"connection": "keep-alive",
212+
"user-agent": con._get_default_user_agent(),
213+
},
214+
con.headers,
215+
)
216+
217+
def test_http_auth_list(self):
218+
con = AIOHttpConnection(http_auth=["username", "secret"])
219+
self.assertEqual(
220+
{
221+
"authorization": "Basic dXNlcm5hbWU6c2VjcmV0",
222+
"content-type": "application/json",
223+
"connection": "keep-alive",
224+
"user-agent": con._get_default_user_agent(),
225+
},
226+
con.headers,
227+
)
228+
229+
def test_uses_https_if_verify_certs_is_off(self):
230+
with warnings.catch_warnings(record=True) as w:
231+
con = AIOHttpConnection(use_ssl=True, verify_certs=False)
232+
self.assertEqual(1, len(w))
233+
self.assertEqual(
234+
"Connecting to https://localhost:9200 using SSL with verify_certs=False is insecure.",
235+
str(w[0].message),
236+
)
237+
238+
self.assertTrue(con.use_ssl)
239+
self.assertEqual(con.scheme, "https")
240+
self.assertEqual(con.host, "https://localhost:9200")
241+
242+
def nowarn_when_test_uses_https_if_verify_certs_is_off(self):
243+
with warnings.catch_warnings(record=True) as w:
244+
con = Urllib3HttpConnection(
245+
use_ssl=True, verify_certs=False, ssl_show_warn=False
246+
)
247+
self.assertEqual(0, len(w))
248+
249+
self.assertIsInstance(con.pool, urllib3.HTTPSConnectionPool)
250+
251+
def test_doesnt_use_https_if_not_specified(self):
252+
con = AIOHttpConnection()
253+
self.assertFalse(con.use_ssl)
254+
255+
def test_no_warning_when_using_ssl_context(self):
256+
ctx = ssl.create_default_context()
257+
with warnings.catch_warnings(record=True) as w:
258+
AIOHttpConnection(ssl_context=ctx)
259+
self.assertEqual(0, len(w), str([x.message for x in w]))
260+
261+
def test_warns_if_using_non_default_ssl_kwargs_with_ssl_context(self):
262+
for kwargs in (
263+
{"ssl_show_warn": False},
264+
{"ssl_show_warn": True},
265+
{"verify_certs": True},
266+
{"verify_certs": False},
267+
{"ca_certs": "/path/to/certs"},
268+
{"ssl_show_warn": True, "ca_certs": "/path/to/certs"},
269+
):
270+
kwargs["ssl_context"] = ssl.create_default_context()
271+
272+
with warnings.catch_warnings(record=True) as w:
273+
warnings.simplefilter("always")
274+
275+
AIOHttpConnection(**kwargs)
276+
277+
self.assertEqual(1, len(w))
278+
self.assertEqual(
279+
"When using `ssl_context`, all other SSL related kwargs are ignored",
280+
str(w[0].message),
281+
)
282+
283+
@patch("elasticsearch.connection.base.logger")
284+
async def test_uncompressed_body_logged(self, logger):
285+
con = await self._get_mock_connection(connection_params={"http_compress": True})
286+
await con.perform_request("GET", "/", body=b'{"example": "body"}')
287+
288+
self.assertEqual(2, logger.debug.call_count)
289+
req, resp = logger.debug.call_args_list
290+
291+
self.assertEqual('> {"example": "body"}', req[0][0] % req[0][1:])
292+
self.assertEqual("< {}", resp[0][0] % resp[0][1:])
293+
294+
async def test_surrogatepass_into_bytes(self):
295+
buf = b"\xe4\xbd\xa0\xe5\xa5\xbd\xed\xa9\xaa"
296+
con = await self._get_mock_connection(response_body=buf)
297+
status, headers, data = await con.perform_request("GET", "/")
298+
self.assertEqual(u"你好\uda6a", data)
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
{% extends "base" %}
22
{% block request %}
3-
return self.transport.perform_request("{{ api.method }}", "/_cluster/stats" if node_id in SKIP_IN_PATH else _make_path("_cluster/stats/nodes", node_id), params=params, headers=headers)
3+
return self.transport.perform_request("{{ api.method }}", "/_cluster/stats" if node_id in SKIP_IN_PATH else _make_path("_cluster", "stats", "nodes", node_id), params=params, headers=headers)
44
{% endblock%}
55

0 commit comments

Comments
 (0)