@@ -107,13 +107,23 @@ def __init__(self, database=None, *, loop=None):
107
107
("Error, database must be provided via "
108
108
"argument or class member." )
109
109
110
- self .loop = loop or asyncio .get_event_loop ()
111
110
self .database = database or self .database
111
+ self ._loop = loop
112
+
112
113
attach_callback = getattr (self .database , 'attach_callback' , None )
113
114
if attach_callback :
114
- attach_callback (lambda db : db . set_event_loop ( self . loop ))
115
+ attach_callback (lambda db : setattr ( db , '_loop' , loop ))
115
116
else :
116
- self .database .set_event_loop (self .loop )
117
+ self .database ._loop = loop
118
+
119
+ @property
120
+ def loop (self ):
121
+ """Get the event loop.
122
+
123
+ If no event loop is provided explicitly on creating
124
+ the instance, just return the current event loop.
125
+ """
126
+ return self ._loop or asyncio .get_event_loop ()
117
127
118
128
@property
119
129
def is_connected (self ):
@@ -123,7 +133,7 @@ def is_connected(self):
123
133
124
134
@asyncio .coroutine
125
135
def get (self , source , * args , ** kwargs ):
126
- """Get model instance.
136
+ """Get the model instance.
127
137
128
138
:param source: model or base query for lookup
129
139
@@ -159,7 +169,7 @@ async def my_async_func():
159
169
160
170
@asyncio .coroutine
161
171
def create (self , model , ** data ):
162
- """Create new object saved to database.
172
+ """Create a new object saved to database.
163
173
"""
164
174
inst = model (** data )
165
175
query = model .insert (** dict (inst ._data ))
@@ -174,7 +184,7 @@ def create(self, model, **data):
174
184
175
185
@asyncio .coroutine
176
186
def get_or_create (self , model , defaults = None , ** kwargs ):
177
- """Try to get object or create it with specified defaults.
187
+ """Try to get an object or create it with the specified defaults.
178
188
179
189
Return 2-tuple containing the model instance and a boolean
180
190
indicating whether the instance was created.
@@ -189,8 +199,8 @@ def get_or_create(self, model, defaults=None, **kwargs):
189
199
190
200
@asyncio .coroutine
191
201
def update (self , obj , only = None ):
192
- """Update object in database. Optionally, update only specified
193
- fields. For creating new object use :meth:`.create()`
202
+ """Update the object in the database. Optionally, update only
203
+ the specified fields. For creating a new object use :meth:`.create()`
194
204
195
205
:param only: (optional) the list/tuple of fields or
196
206
field names to update
@@ -322,24 +332,16 @@ def savepoint(self, sid=None):
322
332
"""
323
333
return savepoint (self .database , sid = sid )
324
334
325
- @contextlib .contextmanager
326
335
def allow_sync (self ):
327
- """Allow sync queries within context. Close sync
328
- connection on exit if connected.
336
+ """Allow sync queries within context. Close the sync
337
+ database connection on exit if connected.
329
338
330
339
Example::
331
340
332
341
with objects.allow_sync():
333
342
PageBlock.create_table(True)
334
343
"""
335
- old_allow = self .database .allow_sync
336
- self .database .allow_sync = True
337
- yield
338
- try :
339
- self .database .close ()
340
- except self .database .Error :
341
- pass # already closed
342
- self .database .allow_sync = old_allow
344
+ return self .database .allow_sync ()
343
345
344
346
def _swap_database (self , query ):
345
347
"""Swap database for query if swappable. Return **new query**
@@ -817,27 +819,30 @@ def _get_result_wrapper(self, query):
817
819
############
818
820
819
821
class AsyncDatabase :
820
- allow_sync = True # whether sync queries allowed
821
- loop = None # asyncio event loop
822
+ _loop = None # asyncio event loop
823
+ _allow_sync = True # whether sync queries are allowed
822
824
_async_conn = None # async connection
823
825
_async_wait = None # connection waiter
824
- _task_data = None # task context data
826
+ _task_data = None # asyncio per-task data
827
+
828
+ def __setattr__ (self , name , value ):
829
+ if name == 'allow_sync' :
830
+ warnings .warn (
831
+ "`.allow_sync` setter is deprecated, use either the "
832
+ "`.allow_sync()` context manager or `.set_allow_sync()` "
833
+ "method." , DeprecationWarning )
834
+ self ._allow_sync = value
835
+ else :
836
+ super ().__setattr__ (name , value )
837
+
838
+ @property
839
+ def loop (self ):
840
+ """Get the event loop.
825
841
826
- def set_event_loop (self , loop ):
827
- """Set event loop for the database. Usually, you don't need to
828
- call this directly. It's called from `Manager.connect()` or
829
- `.connect_async()` methods.
842
+ If no event loop is provided explicitly on creating
843
+ the instance, just return the current event loop.
830
844
"""
831
- # These checks are not very pythonic, but I believe it's OK to be
832
- # a little paranoid about mismatching of asyncio event loops,
833
- # because such errors won't show clear traceback and could be
834
- # tricky to debug.
835
- loop = loop or asyncio .get_event_loop ()
836
- if not self .loop :
837
- self .loop = loop
838
- elif self .loop != loop :
839
- raise RuntimeError ("Error, the event loop is already set before. "
840
- "Make sure you're using the same event loop!" )
845
+ return self ._loop or asyncio .get_event_loop ()
841
846
842
847
@asyncio .coroutine
843
848
def connect_async (self , loop = None , timeout = None ):
@@ -853,12 +858,12 @@ def connect_async(self, loop=None, timeout=None):
853
858
elif self ._async_wait :
854
859
yield from self ._async_wait
855
860
else :
856
- self .set_event_loop ( loop )
857
- self ._async_wait = asyncio .Future (loop = self .loop )
861
+ self ._loop = loop
862
+ self ._async_wait = asyncio .Future (loop = self ._loop )
858
863
859
864
conn = self ._async_conn_cls (
860
865
database = self .database ,
861
- loop = self .loop ,
866
+ loop = self ._loop ,
862
867
timeout = timeout ,
863
868
** self .connect_kwargs_async )
864
869
@@ -869,15 +874,15 @@ def connect_async(self, loop=None, timeout=None):
869
874
self ._async_wait = None
870
875
raise
871
876
else :
872
- self ._task_data = TaskLocals (loop = self .loop )
877
+ self ._task_data = TaskLocals (loop = self ._loop )
873
878
self ._async_conn = conn
874
879
self ._async_wait .set_result (True )
875
880
876
881
@asyncio .coroutine
877
882
def cursor_async (self ):
878
883
"""Acquire async cursor.
879
884
"""
880
- yield from self .connect_async (loop = self .loop )
885
+ yield from self .connect_async (loop = self ._loop )
881
886
882
887
if self .transaction_depth_async () > 0 :
883
888
conn = self .transaction_conn_async ()
@@ -956,15 +961,47 @@ def savepoint_async(self, sid=None):
956
961
"""
957
962
return savepoint (self , sid = sid )
958
963
964
+ def set_allow_sync (self , value ):
965
+ """Allow or forbid sync queries for the database. See also
966
+ the :meth:`.allow_sync()` context manager.
967
+ """
968
+ self ._allow_sync = value
969
+
970
+ @contextlib .contextmanager
971
+ def allow_sync (self ):
972
+ """Allow sync queries within context. Close sync
973
+ connection on exit if connected.
974
+
975
+ Example::
976
+
977
+ with database.allow_sync():
978
+ PageBlock.create_table(True)
979
+ """
980
+ old_allow_sync = self ._allow_sync
981
+ self ._allow_sync = True
982
+
983
+ try :
984
+ yield
985
+ except :
986
+ raise
987
+ finally :
988
+ try :
989
+ self .close ()
990
+ except self .Error :
991
+ pass # already closed
992
+
993
+ self ._allow_sync = old_allow_sync
994
+
959
995
def execute_sql (self , * args , ** kwargs ):
960
996
"""Sync execute SQL query, `allow_sync` must be set to True.
961
997
"""
962
- assert self .allow_sync , ("Error, sync query is not allowed: "
963
- "allow_sync is False" )
964
- if self .allow_sync in (logging .ERROR , logging .WARNING ):
965
- logging .log (self .allow_sync ,
966
- "Error, sync query is not allowed: %s %s" %
967
- str (args ), str (kwargs ))
998
+ assert self ._allow_sync , (
999
+ "Error, sync query is not allowed! Call the `.set_allow_sync()` "
1000
+ "or use the `.allow_sync()` context manager." )
1001
+ if self ._allow_sync in (logging .ERROR , logging .WARNING ):
1002
+ logging .log (self ._allow_sync ,
1003
+ "Error, sync query is not allowed: %s %s" %
1004
+ str (args ), str (kwargs ))
968
1005
return super ().execute_sql (* args , ** kwargs )
969
1006
970
1007
@@ -1268,17 +1305,17 @@ def sync_unwanted(database):
1268
1305
`UnwantedSyncQueryError` exception will raise on such query.
1269
1306
1270
1307
NOTE: sync_unwanted() context manager is **deprecated**, use
1271
- database ` allow_sync` property directly or via `Manager.allow_sync()`
1308
+ database's `. allow_sync()` context manager or `Manager.allow_sync()`
1272
1309
context manager.
1273
1310
"""
1274
1311
warnings .warn ("sync_unwanted() context manager is deprecated, "
1275
- "use database ` allow_sync` property directly or "
1276
- "via Manager ` allow_sync()` context manager. " ,
1312
+ "use database's `. allow_sync()` context manager or "
1313
+ "`Manager. allow_sync()` context manager. " ,
1277
1314
DeprecationWarning )
1278
- old_allow_sync = database .allow_sync
1279
- database .allow_sync = False
1315
+ old_allow_sync = database ._allow_sync
1316
+ database ._allow_sync = False
1280
1317
yield
1281
- database .allow_sync = old_allow_sync
1318
+ database ._allow_sync = old_allow_sync
1282
1319
1283
1320
1284
1321
class UnwantedSyncQueryError (Exception ):
0 commit comments