Skip to content

Commit f2a3ef9

Browse files
committed
Get rid of dirty hack in select() coroutine
1 parent b161d88 commit f2a3ef9

File tree

1 file changed

+36
-17
lines changed

1 file changed

+36
-17
lines changed

peewee_async.py

+36-17
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,12 @@
5252
'savepoint',
5353
]
5454

55+
RESULTS_NAIVE = peewee.RESULTS_NAIVE
56+
RESULTS_MODELS = peewee.RESULTS_MODELS
57+
RESULTS_TUPLES = peewee.RESULTS_TUPLES
58+
RESULTS_DICTS = peewee.RESULTS_DICTS
59+
RESULTS_AGGREGATE_MODELS = peewee.RESULTS_AGGREGATE_MODELS
60+
5561

5662
#################
5763
# Async queries #
@@ -206,24 +212,14 @@ def select(query):
206212
("Error, trying to run select coroutine"
207213
"with wrong query class %s" % str(query))
208214

209-
# Perform *real* async query
210-
query = query.clone()
211-
cursor = yield from _execute_query_async(query)
212-
213-
# Perform *fake* query: we only need a result wrapper
214-
# here, not the query result itself:
215-
query._execute = lambda: None
216-
result_wrapper = query.execute()
217-
218-
# Fetch result
219-
result = AsyncQueryResult(result_wrapper=result_wrapper, cursor=cursor)
215+
result = AsyncQueryWrapper(query)
216+
cursor = yield from result.execute()
220217
try:
221218
while True:
222219
yield from result.fetchone()
223220
except GeneratorExit:
224221
pass
225222

226-
# Release cursor and return
227223
cursor.release()
228224
return result
229225

@@ -375,7 +371,7 @@ def prefetch(sq, *subqueries):
375371
###################
376372

377373

378-
class AsyncQueryResult:
374+
class AsyncQueryWrapper:
379375
"""Async query results wrapper for async `select()`. Internally uses
380376
results wrapper produced by sync peewee select query.
381377
@@ -387,11 +383,12 @@ class AsyncQueryResult:
387383
To retrieve results after async fetching just iterate over this class
388384
instance, like you generally iterate over sync results wrapper.
389385
"""
390-
def __init__(self, result_wrapper=None, cursor=None):
391-
self._result = []
386+
def __init__(self, query):
392387
self._initialized = False
393-
self._result_wrapper = result_wrapper
394-
self._cursor = cursor
388+
self._cursor = None
389+
self._query = query
390+
self._result = []
391+
self._result_wrapper = self._get_result_wrapper(query)
395392

396393
def __iter__(self):
397394
return iter(self._result)
@@ -402,6 +399,28 @@ def __getitem__(self, key):
402399
def __len__(self):
403400
return len(self._result)
404401

402+
@classmethod
403+
def _get_result_wrapper(self, query):
404+
"""Get result wrapper class.
405+
"""
406+
if query._tuples:
407+
QR = query.database.get_result_wrapper(RESULTS_TUPLES)
408+
elif query._dicts:
409+
QR = query.database.get_result_wrapper(RESULTS_DICTS)
410+
elif query._naive or not query._joins or query.verify_naive():
411+
QR = query.database.get_result_wrapper(RESULTS_NAIVE)
412+
elif query._aggregate_rows:
413+
QR = query.database.get_result_wrapper(RESULTS_AGGREGATE_MODELS)
414+
else:
415+
QR = query.database.get_result_wrapper(RESULTS_MODELS)
416+
417+
return QR(query.model_class, None, query.get_query_meta())
418+
419+
@asyncio.coroutine
420+
def execute(self):
421+
self._cursor = yield from _execute_query_async(self._query)
422+
return self._cursor
423+
405424
@asyncio.coroutine
406425
def fetchone(self):
407426
row = yield from self._cursor.fetchone()

0 commit comments

Comments
 (0)