52
52
'savepoint' ,
53
53
]
54
54
55
+ RESULTS_NAIVE = peewee .RESULTS_NAIVE
56
+ RESULTS_MODELS = peewee .RESULTS_MODELS
57
+ RESULTS_TUPLES = peewee .RESULTS_TUPLES
58
+ RESULTS_DICTS = peewee .RESULTS_DICTS
59
+ RESULTS_AGGREGATE_MODELS = peewee .RESULTS_AGGREGATE_MODELS
60
+
55
61
56
62
#################
57
63
# Async queries #
@@ -206,24 +212,14 @@ def select(query):
206
212
("Error, trying to run select coroutine"
207
213
"with wrong query class %s" % str (query ))
208
214
209
- # Perform *real* async query
210
- query = query .clone ()
211
- cursor = yield from _execute_query_async (query )
212
-
213
- # Perform *fake* query: we only need a result wrapper
214
- # here, not the query result itself:
215
- query ._execute = lambda : None
216
- result_wrapper = query .execute ()
217
-
218
- # Fetch result
219
- result = AsyncQueryResult (result_wrapper = result_wrapper , cursor = cursor )
215
+ result = AsyncQueryWrapper (query )
216
+ cursor = yield from result .execute ()
220
217
try :
221
218
while True :
222
219
yield from result .fetchone ()
223
220
except GeneratorExit :
224
221
pass
225
222
226
- # Release cursor and return
227
223
cursor .release ()
228
224
return result
229
225
@@ -375,7 +371,7 @@ def prefetch(sq, *subqueries):
375
371
###################
376
372
377
373
378
- class AsyncQueryResult :
374
+ class AsyncQueryWrapper :
379
375
"""Async query results wrapper for async `select()`. Internally uses
380
376
results wrapper produced by sync peewee select query.
381
377
@@ -387,11 +383,12 @@ class AsyncQueryResult:
387
383
To retrieve results after async fetching just iterate over this class
388
384
instance, like you generally iterate over sync results wrapper.
389
385
"""
390
- def __init__ (self , result_wrapper = None , cursor = None ):
391
- self ._result = []
386
+ def __init__ (self , query ):
392
387
self ._initialized = False
393
- self ._result_wrapper = result_wrapper
394
- self ._cursor = cursor
388
+ self ._cursor = None
389
+ self ._query = query
390
+ self ._result = []
391
+ self ._result_wrapper = self ._get_result_wrapper (query )
395
392
396
393
def __iter__ (self ):
397
394
return iter (self ._result )
@@ -402,6 +399,28 @@ def __getitem__(self, key):
402
399
def __len__ (self ):
403
400
return len (self ._result )
404
401
402
+ @classmethod
403
+ def _get_result_wrapper (self , query ):
404
+ """Get result wrapper class.
405
+ """
406
+ if query ._tuples :
407
+ QR = query .database .get_result_wrapper (RESULTS_TUPLES )
408
+ elif query ._dicts :
409
+ QR = query .database .get_result_wrapper (RESULTS_DICTS )
410
+ elif query ._naive or not query ._joins or query .verify_naive ():
411
+ QR = query .database .get_result_wrapper (RESULTS_NAIVE )
412
+ elif query ._aggregate_rows :
413
+ QR = query .database .get_result_wrapper (RESULTS_AGGREGATE_MODELS )
414
+ else :
415
+ QR = query .database .get_result_wrapper (RESULTS_MODELS )
416
+
417
+ return QR (query .model_class , None , query .get_query_meta ())
418
+
419
+ @asyncio .coroutine
420
+ def execute (self ):
421
+ self ._cursor = yield from _execute_query_async (self ._query )
422
+ return self ._cursor
423
+
405
424
@asyncio .coroutine
406
425
def fetchone (self ):
407
426
row = yield from self ._cursor .fetchone ()
0 commit comments