Skip to content

Commit 4862717

Browse files
committed
Add test suite for async API
1 parent 7f6ccbe commit 4862717

File tree

5 files changed

+316
-1
lines changed

5 files changed

+316
-1
lines changed
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# Licensed to Elasticsearch B.V under one or more agreements.
2+
# Elasticsearch B.V licenses this file to you under the Apache 2.0 License.
3+
# See the LICENSE file in the project root for more information
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
# Licensed to Elasticsearch B.V under one or more agreements.
2+
# Elasticsearch B.V licenses this file to you under the Apache 2.0 License.
3+
# See the LICENSE file in the project root for more information
4+
5+
import os
6+
import pytest
7+
import asyncio
8+
import elasticsearch
9+
10+
pytestmark = pytest.mark.asyncio
11+
12+
13+
@pytest.fixture(scope="function")
14+
async def async_client():
15+
client = None
16+
try:
17+
if not hasattr(elasticsearch, "AsyncElasticsearch"):
18+
pytest.skip("test requires 'AsyncElasticsearch'")
19+
20+
kw = {
21+
"timeout": 30,
22+
"ca_certs": ".ci/certs/ca.pem",
23+
"connection_class": elasticsearch.AIOHttpConnection,
24+
}
25+
26+
client = elasticsearch.AsyncElasticsearch(
27+
[os.environ.get("ELASTICSEARCH_HOST", {})], **kw
28+
)
29+
30+
# wait for yellow status
31+
for _ in range(100):
32+
try:
33+
await client.cluster.health(wait_for_status="yellow")
34+
break
35+
except ConnectionError:
36+
await asyncio.sleep(0.1)
37+
else:
38+
# timeout
39+
pytest.skip("Elasticsearch failed to start.")
40+
41+
yield client
42+
43+
finally:
44+
if client:
45+
version = tuple(
46+
[
47+
int(x) if x.isdigit() else 999
48+
for x in (await client.info())["version"]["number"].split(".")
49+
]
50+
)
51+
52+
expand_wildcards = ["open", "closed"]
53+
if version >= (7, 7):
54+
expand_wildcards.append("hidden")
55+
56+
await client.indices.delete(
57+
index="*", ignore=404, expand_wildcards=expand_wildcards
58+
)
59+
await client.indices.delete_template(name="*", ignore=404)
60+
await client.indices.delete_index_template(name="*", ignore=404)
61+
await client.close()
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
# -*- coding: utf-8 -*-
2+
# Licensed to Elasticsearch B.V under one or more agreements.
3+
# Elasticsearch B.V licenses this file to you under the Apache 2.0 License.
4+
# See the LICENSE file in the project root for more information
5+
6+
from __future__ import unicode_literals
7+
import pytest
8+
9+
pytestmark = pytest.mark.asyncio
10+
11+
12+
class TestUnicode:
13+
async def test_indices_analyze(self, async_client):
14+
await async_client.indices.analyze(body='{"text": "привет"}')
15+
16+
17+
class TestBulk:
18+
async def test_bulk_works_with_string_body(self, async_client):
19+
docs = '{ "index" : { "_index" : "bulk_test_index", "_id" : "1" } }\n{"answer": 42}'
20+
response = await async_client.bulk(body=docs)
21+
22+
assert response["errors"] is False
23+
assert len(response["items"]) == 1
24+
25+
async def test_bulk_works_with_bytestring_body(self, async_client):
26+
docs = b'{ "index" : { "_index" : "bulk_test_index", "_id" : "2" } }\n{"answer": 42}'
27+
response = await async_client.bulk(body=docs)
28+
29+
assert response["errors"] is False
30+
assert len(response["items"]) == 1
Lines changed: 211 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,211 @@
1+
# Licensed to Elasticsearch B.V under one or more agreements.
2+
# Elasticsearch B.V licenses this file to you under the Apache 2.0 License.
3+
# See the LICENSE file in the project root for more information
4+
5+
"""
6+
Dynamically generated set of TestCases based on set of yaml files decribing
7+
some integration tests. These files are shared among all official Elasticsearch
8+
clients.
9+
"""
10+
import pytest
11+
from shutil import rmtree
12+
import warnings
13+
import inspect
14+
15+
from elasticsearch import RequestError, ElasticsearchDeprecationWarning
16+
from elasticsearch.helpers.test import _get_version
17+
from ...test_server.test_rest_api_spec import (
18+
YamlRunner,
19+
YAML_TEST_SPECS,
20+
InvalidActionType,
21+
RUN_ASYNC_REST_API_TESTS,
22+
PARAMS_RENAMES,
23+
)
24+
25+
pytestmark = pytest.mark.asyncio
26+
27+
XPACK_FEATURES = None
28+
ES_VERSION = None
29+
30+
31+
async def await_if_coro(x):
32+
if inspect.iscoroutine(x):
33+
return await x
34+
return x
35+
36+
37+
class AsyncYamlRunner(YamlRunner):
38+
async def setup(self):
39+
if self._setup_code:
40+
await self.run_code(self._setup_code)
41+
42+
async def teardown(self):
43+
if self._teardown_code:
44+
await self.run_code(self._teardown_code)
45+
46+
for repo, definition in (
47+
await self.client.snapshot.get_repository(repository="_all")
48+
).items():
49+
await self.client.snapshot.delete_repository(repository=repo)
50+
if definition["type"] == "fs":
51+
rmtree(
52+
"/tmp/%s" % definition["settings"]["location"], ignore_errors=True
53+
)
54+
55+
# stop and remove all ML stuff
56+
if await self._feature_enabled("ml"):
57+
await self.client.ml.stop_datafeed(datafeed_id="*", force=True)
58+
for feed in (await self.client.ml.get_datafeeds(datafeed_id="*"))[
59+
"datafeeds"
60+
]:
61+
await self.client.ml.delete_datafeed(datafeed_id=feed["datafeed_id"])
62+
63+
await self.client.ml.close_job(job_id="*", force=True)
64+
for job in (await self.client.ml.get_jobs(job_id="*"))["jobs"]:
65+
await self.client.ml.delete_job(
66+
job_id=job["job_id"], wait_for_completion=True, force=True
67+
)
68+
69+
# stop and remove all Rollup jobs
70+
if await self._feature_enabled("rollup"):
71+
for rollup in (await self.client.rollup.get_jobs(id="*"))["jobs"]:
72+
await self.client.rollup.stop_job(
73+
id=rollup["config"]["id"], wait_for_completion=True
74+
)
75+
await self.client.rollup.delete_job(id=rollup["config"]["id"])
76+
77+
async def es_version(self):
78+
global ES_VERSION
79+
if ES_VERSION is None:
80+
version_string = (await self.client.info())["version"]["number"]
81+
if "." not in version_string:
82+
return ()
83+
version = version_string.strip().split(".")
84+
ES_VERSION = tuple(int(v) if v.isdigit() else 999 for v in version)
85+
return ES_VERSION
86+
87+
async def run(self):
88+
try:
89+
await self.setup()
90+
await self.run_code(self._run_code)
91+
finally:
92+
await self.teardown()
93+
94+
async def run_code(self, test):
95+
""" Execute an instruction based on it's type. """
96+
print(test)
97+
for action in test:
98+
assert len(action) == 1
99+
action_type, action = list(action.items())[0]
100+
101+
if hasattr(self, "run_" + action_type):
102+
await await_if_coro(getattr(self, "run_" + action_type)(action))
103+
else:
104+
raise InvalidActionType(action_type)
105+
106+
async def run_do(self, action):
107+
api = self.client
108+
headers = action.pop("headers", None)
109+
catch = action.pop("catch", None)
110+
warn = action.pop("warnings", ())
111+
allowed_warnings = action.pop("allowed_warnings", ())
112+
assert len(action) == 1
113+
114+
method, args = list(action.items())[0]
115+
args["headers"] = headers
116+
117+
# locate api endpoint
118+
for m in method.split("."):
119+
assert hasattr(api, m)
120+
api = getattr(api, m)
121+
122+
# some parameters had to be renamed to not clash with python builtins,
123+
# compensate
124+
for k in PARAMS_RENAMES:
125+
if k in args:
126+
args[PARAMS_RENAMES[k]] = args.pop(k)
127+
128+
# resolve vars
129+
for k in args:
130+
args[k] = self._resolve(args[k])
131+
132+
warnings.simplefilter("always", category=ElasticsearchDeprecationWarning)
133+
with warnings.catch_warnings(record=True) as caught_warnings:
134+
try:
135+
self.last_response = await api(**args)
136+
except Exception as e:
137+
if not catch:
138+
raise
139+
self.run_catch(catch, e)
140+
else:
141+
if catch:
142+
raise AssertionError(
143+
"Failed to catch %r in %r." % (catch, self.last_response)
144+
)
145+
146+
# Filter out warnings raised by other components.
147+
caught_warnings = [
148+
str(w.message)
149+
for w in caught_warnings
150+
if w.category == ElasticsearchDeprecationWarning
151+
and str(w.message) not in allowed_warnings
152+
]
153+
154+
# Sorting removes the issue with order raised. We only care about
155+
# if all warnings are raised in the single API call.
156+
if warn and sorted(warn) != sorted(caught_warnings):
157+
raise AssertionError(
158+
"Expected warnings not equal to actual warnings: expected=%r actual=%r"
159+
% (warn, caught_warnings)
160+
)
161+
162+
async def run_skip(self, skip):
163+
global IMPLEMENTED_FEATURES
164+
165+
if "features" in skip:
166+
features = skip["features"]
167+
if not isinstance(features, (tuple, list)):
168+
features = [features]
169+
for feature in features:
170+
if feature in IMPLEMENTED_FEATURES:
171+
continue
172+
pytest.skip("feature '%s' is not supported" % feature)
173+
174+
if "version" in skip:
175+
version, reason = skip["version"], skip["reason"]
176+
if version == "all":
177+
pytest.skip(reason)
178+
min_version, max_version = version.split("-")
179+
min_version = _get_version(min_version) or (0,)
180+
max_version = _get_version(max_version) or (999,)
181+
if min_version <= (await self.es_version()) <= max_version:
182+
pytest.skip(reason)
183+
184+
async def _feature_enabled(self, name):
185+
global XPACK_FEATURES, IMPLEMENTED_FEATURES
186+
if XPACK_FEATURES is None:
187+
try:
188+
xinfo = await self.client.xpack.info()
189+
XPACK_FEATURES = set(
190+
f for f in xinfo["features"] if xinfo["features"][f]["enabled"]
191+
)
192+
IMPLEMENTED_FEATURES.add("xpack")
193+
except RequestError:
194+
XPACK_FEATURES = set()
195+
IMPLEMENTED_FEATURES.add("no_xpack")
196+
return name in XPACK_FEATURES
197+
198+
199+
@pytest.fixture(scope="function")
200+
def async_runner(async_client):
201+
return AsyncYamlRunner(async_client)
202+
203+
204+
@pytest.mark.parametrize("test_spec", YAML_TEST_SPECS)
205+
async def test_rest_api_spec(test_spec, async_runner):
206+
if not RUN_ASYNC_REST_API_TESTS:
207+
pytest.skip("Skipped running async REST API tests")
208+
if test_spec.get("skip", False):
209+
pytest.skip("Manually skipped in 'SKIP_TESTS'")
210+
async_runner.use_spec(test_spec)
211+
await async_runner.run()

test_elasticsearch/test_server/test_rest_api_spec.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,9 @@
77
some integration tests. These files are shared among all official Elasticsearch
88
clients.
99
"""
10+
import sys
1011
import re
12+
import os
1113
from os import walk, environ
1214
from os.path import exists, join, dirname, pardir, relpath
1315
import yaml
@@ -58,11 +60,17 @@
5860
"indices/put_template/10_basic[4]",
5961
# depends on order of response JSON which is random
6062
"indices/simulate_index_template/10_basic[1]",
63+
# body: null? body is {}
64+
"indices/simulate_index_template/10_basic[2]",
6165
}
6266

6367

6468
XPACK_FEATURES = None
6569
ES_VERSION = None
70+
RUN_ASYNC_REST_API_TESTS = (
71+
sys.version_info >= (3, 6)
72+
and os.environ.get("PYTHON_CONNECTION_CLASS") == "RequestsHttpConnection"
73+
)
6674

6775

6876
class YamlRunner:
@@ -78,7 +86,7 @@ def __init__(self, client):
7886
def use_spec(self, test_spec):
7987
self._setup_code = test_spec.pop("setup", None)
8088
self._run_code = test_spec.pop("run", None)
81-
self._teardown_code = test_spec.pop("teardown")
89+
self._teardown_code = test_spec.pop("teardown", None)
8290

8391
def setup(self):
8492
if self._setup_code:
@@ -417,6 +425,8 @@ def sync_runner(sync_client):
417425

418426
@pytest.mark.parametrize("test_spec", YAML_TEST_SPECS)
419427
def test_rest_api_spec(test_spec, sync_runner):
428+
if RUN_ASYNC_REST_API_TESTS:
429+
pytest.skip("Skipped running sync REST API tests")
420430
if test_spec.get("skip", False):
421431
pytest.skip("Manually skipped in 'SKIP_TESTS'")
422432
sync_runner.use_spec(test_spec)

0 commit comments

Comments
 (0)