Skip to content

Commit fcb1fe3

Browse files
Mock embedding model to address flaky vector tests (#1822) (#1823)
(cherry picked from commit b5435a8) Co-authored-by: Miguel Grinberg <[email protected]>
1 parent ab70d6f commit fcb1fe3

File tree

5 files changed

+95
-12
lines changed

5 files changed

+95
-12
lines changed

tests/async_sleep.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
# Licensed to Elasticsearch B.V. under one or more contributor
2+
# license agreements. See the NOTICE file distributed with
3+
# this work for additional information regarding copyright
4+
# ownership. Elasticsearch B.V. licenses this file to you under
5+
# the Apache License, Version 2.0 (the "License"); you may
6+
# not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
import asyncio
19+
20+
21+
async def sleep(secs):
22+
"""Tests can use this function to sleep."""
23+
await asyncio.sleep(secs)

tests/sleep.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
# Licensed to Elasticsearch B.V. under one or more contributor
2+
# license agreements. See the NOTICE file distributed with
3+
# this work for additional information regarding copyright
4+
# ownership. Elasticsearch B.V. licenses this file to you under
5+
# the Apache License, Version 2.0 (the "License"); you may
6+
# not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
import time
19+
20+
21+
def sleep(secs):
22+
"""Tests can use this function to sleep."""
23+
time.sleep(secs)

tests/test_integration/test_examples/_async/test_vectors.py

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,20 +15,38 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717

18+
from hashlib import md5
1819
from unittest import SkipTest
1920

2021
import pytest
2122

22-
from ..async_examples.vectors import create, search
23+
from tests.async_sleep import sleep
24+
25+
from ..async_examples import vectors
2326

2427

2528
@pytest.mark.asyncio
26-
async def test_vector_search(async_write_client, es_version):
29+
async def test_vector_search(async_write_client, es_version, mocker):
2730
# this test only runs on Elasticsearch >= 8.11 because the example uses
28-
# a dense vector without giving them an explicit size
31+
# a dense vector without specifying an explicit size
2932
if es_version < (8, 11):
3033
raise SkipTest("This test requires Elasticsearch 8.11 or newer")
3134

32-
await create()
33-
results = await (await search("work from home")).execute()
34-
assert results[0].name == "Work From Home Policy"
35+
class MockModel:
36+
def __init__(self, model):
37+
pass
38+
39+
def encode(self, text):
40+
vector = [int(ch) for ch in md5(text.encode()).digest()]
41+
total = sum(vector)
42+
return [float(v) / total for v in vector]
43+
44+
mocker.patch.object(vectors, "SentenceTransformer", new=MockModel)
45+
46+
await vectors.create()
47+
for i in range(10):
48+
results = await (await vectors.search("Welcome to our team!")).execute()
49+
if len(results.hits) > 0:
50+
break
51+
await sleep(0.1)
52+
assert results[0].name == "New Employee Onboarding Guide"

tests/test_integration/test_examples/_sync/test_vectors.py

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,20 +15,38 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717

18+
from hashlib import md5
1819
from unittest import SkipTest
1920

2021
import pytest
2122

22-
from ..examples.vectors import create, search
23+
from tests.sleep import sleep
24+
25+
from ..examples import vectors
2326

2427

2528
@pytest.mark.sync
26-
def test_vector_search(write_client, es_version):
29+
def test_vector_search(write_client, es_version, mocker):
2730
# this test only runs on Elasticsearch >= 8.11 because the example uses
28-
# a dense vector without giving them an explicit size
31+
# a dense vector without specifying an explicit size
2932
if es_version < (8, 11):
3033
raise SkipTest("This test requires Elasticsearch 8.11 or newer")
3134

32-
create()
33-
results = (search("work from home")).execute()
34-
assert results[0].name == "Work From Home Policy"
35+
class MockModel:
36+
def __init__(self, model):
37+
pass
38+
39+
def encode(self, text):
40+
vector = [int(ch) for ch in md5(text.encode()).digest()]
41+
total = sum(vector)
42+
return [float(v) / total for v in vector]
43+
44+
mocker.patch.object(vectors, "SentenceTransformer", new=MockModel)
45+
46+
vectors.create()
47+
for i in range(10):
48+
results = (vectors.search("Welcome to our team!")).execute()
49+
if len(results.hits) > 0:
50+
break
51+
sleep(0.1)
52+
assert results[0].name == "New Employee Onboarding Guide"

utils/run-unasync.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ def main(check=False):
6969
"async_write_client": "write_client",
7070
"async_pull_request": "pull_request",
7171
"async_examples": "examples",
72+
"async_sleep": "sleep",
7273
"assert_awaited_once_with": "assert_called_once_with",
7374
"pytest_asyncio": "pytest",
7475
}

0 commit comments

Comments
 (0)