Skip to content

Commit a979c9c

Browse files
committed
Add test suite for async API
1 parent 7f6ccbe commit a979c9c

File tree

5 files changed

+315
-1
lines changed

5 files changed

+315
-1
lines changed
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# Licensed to Elasticsearch B.V under one or more agreements.
2+
# Elasticsearch B.V licenses this file to you under the Apache 2.0 License.
3+
# See the LICENSE file in the project root for more information
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
# Licensed to Elasticsearch B.V under one or more agreements.
2+
# Elasticsearch B.V licenses this file to you under the Apache 2.0 License.
3+
# See the LICENSE file in the project root for more information
4+
5+
import os
6+
import pytest
7+
import asyncio
8+
import elasticsearch
9+
10+
pytestmark = pytest.mark.asyncio
11+
12+
13+
@pytest.fixture(scope="function")
14+
async def async_client():
15+
client = None
16+
try:
17+
if not hasattr(elasticsearch, "AsyncElasticsearch"):
18+
pytest.skip("test requires 'AsyncElasticsearch'")
19+
20+
kw = {
21+
"timeout": 30,
22+
"ca_certs": ".ci/certs/ca.pem",
23+
"connection_class": elasticsearch.AIOHttpConnection,
24+
}
25+
26+
client = elasticsearch.AsyncElasticsearch(
27+
[os.environ.get("ELASTICSEARCH_HOST", {})], **kw
28+
)
29+
30+
# wait for yellow status
31+
for _ in range(100):
32+
try:
33+
await client.cluster.health(wait_for_status="yellow")
34+
break
35+
except ConnectionError:
36+
await asyncio.sleep(0.1)
37+
else:
38+
# timeout
39+
pytest.skip("Elasticsearch failed to start.")
40+
41+
yield client
42+
43+
finally:
44+
if client:
45+
version = tuple(
46+
[
47+
int(x) if x.isdigit() else 999
48+
for x in (await client.info())["version"]["number"].split(".")
49+
]
50+
)
51+
52+
expand_wildcards = ["open", "closed"]
53+
if version >= (7, 7):
54+
expand_wildcards.append("hidden")
55+
56+
await client.indices.delete(
57+
index="*", ignore=404, expand_wildcards=expand_wildcards
58+
)
59+
await client.indices.delete_template(name="*", ignore=404)
60+
await client.indices.delete_index_template(name="*", ignore=404)
61+
await client.close()
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
# -*- coding: utf-8 -*-
2+
# Licensed to Elasticsearch B.V under one or more agreements.
3+
# Elasticsearch B.V licenses this file to you under the Apache 2.0 License.
4+
# See the LICENSE file in the project root for more information
5+
6+
from __future__ import unicode_literals
7+
import pytest
8+
9+
pytestmark = pytest.mark.asyncio
10+
11+
12+
class TestUnicode:
13+
async def test_indices_analyze(self, async_client):
14+
await async_client.indices.analyze(body='{"text": "привет"}')
15+
16+
17+
class TestBulk:
18+
async def test_bulk_works_with_string_body(self, async_client):
19+
docs = '{ "index" : { "_index" : "bulk_test_index", "_id" : "1" } }\n{"answer": 42}'
20+
response = await async_client.bulk(body=docs)
21+
22+
assert response["errors"] is False
23+
assert len(response["items"]) == 1
24+
25+
async def test_bulk_works_with_bytestring_body(self, async_client):
26+
docs = b'{ "index" : { "_index" : "bulk_test_index", "_id" : "2" } }\n{"answer": 42}'
27+
response = await async_client.bulk(body=docs)
28+
29+
assert response["errors"] is False
30+
assert len(response["items"]) == 1
Lines changed: 210 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,210 @@
1+
# Licensed to Elasticsearch B.V under one or more agreements.
2+
# Elasticsearch B.V licenses this file to you under the Apache 2.0 License.
3+
# See the LICENSE file in the project root for more information
4+
5+
"""
6+
Dynamically generated set of TestCases based on set of yaml files decribing
7+
some integration tests. These files are shared among all official Elasticsearch
8+
clients.
9+
"""
10+
import pytest
11+
from shutil import rmtree
12+
import warnings
13+
import inspect
14+
15+
from elasticsearch import RequestError, ElasticsearchDeprecationWarning
16+
from elasticsearch.helpers.test import _get_version
17+
from ...test_server.test_rest_api_spec import (
18+
YamlRunner,
19+
YAML_TEST_SPECS,
20+
InvalidActionType,
21+
RUN_ASYNC_REST_API_TESTS,
22+
PARAMS_RENAMES,
23+
IMPLEMENTED_FEATURES,
24+
)
25+
26+
pytestmark = pytest.mark.asyncio
27+
28+
XPACK_FEATURES = None
29+
ES_VERSION = None
30+
31+
32+
async def await_if_coro(x):
33+
if inspect.iscoroutine(x):
34+
return await x
35+
return x
36+
37+
38+
class AsyncYamlRunner(YamlRunner):
39+
async def setup(self):
40+
if self._setup_code:
41+
await self.run_code(self._setup_code)
42+
43+
async def teardown(self):
44+
if self._teardown_code:
45+
await self.run_code(self._teardown_code)
46+
47+
for repo, definition in (
48+
await self.client.snapshot.get_repository(repository="_all")
49+
).items():
50+
await self.client.snapshot.delete_repository(repository=repo)
51+
if definition["type"] == "fs":
52+
rmtree(
53+
"/tmp/%s" % definition["settings"]["location"], ignore_errors=True
54+
)
55+
56+
# stop and remove all ML stuff
57+
if await self._feature_enabled("ml"):
58+
await self.client.ml.stop_datafeed(datafeed_id="*", force=True)
59+
for feed in (await self.client.ml.get_datafeeds(datafeed_id="*"))[
60+
"datafeeds"
61+
]:
62+
await self.client.ml.delete_datafeed(datafeed_id=feed["datafeed_id"])
63+
64+
await self.client.ml.close_job(job_id="*", force=True)
65+
for job in (await self.client.ml.get_jobs(job_id="*"))["jobs"]:
66+
await self.client.ml.delete_job(
67+
job_id=job["job_id"], wait_for_completion=True, force=True
68+
)
69+
70+
# stop and remove all Rollup jobs
71+
if await self._feature_enabled("rollup"):
72+
for rollup in (await self.client.rollup.get_jobs(id="*"))["jobs"]:
73+
await self.client.rollup.stop_job(
74+
id=rollup["config"]["id"], wait_for_completion=True
75+
)
76+
await self.client.rollup.delete_job(id=rollup["config"]["id"])
77+
78+
async def es_version(self):
79+
global ES_VERSION
80+
if ES_VERSION is None:
81+
version_string = (await self.client.info())["version"]["number"]
82+
if "." not in version_string:
83+
return ()
84+
version = version_string.strip().split(".")
85+
ES_VERSION = tuple(int(v) if v.isdigit() else 999 for v in version)
86+
return ES_VERSION
87+
88+
async def run(self):
89+
try:
90+
await self.setup()
91+
await self.run_code(self._run_code)
92+
finally:
93+
await self.teardown()
94+
95+
async def run_code(self, test):
96+
""" Execute an instruction based on it's type. """
97+
print(test)
98+
for action in test:
99+
assert len(action) == 1
100+
action_type, action = list(action.items())[0]
101+
102+
if hasattr(self, "run_" + action_type):
103+
await await_if_coro(getattr(self, "run_" + action_type)(action))
104+
else:
105+
raise InvalidActionType(action_type)
106+
107+
async def run_do(self, action):
108+
api = self.client
109+
headers = action.pop("headers", None)
110+
catch = action.pop("catch", None)
111+
warn = action.pop("warnings", ())
112+
allowed_warnings = action.pop("allowed_warnings", ())
113+
assert len(action) == 1
114+
115+
method, args = list(action.items())[0]
116+
args["headers"] = headers
117+
118+
# locate api endpoint
119+
for m in method.split("."):
120+
assert hasattr(api, m)
121+
api = getattr(api, m)
122+
123+
# some parameters had to be renamed to not clash with python builtins,
124+
# compensate
125+
for k in PARAMS_RENAMES:
126+
if k in args:
127+
args[PARAMS_RENAMES[k]] = args.pop(k)
128+
129+
# resolve vars
130+
for k in args:
131+
args[k] = self._resolve(args[k])
132+
133+
warnings.simplefilter("always", category=ElasticsearchDeprecationWarning)
134+
with warnings.catch_warnings(record=True) as caught_warnings:
135+
try:
136+
self.last_response = await api(**args)
137+
except Exception as e:
138+
if not catch:
139+
raise
140+
self.run_catch(catch, e)
141+
else:
142+
if catch:
143+
raise AssertionError(
144+
"Failed to catch %r in %r." % (catch, self.last_response)
145+
)
146+
147+
# Filter out warnings raised by other components.
148+
caught_warnings = [
149+
str(w.message)
150+
for w in caught_warnings
151+
if w.category == ElasticsearchDeprecationWarning
152+
and str(w.message) not in allowed_warnings
153+
]
154+
155+
# Sorting removes the issue with order raised. We only care about
156+
# if all warnings are raised in the single API call.
157+
if warn and sorted(warn) != sorted(caught_warnings):
158+
raise AssertionError(
159+
"Expected warnings not equal to actual warnings: expected=%r actual=%r"
160+
% (warn, caught_warnings)
161+
)
162+
163+
async def run_skip(self, skip):
164+
if "features" in skip:
165+
features = skip["features"]
166+
if not isinstance(features, (tuple, list)):
167+
features = [features]
168+
for feature in features:
169+
if feature in IMPLEMENTED_FEATURES:
170+
continue
171+
pytest.skip("feature '%s' is not supported" % feature)
172+
173+
if "version" in skip:
174+
version, reason = skip["version"], skip["reason"]
175+
if version == "all":
176+
pytest.skip(reason)
177+
min_version, max_version = version.split("-")
178+
min_version = _get_version(min_version) or (0,)
179+
max_version = _get_version(max_version) or (999,)
180+
if min_version <= (await self.es_version()) <= max_version:
181+
pytest.skip(reason)
182+
183+
async def _feature_enabled(self, name):
184+
global XPACK_FEATURES
185+
if XPACK_FEATURES is None:
186+
try:
187+
xinfo = await self.client.xpack.info()
188+
XPACK_FEATURES = set(
189+
f for f in xinfo["features"] if xinfo["features"][f]["enabled"]
190+
)
191+
IMPLEMENTED_FEATURES.add("xpack")
192+
except RequestError:
193+
XPACK_FEATURES = set()
194+
IMPLEMENTED_FEATURES.add("no_xpack")
195+
return name in XPACK_FEATURES
196+
197+
198+
@pytest.fixture(scope="function")
199+
def async_runner(async_client):
200+
return AsyncYamlRunner(async_client)
201+
202+
203+
@pytest.mark.parametrize("test_spec", YAML_TEST_SPECS)
204+
async def test_rest_api_spec(test_spec, async_runner):
205+
if not RUN_ASYNC_REST_API_TESTS:
206+
pytest.skip("Skipped running async REST API tests")
207+
if test_spec.get("skip", False):
208+
pytest.skip("Manually skipped in 'SKIP_TESTS'")
209+
async_runner.use_spec(test_spec)
210+
await async_runner.run()

test_elasticsearch/test_server/test_rest_api_spec.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,9 @@
77
some integration tests. These files are shared among all official Elasticsearch
88
clients.
99
"""
10+
import sys
1011
import re
12+
import os
1113
from os import walk, environ
1214
from os.path import exists, join, dirname, pardir, relpath
1315
import yaml
@@ -58,11 +60,17 @@
5860
"indices/put_template/10_basic[4]",
5961
# depends on order of response JSON which is random
6062
"indices/simulate_index_template/10_basic[1]",
63+
# body: null? body is {}
64+
"indices/simulate_index_template/10_basic[2]",
6165
}
6266

6367

6468
XPACK_FEATURES = None
6569
ES_VERSION = None
70+
RUN_ASYNC_REST_API_TESTS = (
71+
sys.version_info >= (3, 6)
72+
and os.environ.get("PYTHON_CONNECTION_CLASS") == "RequestsHttpConnection"
73+
)
6674

6775

6876
class YamlRunner:
@@ -78,7 +86,7 @@ def __init__(self, client):
7886
def use_spec(self, test_spec):
7987
self._setup_code = test_spec.pop("setup", None)
8088
self._run_code = test_spec.pop("run", None)
81-
self._teardown_code = test_spec.pop("teardown")
89+
self._teardown_code = test_spec.pop("teardown", None)
8290

8391
def setup(self):
8492
if self._setup_code:
@@ -417,6 +425,8 @@ def sync_runner(sync_client):
417425

418426
@pytest.mark.parametrize("test_spec", YAML_TEST_SPECS)
419427
def test_rest_api_spec(test_spec, sync_runner):
428+
if RUN_ASYNC_REST_API_TESTS:
429+
pytest.skip("Skipped running sync REST API tests")
420430
if test_spec.get("skip", False):
421431
pytest.skip("Manually skipped in 'SKIP_TESTS'")
422432
sync_runner.use_spec(test_spec)

0 commit comments

Comments
 (0)