Skip to content

Commit effe078

Browse files
committed
Fix pytest-dev#30: allow_sync() context manager didn't work with peewee.Proxy; also database's .allow_sync converted to context manager; database's .allow_sync setter marked as deprecated
1 parent 70269ab commit effe078

File tree

2 files changed

+124
-99
lines changed

2 files changed

+124
-99
lines changed

peewee_async.py

+90-53
Original file line numberDiff line numberDiff line change
@@ -107,13 +107,23 @@ def __init__(self, database=None, *, loop=None):
107107
("Error, database must be provided via "
108108
"argument or class member.")
109109

110-
self.loop = loop or asyncio.get_event_loop()
111110
self.database = database or self.database
111+
self._loop = loop
112+
112113
attach_callback = getattr(self.database, 'attach_callback', None)
113114
if attach_callback:
114-
attach_callback(lambda db: db.set_event_loop(self.loop))
115+
attach_callback(lambda db: setattr(db, '_loop', loop))
115116
else:
116-
self.database.set_event_loop(self.loop)
117+
self.database._loop = loop
118+
119+
@property
120+
def loop(self):
121+
"""Get the event loop.
122+
123+
If no event loop is provided explicitly on creating
124+
the instance, just return the current event loop.
125+
"""
126+
return self._loop or asyncio.get_event_loop()
117127

118128
@property
119129
def is_connected(self):
@@ -123,7 +133,7 @@ def is_connected(self):
123133

124134
@asyncio.coroutine
125135
def get(self, source, *args, **kwargs):
126-
"""Get model instance.
136+
"""Get the model instance.
127137
128138
:param source: model or base query for lookup
129139
@@ -159,7 +169,7 @@ async def my_async_func():
159169

160170
@asyncio.coroutine
161171
def create(self, model, **data):
162-
"""Create new object saved to database.
172+
"""Create a new object saved to database.
163173
"""
164174
inst = model(**data)
165175
query = model.insert(**dict(inst._data))
@@ -174,7 +184,7 @@ def create(self, model, **data):
174184

175185
@asyncio.coroutine
176186
def get_or_create(self, model, defaults=None, **kwargs):
177-
"""Try to get object or create it with specified defaults.
187+
"""Try to get an object or create it with the specified defaults.
178188
179189
Return 2-tuple containing the model instance and a boolean
180190
indicating whether the instance was created.
@@ -189,8 +199,8 @@ def get_or_create(self, model, defaults=None, **kwargs):
189199

190200
@asyncio.coroutine
191201
def update(self, obj, only=None):
192-
"""Update object in database. Optionally, update only specified
193-
fields. For creating new object use :meth:`.create()`
202+
"""Update the object in the database. Optionally, update only
203+
the specified fields. For creating a new object use :meth:`.create()`
194204
195205
:param only: (optional) the list/tuple of fields or
196206
field names to update
@@ -322,24 +332,16 @@ def savepoint(self, sid=None):
322332
"""
323333
return savepoint(self.database, sid=sid)
324334

325-
@contextlib.contextmanager
326335
def allow_sync(self):
327-
"""Allow sync queries within context. Close sync
328-
connection on exit if connected.
336+
"""Allow sync queries within context. Close the sync
337+
database connection on exit if connected.
329338
330339
Example::
331340
332341
with objects.allow_sync():
333342
PageBlock.create_table(True)
334343
"""
335-
old_allow = self.database.allow_sync
336-
self.database.allow_sync = True
337-
yield
338-
try:
339-
self.database.close()
340-
except self.database.Error:
341-
pass # already closed
342-
self.database.allow_sync = old_allow
344+
return self.database.allow_sync()
343345

344346
def _swap_database(self, query):
345347
"""Swap database for query if swappable. Return **new query**
@@ -817,27 +819,30 @@ def _get_result_wrapper(self, query):
817819
############
818820

819821
class AsyncDatabase:
820-
allow_sync = True # whether sync queries allowed
821-
loop = None # asyncio event loop
822+
_loop = None # asyncio event loop
823+
_allow_sync = True # whether sync queries are allowed
822824
_async_conn = None # async connection
823825
_async_wait = None # connection waiter
824-
_task_data = None # task context data
826+
_task_data = None # asyncio per-task data
827+
828+
def __setattr__(self, name, value):
829+
if name == 'allow_sync':
830+
warnings.warn(
831+
"`.allow_sync` setter is deprecated, use either the "
832+
"`.allow_sync()` context manager or `.set_allow_sync()` "
833+
"method.", DeprecationWarning)
834+
self._allow_sync = value
835+
else:
836+
super().__setattr__(name, value)
837+
838+
@property
839+
def loop(self):
840+
"""Get the event loop.
825841
826-
def set_event_loop(self, loop):
827-
"""Set event loop for the database. Usually, you don't need to
828-
call this directly. It's called from `Manager.connect()` or
829-
`.connect_async()` methods.
842+
If no event loop is provided explicitly on creating
843+
the instance, just return the current event loop.
830844
"""
831-
# These checks are not very pythonic, but I believe it's OK to be
832-
# a little paranoid about mismatching of asyncio event loops,
833-
# because such errors won't show clear traceback and could be
834-
# tricky to debug.
835-
loop = loop or asyncio.get_event_loop()
836-
if not self.loop:
837-
self.loop = loop
838-
elif self.loop != loop:
839-
raise RuntimeError("Error, the event loop is already set before. "
840-
"Make sure you're using the same event loop!")
845+
return self._loop or asyncio.get_event_loop()
841846

842847
@asyncio.coroutine
843848
def connect_async(self, loop=None, timeout=None):
@@ -853,12 +858,12 @@ def connect_async(self, loop=None, timeout=None):
853858
elif self._async_wait:
854859
yield from self._async_wait
855860
else:
856-
self.set_event_loop(loop)
857-
self._async_wait = asyncio.Future(loop=self.loop)
861+
self._loop = loop
862+
self._async_wait = asyncio.Future(loop=self._loop)
858863

859864
conn = self._async_conn_cls(
860865
database=self.database,
861-
loop=self.loop,
866+
loop=self._loop,
862867
timeout=timeout,
863868
**self.connect_kwargs_async)
864869

@@ -869,15 +874,15 @@ def connect_async(self, loop=None, timeout=None):
869874
self._async_wait = None
870875
raise
871876
else:
872-
self._task_data = TaskLocals(loop=self.loop)
877+
self._task_data = TaskLocals(loop=self._loop)
873878
self._async_conn = conn
874879
self._async_wait.set_result(True)
875880

876881
@asyncio.coroutine
877882
def cursor_async(self):
878883
"""Acquire async cursor.
879884
"""
880-
yield from self.connect_async(loop=self.loop)
885+
yield from self.connect_async(loop=self._loop)
881886

882887
if self.transaction_depth_async() > 0:
883888
conn = self.transaction_conn_async()
@@ -956,15 +961,47 @@ def savepoint_async(self, sid=None):
956961
"""
957962
return savepoint(self, sid=sid)
958963

964+
def set_allow_sync(self, value):
965+
"""Allow or forbid sync queries for the database. See also
966+
the :meth:`.allow_sync()` context manager.
967+
"""
968+
self._allow_sync = value
969+
970+
@contextlib.contextmanager
971+
def allow_sync(self):
972+
"""Allow sync queries within context. Close sync
973+
connection on exit if connected.
974+
975+
Example::
976+
977+
with database.allow_sync():
978+
PageBlock.create_table(True)
979+
"""
980+
old_allow_sync = self._allow_sync
981+
self._allow_sync = True
982+
983+
try:
984+
yield
985+
except:
986+
raise
987+
finally:
988+
try:
989+
self.close()
990+
except self.Error:
991+
pass # already closed
992+
993+
self._allow_sync = old_allow_sync
994+
959995
def execute_sql(self, *args, **kwargs):
960996
"""Sync execute SQL query, `allow_sync` must be set to True.
961997
"""
962-
assert self.allow_sync, ("Error, sync query is not allowed: "
963-
"allow_sync is False")
964-
if self.allow_sync in (logging.ERROR, logging.WARNING):
965-
logging.log(self.allow_sync,
966-
"Error, sync query is not allowed: %s %s" %
967-
str(args), str(kwargs))
998+
assert self._allow_sync, (
999+
"Error, sync query is not allowed! Call the `.set_allow_sync()` "
1000+
"or use the `.allow_sync()` context manager.")
1001+
if self._allow_sync in (logging.ERROR, logging.WARNING):
1002+
logging.log(self._allow_sync,
1003+
"Error, sync query is not allowed: %s %s" %
1004+
str(args), str(kwargs))
9681005
return super().execute_sql(*args, **kwargs)
9691006

9701007

@@ -1268,17 +1305,17 @@ def sync_unwanted(database):
12681305
`UnwantedSyncQueryError` exception will raise on such query.
12691306
12701307
NOTE: sync_unwanted() context manager is **deprecated**, use
1271-
database `allow_sync` property directly or via `Manager.allow_sync()`
1308+
database's `.allow_sync()` context manager or `Manager.allow_sync()`
12721309
context manager.
12731310
"""
12741311
warnings.warn("sync_unwanted() context manager is deprecated, "
1275-
"use database `allow_sync` property directly or "
1276-
"via Manager `allow_sync()` context manager. ",
1312+
"use database's `.allow_sync()` context manager or "
1313+
"`Manager.allow_sync()` context manager. ",
12771314
DeprecationWarning)
1278-
old_allow_sync = database.allow_sync
1279-
database.allow_sync = False
1315+
old_allow_sync = database._allow_sync
1316+
database._allow_sync = False
12801317
yield
1281-
database.allow_sync = old_allow_sync
1318+
database._allow_sync = old_allow_sync
12821319

12831320

12841321
class UnwantedSyncQueryError(Exception):

0 commit comments

Comments
 (0)