diff --git a/graphene_sqlalchemy/batching.py b/graphene_sqlalchemy/batching.py index 275d590..0800d0e 100644 --- a/graphene_sqlalchemy/batching.py +++ b/graphene_sqlalchemy/batching.py @@ -2,7 +2,6 @@ from asyncio import get_event_loop from typing import Any, Dict -import aiodataloader import sqlalchemy from sqlalchemy.orm import Session, strategies from sqlalchemy.orm.query import QueryContext @@ -10,7 +9,21 @@ from .utils import is_graphene_version_less_than, is_sqlalchemy_version_less_than -class RelationshipLoader(aiodataloader.DataLoader): +def get_data_loader_impl() -> Any: # pragma: no cover + """Graphene >= 3.1.1 ships a copy of aiodataloader with minor fixes. To preserve backward-compatibility, + aiodataloader is used in conjunction with older versions of graphene""" + if is_graphene_version_less_than("3.1.1"): + from aiodataloader import DataLoader + else: + from graphene.utils.dataloader import DataLoader + + return DataLoader + + +DataLoader = get_data_loader_impl() + + +class RelationshipLoader(DataLoader): cache = False def __init__(self, relationship_prop, selectin_loader): @@ -92,20 +105,6 @@ async def batch_load_fn(self, parents): ] = {} -def get_data_loader_impl() -> Any: # pragma: no cover - """Graphene >= 3.1.1 ships a copy of aiodataloader with minor fixes. To preserve backward-compatibility, - aiodataloader is used in conjunction with older versions of graphene""" - if is_graphene_version_less_than("3.1.1"): - from aiodataloader import DataLoader - else: - from graphene.utils.dataloader import DataLoader - - return DataLoader - - -DataLoader = get_data_loader_impl() - - def get_batch_resolver(relationship_prop): """Get the resolve function for the given relationship."""