Skip to content

Commit b843fe3

Browse files
committed
Merge pull request pytest-dev#16 from cypreess/master
Implement async `prefetch()`, close pytest-dev#11. Thanks @cypreess 👍 !
2 parents e657f7f + d05bc60 commit b843fe3

File tree

3 files changed

+113
-7
lines changed

3 files changed

+113
-7
lines changed

docs/peewee_async/api.rst

+1
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ Select, update, delete
1111
.. autofunction:: peewee_async.create_object
1212
.. autofunction:: peewee_async.delete_object
1313
.. autofunction:: peewee_async.update_object
14+
.. autofunction:: peewee_async.prefetch
1415

1516
Transactions
1617
------------

peewee_async.py

+41
Original file line numberDiff line numberDiff line change
@@ -720,3 +720,44 @@ def _compose_dsn(dbname, **kwargs):
720720
if v:
721721
dsn += ' %s=%s' % (k, v)
722722
return dsn, kwargs
723+
724+
725+
@asyncio.coroutine
726+
def prefetch(sq, *subqueries):
727+
"""Asynchronous version of the prefetch function from peewee.
728+
729+
Returns Query that has already cached data.
730+
"""
731+
732+
# This code is copied from peewee.prefetch and adopted to use async execute
733+
734+
if not subqueries:
735+
return sq
736+
fixed_queries = peewee.prefetch_add_subquery(sq, subqueries)
737+
738+
deps = {}
739+
rel_map = {}
740+
for prefetch_result in reversed(fixed_queries):
741+
query_model = prefetch_result.model
742+
if prefetch_result.fields:
743+
for rel_model in prefetch_result.rel_models:
744+
rel_map.setdefault(rel_model, [])
745+
rel_map[rel_model].append(prefetch_result)
746+
747+
deps[query_model] = {}
748+
id_map = deps[query_model]
749+
has_relations = bool(rel_map.get(query_model))
750+
751+
# This is hack, because peewee async execute do a copy of query and do not change state of query
752+
# comparing to what real peewee is doing when execute method is called
753+
prefetch_result.query._qr = yield from execute(prefetch_result.query)
754+
prefetch_result.query._dirty = False
755+
756+
for instance in prefetch_result.query._qr:
757+
if prefetch_result.fields:
758+
prefetch_result.store_instance(instance, id_map)
759+
if has_relations:
760+
for rel in rel_map[query_model]:
761+
rel.populate_instance(instance, deps[rel.model])
762+
763+
return prefetch_result.query

tests/__init__.py

+71-7
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,29 @@ class Meta:
108108
database = database
109109

110110

111+
class TestModelAlpha(peewee.Model):
112+
text = peewee.CharField()
113+
114+
class Meta:
115+
database = database
116+
117+
118+
class TestModelBeta(peewee.Model):
119+
alpha = peewee.ForeignKeyField(TestModelAlpha, related_name='betas')
120+
text = peewee.CharField()
121+
122+
class Meta:
123+
database = database
124+
125+
126+
class TestModelGamma(peewee.Model):
127+
text = peewee.CharField()
128+
beta = peewee.ForeignKeyField(TestModelBeta, related_name='gammas')
129+
130+
class Meta:
131+
database = database
132+
133+
111134
class UUIDTestModel(peewee.Model):
112135
id = peewee.UUIDField(primary_key=True, default=uuid.uuid4)
113136
text = peewee.CharField()
@@ -117,6 +140,8 @@ class Meta:
117140

118141

119142
class BaseAsyncPostgresTestCase(unittest.TestCase):
143+
db_tables = [TestModel, UUIDTestModel, TestModelAlpha, TestModelBeta, TestModelGamma]
144+
120145
@classmethod
121146
def setUpClass(cls, *args, **kwargs):
122147
# Sync connect
@@ -129,22 +154,38 @@ def test():
129154
yield from database.connect_async(loop=cls.loop)
130155
cls.loop.run_until_complete(test())
131156

132-
# Clean up after possible errors
133-
TestModel.drop_table(True)
134-
UUIDTestModel.drop_table(True)
157+
for table in reversed(cls.db_tables):
158+
# Clean up after possible errors
159+
table.drop_table(True, cascade=True)
135160

136-
# Create table with sync connection
137-
TestModel.create_table()
138-
UUIDTestModel.create_table()
161+
for table in cls.db_tables:
162+
# Create table with sync connection
163+
table.create_table()
139164

140165
# Create at least one object per model
141166
cls.obj = TestModel.create(text='[sync] Hello!')
142167
cls.uuid_obj = UUIDTestModel.create(text='[sync] Hello!')
143168

169+
cls.alpha_1 = TestModelAlpha.create(text='Alpha 1')
170+
cls.alpha_2 = TestModelAlpha.create(text='Alpha 2')
171+
172+
cls.beta_11 = TestModelBeta.create(text='Beta 1', alpha=cls.alpha_1)
173+
cls.beta_12 = TestModelBeta.create(text='Beta 2', alpha=cls.alpha_1)
174+
175+
cls.beta_21 = TestModelBeta.create(text='Beta 1', alpha=cls.alpha_2)
176+
cls.beta_22 = TestModelBeta.create(text='Beta 2', alpha=cls.alpha_2)
177+
178+
cls.gamma_111 = TestModelGamma.create(text='Gamma 1', beta=cls.beta_11)
179+
cls.gamma_112 = TestModelGamma.create(text='Gamma 2', beta=cls.beta_11)
180+
181+
cls.gamma_121 = TestModelGamma.create(text='Gamma 1', beta=cls.beta_12)
182+
183+
144184
@classmethod
145185
def tearDownClass(cls, *args, **kwargs):
186+
for table in reversed(cls.db_tables):
146187
# Finally, clean up
147-
TestModel.drop_table()
188+
table.drop_table()
148189

149190
# Close database
150191
database.close()
@@ -351,6 +392,29 @@ def test():
351392

352393
self.run_until_complete(test())
353394

395+
def test_prefetch(self):
396+
# Async prefetch
397+
@asyncio.coroutine
398+
def test():
399+
with sync_unwanted(database):
400+
result = yield from peewee_async.prefetch(TestModelAlpha.select(), TestModelBeta.select(),
401+
TestModelGamma.select())
402+
403+
result = list(result) # this should NOT fire any call (will read it from query cache)
404+
405+
# Check if we have here both alpha items in specific order
406+
self.assertEqual(result, [self.alpha_1, self.alpha_2])
407+
408+
alpha_1 = result[0]
409+
self.assertEqual(alpha_1.betas_prefetch, [self.beta_11, self.beta_12])
410+
411+
beta_11 = alpha_1.betas_prefetch[0]
412+
self.assertEqual(beta_11, self.beta_11)
413+
414+
self.assertEqual(beta_11.gammas_prefetch, [self.gamma_111, self.gamma_112])
415+
416+
self.run_until_complete(test())
417+
354418

355419
if sys.version_info >= (3, 5):
356420
from .tests_py35 import *

0 commit comments

Comments
 (0)