Skip to content

Commit 6cdb7e8

Browse files
handle the connections with proper event loop parsing
1 parent 1784e5a commit 6cdb7e8

File tree

3 files changed

+93
-62
lines changed

3 files changed

+93
-62
lines changed

redisvl/index/index.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -97,19 +97,19 @@ def wrapper(self, *args, **kwargs):
9797
return decorator
9898

9999

100-
def setup_async_redis():
101-
def decorator(func):
102-
@wraps(func)
103-
def wrapper(self, *args, **kwargs):
104-
result = func(self, *args, **kwargs)
105-
RedisConnectionFactory.validate_async_redis(
106-
self._redis_client, self._lib_name
107-
)
108-
return result
100+
# def setup_async_redis():
101+
# def decorator(func):
102+
# @wraps(func)
103+
# def wrapper(self, *args, **kwargs):
104+
# result = func(self, *args, **kwargs)
105+
# RedisConnectionFactory.validate_async_redis(
106+
# self._redis_client, self._lib_name
107+
# )
108+
# return result
109109

110-
return wrapper
110+
# return wrapper
111111

112-
return decorator
112+
# return decorator
113113

114114

115115
def check_index_exists():
@@ -741,7 +741,7 @@ def connect(self, redis_url: Optional[str] = None, **kwargs):
741741
)
742742
return self.set_client(client)
743743

744-
@setup_async_redis()
744+
@setup_redis()
745745
def set_client(self, client: aredis.Redis):
746746
"""Manually set the Redis client to use with the search index.
747747

redisvl/redis/connection.py

Lines changed: 80 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1+
import asyncio
12
import os
2-
from typing import Any, Dict, List, Optional, Type
3+
from typing import Any, Dict, List, Optional, Type, Union
34

45
from redis import Redis
56
from redis.asyncio import Redis as AsyncRedis
@@ -122,59 +123,105 @@ def get_async_redis_connection(url: Optional[str] = None, **kwargs) -> AsyncRedi
122123
# fallback to env var REDIS_URL
123124
return AsyncRedis.from_url(get_address_from_env(), **kwargs)
124125

125-
@staticmethod
126126
def validate_redis(
127-
client: Redis,
127+
client: Union[Redis, AsyncRedis],
128128
lib_name: Optional[str] = None,
129129
redis_required_modules: Optional[List[Dict[str, Any]]] = None,
130130
) -> None:
131-
"""Validates if the required Redis modules are installed.
131+
"""Validates the Redis connection.
132132
133133
Args:
134-
client (Redis): Synchronous Redis client.
134+
client (Redis or AsyncRedis): Redis client.
135+
lib_name (str): Library name to set on the Redis client.
136+
redis_required_modules (List[Dict[str, Any]]): List of required modules and their versions.
135137
136138
Raises:
137139
ValueError: If required Redis modules are not installed.
138140
"""
139-
# set client library name
141+
if isinstance(client, AsyncRedis):
142+
print("VALIDATING ASYNC CLIENT", flush=True)
143+
RedisConnectionFactory._run_async(
144+
RedisConnectionFactory._validate_async_redis,
145+
client,
146+
lib_name,
147+
redis_required_modules,
148+
)
149+
else:
150+
RedisConnectionFactory._validate_sync_redis(
151+
client, lib_name, redis_required_modules
152+
)
153+
154+
@staticmethod
155+
def _validate_sync_redis(
156+
client: Redis,
157+
lib_name: Optional[str],
158+
redis_required_modules: Optional[List[Dict[str, Any]]],
159+
) -> None:
160+
"""Validates the sync client."""
161+
# Set client library name
140162
client.client_setinfo("LIB-NAME", make_lib_name(lib_name))
141163

142-
# validate available modules
143-
RedisConnectionFactory._validate_modules(
144-
convert_bytes(client.module_list()), redis_required_modules
145-
)
164+
# Get list of modules
165+
modules_list = convert_bytes(client.module_list())
166+
167+
# Validate available modules
168+
RedisConnectionFactory._validate_modules(modules_list, redis_required_modules)
146169

147170
@staticmethod
148-
def validate_async_redis(
171+
async def _validate_async_redis(
149172
client: AsyncRedis,
150-
lib_name: Optional[str] = None,
151-
redis_required_modules: Optional[List[Dict[str, Any]]] = None,
173+
lib_name: Optional[str],
174+
redis_required_modules: Optional[List[Dict[str, Any]]],
152175
) -> None:
176+
"""Validates the async client."""
177+
# Set client library name
178+
res = await client.client_setinfo("LIB-NAME", make_lib_name(lib_name))
179+
print("SET ASYNC CLIENT NAME", res, flush=True)
180+
181+
# Get list of modules
182+
modules_list = convert_bytes(await client.module_list())
183+
184+
# Validate available modules
185+
RedisConnectionFactory._validate_modules(modules_list, redis_required_modules)
186+
187+
@staticmethod
188+
def _run_async(coro, *args, **kwargs):
153189
"""
154-
Validates if the required Redis modules are installed.
190+
Runs an asynchronous function in the appropriate event loop context.
191+
192+
This method checks if there is an existing event loop running. If there is,
193+
it schedules the coroutine to be run within the current loop using `asyncio.ensure_future`.
194+
If no event loop is running, it creates a new event loop, runs the coroutine,
195+
and then closes the loop to avoid resource leaks.
155196
156197
Args:
157-
client (AsyncRedis): Asynchronous Redis client.
198+
coro (coroutine): The coroutine function to be run.
199+
*args: Positional arguments to pass to the coroutine function.
200+
**kwargs: Keyword arguments to pass to the coroutine function.
158201
159-
Raises:
160-
ValueError: If required Redis modules are not installed.
202+
Returns:
203+
The result of the coroutine if a new event loop is created,
204+
otherwise a task object representing the coroutine execution.
161205
"""
162-
# pick the right connection class
163-
connection_class: Type[AbstractConnection] = (
164-
SSLConnection
165-
if client.connection_pool.connection_class == ASSLConnection
166-
else Connection
167-
)
168-
# set up a temp sync client
169-
temp_client = Redis(
170-
connection_pool=ConnectionPool(
171-
connection_class=connection_class,
172-
**client.connection_pool.connection_kwargs,
173-
)
174-
)
175-
RedisConnectionFactory.validate_redis(
176-
temp_client, lib_name, redis_required_modules
177-
)
206+
try:
207+
# Try to get the current running event loop
208+
loop = asyncio.get_running_loop()
209+
except RuntimeError: # No running event loop
210+
loop = None
211+
212+
if loop and loop.is_running():
213+
# If an event loop is running, schedule the coroutine to run in the existing loop
214+
return asyncio.ensure_future(coro(*args, **kwargs))
215+
else:
216+
# No event loop is running, create a new event loop
217+
loop = asyncio.new_event_loop()
218+
asyncio.set_event_loop(loop)
219+
try:
220+
# Run the coroutine in the new event loop and wait for it to complete
221+
return loop.run_until_complete(coro(*args, **kwargs))
222+
finally:
223+
# Close the event loop to release resources
224+
loop.close()
178225

179226
@staticmethod
180227
def _validate_modules(

tests/integration/test_connection.py

Lines changed: 1 addition & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -58,23 +58,7 @@ def test_validate_redis(client):
5858
assert lib_name["lib-name"] == EXPECTED_LIB_NAME
5959

6060

61-
@pytest.mark.asyncio
62-
async def test_validate_async_redis(async_client):
63-
client = await async_client
64-
RedisConnectionFactory.validate_async_redis(client)
65-
lib_name = await client.client_info()
66-
assert lib_name["lib-name"] == EXPECTED_LIB_NAME
67-
68-
69-
def test_custom_lib_name(client):
61+
def test_validate_redis_custom_lib_name(client):
7062
RedisConnectionFactory.validate_redis(client, "langchain_v0.1.0")
7163
lib_name = client.client_info()
7264
assert lib_name["lib-name"] == f"redis-py(redisvl_v{__version__};langchain_v0.1.0)"
73-
74-
75-
@pytest.mark.asyncio
76-
async def test_async_custom_lib_name(async_client):
77-
client = await async_client
78-
RedisConnectionFactory.validate_async_redis(client, "langchain_v0.1.0")
79-
lib_name = await client.client_info()
80-
assert lib_name["lib-name"] == f"redis-py(redisvl_v{__version__};langchain_v0.1.0)"

0 commit comments

Comments
 (0)