Skip to content

Commit d0ad40d

Browse files
committed
Add AsyncConnectionPool tests
1 parent 1742803 commit d0ad40d

File tree

1 file changed

+144
-0
lines changed

1 file changed

+144
-0
lines changed
Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
# Licensed to Elasticsearch B.V under one or more agreements.
2+
# Elasticsearch B.V licenses this file to you under the Apache 2.0 License.
3+
# See the LICENSE file in the project root for more information
4+
5+
import time
6+
7+
from elasticsearch import (
8+
AsyncConnectionPool,
9+
RoundRobinSelector,
10+
AsyncDummyConnectionPool,
11+
)
12+
from elasticsearch.connection import Connection
13+
from elasticsearch.exceptions import ImproperlyConfigured
14+
15+
from ..test_cases import TestCase
16+
17+
18+
class TestConnectionPool(TestCase):
19+
def test_dummy_cp_raises_exception_on_more_connections(self):
20+
self.assertRaises(ImproperlyConfigured, AsyncDummyConnectionPool, [])
21+
self.assertRaises(
22+
ImproperlyConfigured, AsyncDummyConnectionPool, [object(), object()]
23+
)
24+
25+
def test_raises_exception_when_no_connections_defined(self):
26+
self.assertRaises(ImproperlyConfigured, AsyncConnectionPool, [])
27+
28+
def test_default_round_robin(self):
29+
pool = AsyncConnectionPool([(x, {}) for x in range(100)])
30+
31+
connections = set()
32+
for _ in range(100):
33+
connections.add(pool.get_connection())
34+
self.assertEqual(connections, set(range(100)))
35+
36+
def test_disable_shuffling(self):
37+
pool = AsyncConnectionPool([(x, {}) for x in range(100)], randomize_hosts=False)
38+
39+
connections = []
40+
for _ in range(100):
41+
connections.append(pool.get_connection())
42+
self.assertEqual(connections, list(range(100)))
43+
44+
def test_selectors_have_access_to_connection_opts(self):
45+
class MySelector(RoundRobinSelector):
46+
def select(self, connections):
47+
return self.connection_opts[
48+
super(MySelector, self).select(connections)
49+
]["actual"]
50+
51+
pool = AsyncConnectionPool(
52+
[(x, {"actual": x * x}) for x in range(100)],
53+
selector_class=MySelector,
54+
randomize_hosts=False,
55+
)
56+
57+
connections = []
58+
for _ in range(100):
59+
connections.append(pool.get_connection())
60+
self.assertEqual(connections, [x * x for x in range(100)])
61+
62+
def test_dead_nodes_are_removed_from_active_connections(self):
63+
pool = AsyncConnectionPool([(x, {}) for x in range(100)])
64+
65+
now = time.time()
66+
pool.mark_dead(42, now=now)
67+
self.assertEqual(99, len(pool.connections))
68+
self.assertEqual(1, pool.dead.qsize())
69+
self.assertEqual((now + 60, 42), pool.dead.get())
70+
71+
def test_connection_is_skipped_when_dead(self):
72+
pool = AsyncConnectionPool([(x, {}) for x in range(2)])
73+
pool.mark_dead(0)
74+
75+
self.assertEqual(
76+
[1, 1, 1],
77+
[pool.get_connection(), pool.get_connection(), pool.get_connection()],
78+
)
79+
80+
def test_new_connection_is_not_marked_dead(self):
81+
# Create 10 connections
82+
pool = AsyncConnectionPool([(Connection(), {}) for _ in range(10)])
83+
84+
# Pass in a new connection that is not in the pool to mark as dead
85+
new_connection = Connection()
86+
pool.mark_dead(new_connection)
87+
88+
# Nothing should be marked dead
89+
self.assertEqual(0, len(pool.dead_count))
90+
91+
def test_connection_is_forcibly_resurrected_when_no_live_ones_are_availible(self):
92+
pool = AsyncConnectionPool([(x, {}) for x in range(2)])
93+
pool.dead_count[0] = 1
94+
pool.mark_dead(0) # failed twice, longer timeout
95+
pool.mark_dead(1) # failed the first time, first to be resurrected
96+
97+
self.assertEqual([], pool.connections)
98+
self.assertEqual(1, pool.get_connection())
99+
self.assertEqual([1], pool.connections)
100+
101+
def test_connection_is_resurrected_after_its_timeout(self):
102+
pool = AsyncConnectionPool([(x, {}) for x in range(100)])
103+
104+
now = time.time()
105+
pool.mark_dead(42, now=now - 61)
106+
pool.get_connection()
107+
self.assertEqual(42, pool.connections[-1])
108+
self.assertEqual(100, len(pool.connections))
109+
110+
def test_force_resurrect_always_returns_a_connection(self):
111+
pool = AsyncConnectionPool([(0, {})])
112+
113+
pool.connections = []
114+
self.assertEqual(0, pool.get_connection())
115+
self.assertEqual([], pool.connections)
116+
self.assertTrue(pool.dead.empty())
117+
118+
def test_already_failed_connection_has_longer_timeout(self):
119+
pool = AsyncConnectionPool([(x, {}) for x in range(100)])
120+
now = time.time()
121+
pool.dead_count[42] = 2
122+
pool.mark_dead(42, now=now)
123+
124+
self.assertEqual(3, pool.dead_count[42])
125+
self.assertEqual((now + 4 * 60, 42), pool.dead.get())
126+
127+
def test_timeout_for_failed_connections_is_limited(self):
128+
pool = AsyncConnectionPool([(x, {}) for x in range(100)])
129+
now = time.time()
130+
pool.dead_count[42] = 245
131+
pool.mark_dead(42, now=now)
132+
133+
self.assertEqual(246, pool.dead_count[42])
134+
self.assertEqual((now + 32 * 60, 42), pool.dead.get())
135+
136+
def test_dead_count_is_wiped_clean_for_connection_if_marked_live(self):
137+
pool = AsyncConnectionPool([(x, {}) for x in range(100)])
138+
now = time.time()
139+
pool.dead_count[42] = 2
140+
pool.mark_dead(42, now=now)
141+
142+
self.assertEqual(3, pool.dead_count[42])
143+
pool.mark_live(42)
144+
self.assertNotIn(42, pool.dead_count)

0 commit comments

Comments
 (0)