Skip to content

Commit 3bbfdf7

Browse files
committed
Create collections with async context manager
With the current version of pytest-asyncio we're using, there's an issue with using async fixtures cached at different scopes when they need the same event loop scope. See: pytest-dev/pytest-asyncio#871 An API breaking change that fixes this is available in 0.24, but fixing this with a context manager here to avoid increasing the blast radius.
1 parent 515b3a3 commit 3bbfdf7

File tree

1 file changed

+68
-57
lines changed

1 file changed

+68
-57
lines changed

python/tests/integration/connectors/memory/test_postgres.py

Lines changed: 68 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import uuid
44
from collections.abc import AsyncGenerator
5+
from contextlib import asynccontextmanager
56
from typing import Annotated, Any
67

78
import pandas as pd
@@ -10,12 +11,10 @@
1011
from pydantic import BaseModel
1112

1213
from semantic_kernel.connectors.memory.postgres import PostgresStore
13-
from semantic_kernel.connectors.memory.postgres.postgres_collection import PostgresCollection
1414
from semantic_kernel.connectors.memory.postgres.postgres_settings import PostgresSettings
1515
from semantic_kernel.data.const import DistanceFunction, IndexKind
1616
from semantic_kernel.data.vector_store_model_decorator import vectorstoremodel
1717
from semantic_kernel.data.vector_store_model_definition import VectorStoreRecordDefinition
18-
from semantic_kernel.data.vector_store_record_collection import VectorStoreRecordCollection
1918
from semantic_kernel.data.vector_store_record_fields import (
2019
VectorStoreRecordDataField,
2120
VectorStoreRecordKeyField,
@@ -85,14 +84,22 @@ async def vector_store() -> AsyncGenerator[PostgresStore, None]:
8584
yield PostgresStore(connection_pool=pool)
8685

8786

88-
@pytest_asyncio.fixture(scope="function")
89-
async def simple_collection(vector_store: PostgresStore):
87+
@asynccontextmanager
88+
async def create_simple_collection(vector_store: PostgresStore):
89+
"""Returns a collection with a unique name that is deleted after the context.
90+
91+
This can be moved to use a fixture with scope=function and loop_scope=session
92+
after upgrade to pytest-asyncio 0.24. With the current version, the fixture
93+
would both cache and use the event loop of the declared scope.
94+
"""
9095
suffix = str(uuid.uuid4()).replace("-", "")[:8]
9196
collection_id = f"test_collection_{suffix}"
9297
collection = vector_store.get_collection(collection_id, SimpleDataModel)
9398
await collection.create_collection()
94-
yield collection
95-
await collection.delete_collection()
99+
try:
100+
yield collection
101+
finally:
102+
await collection.delete_collection()
96103

97104

98105
def test_create_store(vector_store):
@@ -118,37 +125,40 @@ async def test_create_does_collection_exist_and_delete(vector_store: PostgresSto
118125

119126

120127
@pytest.mark.asyncio(scope="session")
121-
async def test_list_collection_names(vector_store, simple_collection):
122-
simple_collection_id = simple_collection.collection_name
123-
result = await vector_store.list_collection_names()
124-
assert simple_collection_id in result
128+
async def test_list_collection_names(vector_store):
129+
async with create_simple_collection(vector_store) as simple_collection:
130+
simple_collection_id = simple_collection.collection_name
131+
result = await vector_store.list_collection_names()
132+
assert simple_collection_id in result
125133

126134

127135
@pytest.mark.asyncio(scope="session")
128-
async def test_upsert_get_and_delete(simple_collection: PostgresCollection):
136+
async def test_upsert_get_and_delete(vector_store: PostgresStore):
129137
record = SimpleDataModel(id=1, embedding=[1.1, 2.2, 3.3], data={"key": "value"})
138+
async with create_simple_collection(vector_store) as simple_collection:
139+
result_before_upsert = await simple_collection.get(1)
140+
assert result_before_upsert is None
130141

131-
result_before_upsert = await simple_collection.get(1)
132-
assert result_before_upsert is None
133-
134-
await simple_collection.upsert(record)
135-
result = await simple_collection.get(1)
136-
assert result is not None
137-
assert result.id == record.id
138-
assert result.embedding == record.embedding
139-
assert result.data == record.data
140-
141-
# Check that the table has an index
142-
connection_pool = simple_collection.connection_pool
143-
async with connection_pool.connection() as conn, conn.cursor() as cur:
144-
await cur.execute("SELECT indexname FROM pg_indexes WHERE tablename = %s", (simple_collection.collection_name,))
145-
rows = await cur.fetchall()
146-
index_names = [index[0] for index in rows]
147-
assert any("embedding_idx" in index_name for index_name in index_names)
148-
149-
await simple_collection.delete(1)
150-
result_after_delete = await simple_collection.get(1)
151-
assert result_after_delete is None
142+
await simple_collection.upsert(record)
143+
result = await simple_collection.get(1)
144+
assert result is not None
145+
assert result.id == record.id
146+
assert result.embedding == record.embedding
147+
assert result.data == record.data
148+
149+
# Check that the table has an index
150+
connection_pool = simple_collection.connection_pool
151+
async with connection_pool.connection() as conn, conn.cursor() as cur:
152+
await cur.execute(
153+
"SELECT indexname FROM pg_indexes WHERE tablename = %s", (simple_collection.collection_name,)
154+
)
155+
rows = await cur.fetchall()
156+
index_names = [index[0] for index in rows]
157+
assert any("embedding_idx" in index_name for index_name in index_names)
158+
159+
await simple_collection.delete(1)
160+
result_after_delete = await simple_collection.get(1)
161+
assert result_after_delete is None
152162

153163

154164
@pytest.mark.asyncio(scope="session")
@@ -182,28 +192,29 @@ async def test_upsert_get_and_delete_pandas(vector_store):
182192

183193

184194
@pytest.mark.asyncio(scope="session")
185-
async def test_upsert_get_and_delete_batch(simple_collection: VectorStoreRecordCollection):
186-
record1 = SimpleDataModel(id=1, embedding=[1.1, 2.2, 3.3], data={"key": "value"})
187-
record2 = SimpleDataModel(id=2, embedding=[4.4, 5.5, 6.6], data={"key": "value"})
188-
189-
result_before_upsert = await simple_collection.get_batch([1, 2])
190-
assert result_before_upsert is None
191-
192-
await simple_collection.upsert_batch([record1, record2])
193-
# Test get_batch for the two existing keys and one non-existing key;
194-
# this should return only the two existing records.
195-
result = await simple_collection.get_batch([1, 2, 3])
196-
assert result is not None
197-
assert len(result) == 2
198-
assert result[0] is not None
199-
assert result[0].id == record1.id
200-
assert result[0].embedding == record1.embedding
201-
assert result[0].data == record1.data
202-
assert result[1] is not None
203-
assert result[1].id == record2.id
204-
assert result[1].embedding == record2.embedding
205-
assert result[1].data == record2.data
206-
207-
await simple_collection.delete_batch([1, 2])
208-
result_after_delete = await simple_collection.get_batch([1, 2])
209-
assert result_after_delete is None
195+
async def test_upsert_get_and_delete_batch(vector_store: PostgresStore):
196+
async with create_simple_collection(vector_store) as simple_collection:
197+
record1 = SimpleDataModel(id=1, embedding=[1.1, 2.2, 3.3], data={"key": "value"})
198+
record2 = SimpleDataModel(id=2, embedding=[4.4, 5.5, 6.6], data={"key": "value"})
199+
200+
result_before_upsert = await simple_collection.get_batch([1, 2])
201+
assert result_before_upsert is None
202+
203+
await simple_collection.upsert_batch([record1, record2])
204+
# Test get_batch for the two existing keys and one non-existing key;
205+
# this should return only the two existing records.
206+
result = await simple_collection.get_batch([1, 2, 3])
207+
assert result is not None
208+
assert len(result) == 2
209+
assert result[0] is not None
210+
assert result[0].id == record1.id
211+
assert result[0].embedding == record1.embedding
212+
assert result[0].data == record1.data
213+
assert result[1] is not None
214+
assert result[1].id == record2.id
215+
assert result[1].embedding == record2.embedding
216+
assert result[1].data == record2.data
217+
218+
await simple_collection.delete_batch([1, 2])
219+
result_after_delete = await simple_collection.get_batch([1, 2])
220+
assert result_after_delete is None

0 commit comments

Comments
 (0)