Skip to content

Commit 4997888

Browse files
committed
Refs MagicStack#289, rewrite bind_execute_many()
1 parent 9a55db5 commit 4997888

File tree

6 files changed

+310
-106
lines changed

6 files changed

+310
-106
lines changed

asyncpg/connection.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -252,14 +252,24 @@ async def executemany(self, command: str, args, *, timeout: float=None):
252252
... ''', [(1, 2, 3), (4, 5, 6)])
253253
254254
:param command: Command to execute.
255-
:param args: An iterable containing sequences of arguments.
255+
:param args: An (async) iterable containing sequences of arguments.
256256
:param float timeout: Optional timeout value in seconds.
257257
:return None: This method discards the results of the operations.
258258
259259
.. versionadded:: 0.7.0
260260
261261
.. versionchanged:: 0.11.0
262262
`timeout` became a keyword-only parameter.
263+
264+
.. versionchanged:: 0.16.0
265+
The execution was changed to be in an implicit transaction if there
266+
was no explicit transaction, so that it will no longer end up with
267+
partial success. If you still need the previous behavior to
268+
progressively execute many args, please use a loop with prepared
269+
statement instead.
270+
271+
.. versionchanged:: 0.16.0
272+
``args`` could be an async iterable.
263273
"""
264274
self._check_open()
265275
return await self._executemany(command, args, timeout)

asyncpg/protocol/consts.pxi

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,5 @@ DEF _MAXINT32 = 2**31 - 1
1414
DEF _COPY_BUFFER_SIZE = 524288
1515
DEF _COPY_SIGNATURE = b"PGCOPY\n\377\r\n\0"
1616
DEF _NUMERIC_DECODER_SMALLBUF_SIZE = 256
17+
DEF _EXECUTE_MANY_BUF_NUM = 4
18+
DEF _EXECUTE_MANY_BUF_SIZE = 32768

asyncpg/protocol/coreproto.pxd

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ cdef enum ProtocolState:
3131
PROTOCOL_COPY_OUT_DONE = 20
3232
PROTOCOL_COPY_IN = 21
3333
PROTOCOL_COPY_IN_DATA = 22
34+
PROTOCOL_BIND_EXECUTE_MANY_ROLLBACK = 23
3435

3536

3637
cdef enum AuthenticationMessage:
@@ -75,11 +76,6 @@ cdef class CoreProtocol:
7576
bint _skip_discard
7677
bint _discard_data
7778

78-
# executemany support data
79-
object _execute_iter
80-
str _execute_portal_name
81-
str _execute_stmt_name
82-
8379
ConnectionStatus con_status
8480
ProtocolState state
8581
TransactionStatus xact_status
@@ -158,8 +154,10 @@ cdef class CoreProtocol:
158154
WriteBuffer bind_data, int32_t limit)
159155
cdef _bind_execute(self, str portal_name, str stmt_name,
160156
WriteBuffer bind_data, int32_t limit)
161-
cdef _bind_execute_many(self, str portal_name, str stmt_name,
162-
object bind_data)
157+
cdef _execute_many_init(self)
158+
cdef _execute_many_send(self, object buffers)
159+
cdef _execute_many_rollback(self, object error)
160+
cdef _execute_many_end(self, object sync)
163161
cdef _bind(self, str portal_name, str stmt_name,
164162
WriteBuffer bind_data)
165163
cdef _execute(self, str portal_name, int32_t limit)

asyncpg/protocol/coreproto.pyx

Lines changed: 50 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,6 @@ cdef class CoreProtocol:
2525

2626
self._skip_discard = False
2727

28-
# executemany support data
29-
self._execute_iter = None
30-
self._execute_portal_name = None
31-
self._execute_stmt_name = None
32-
3328
self._reset_result()
3429

3530
cdef _write(self, buf):
@@ -70,7 +65,8 @@ cdef class CoreProtocol:
7065
elif state == PROTOCOL_BIND_EXECUTE:
7166
self._process__bind_execute(mtype)
7267

73-
elif state == PROTOCOL_BIND_EXECUTE_MANY:
68+
elif (state == PROTOCOL_BIND_EXECUTE_MANY or
69+
state == PROTOCOL_BIND_EXECUTE_MANY_ROLLBACK):
7470
self._process__bind_execute_many(mtype)
7571

7672
elif state == PROTOCOL_EXECUTE:
@@ -247,7 +243,9 @@ cdef class CoreProtocol:
247243

248244
elif mtype == b'E':
249245
# ErrorResponse
250-
self._parse_msg_error_response(True)
246+
# Ignoring expected rollback error
247+
self._parse_msg_error_response(
248+
self.state != PROTOCOL_BIND_EXECUTE_MANY_ROLLBACK)
251249

252250
elif mtype == b'2':
253251
# BindComplete
@@ -256,22 +254,7 @@ cdef class CoreProtocol:
256254
elif mtype == b'Z':
257255
# ReadyForQuery
258256
self._parse_msg_ready_for_query()
259-
if self.result_type == RESULT_FAILED:
260-
self._push_result()
261-
else:
262-
try:
263-
buf = <WriteBuffer>next(self._execute_iter)
264-
except StopIteration:
265-
self._push_result()
266-
except Exception as e:
267-
self.result_type = RESULT_FAILED
268-
self.result = e
269-
self._push_result()
270-
else:
271-
# Next iteration over the executemany() arg sequence
272-
self._send_bind_message(
273-
self._execute_portal_name, self._execute_stmt_name,
274-
buf, 0)
257+
self._push_result()
275258

276259
elif mtype == b'I':
277260
# EmptyQueryResponse
@@ -670,6 +653,10 @@ cdef class CoreProtocol:
670653
new_state == PROTOCOL_COPY_IN_DATA):
671654
self.state = new_state
672655

656+
elif (self.state == PROTOCOL_BIND_EXECUTE_MANY and
657+
new_state == PROTOCOL_BIND_EXECUTE_MANY_ROLLBACK):
658+
self.state = new_state
659+
673660
elif self.state == PROTOCOL_FAILED:
674661
raise apg_exc.InternalClientError(
675662
'cannot switch to state {}; '
@@ -811,30 +798,55 @@ cdef class CoreProtocol:
811798

812799
self._send_bind_message(portal_name, stmt_name, bind_data, limit)
813800

814-
cdef _bind_execute_many(self, str portal_name, str stmt_name,
815-
object bind_data):
816-
817-
cdef WriteBuffer buf
818-
801+
cdef _execute_many_init(self):
819802
self._ensure_connected()
820803
self._set_state(PROTOCOL_BIND_EXECUTE_MANY)
821804

822805
self.result = None
823806
self._discard_data = True
824-
self._execute_iter = bind_data
825-
self._execute_portal_name = portal_name
826-
self._execute_stmt_name = stmt_name
827807

828-
try:
829-
buf = <WriteBuffer>next(bind_data)
830-
except StopIteration:
808+
cdef _execute_many_send(self, object buffers):
809+
self._ensure_connected()
810+
self.transport.writelines(buffers)
811+
812+
cdef _execute_many_rollback(self, object error):
813+
cdef:
814+
WriteBuffer packet
815+
WriteBuffer buf
816+
817+
self._ensure_connected()
818+
if self.state != PROTOCOL_CANCELLED:
819+
self._set_state(PROTOCOL_BIND_EXECUTE_MANY_ROLLBACK)
820+
packet = WriteBuffer.new()
821+
822+
# We have no idea if we are in an explicit transaction or not,
823+
# therefore we raise an exception in the database to mark current
824+
# transaction as failed anyway to prevent partial commit
825+
buf = self._build_parse_message('',
826+
'DO language plpgsql $$ BEGIN '
827+
'RAISE transaction_rollback; '
828+
'END$$;')
829+
packet.write_buffer(buf)
830+
buf = self._build_bind_message('', '', WriteBuffer.new())
831+
packet.write_buffer(buf)
832+
buf = self._build_execute_message('', 0)
833+
packet.write_buffer(buf)
834+
packet.write_bytes(SYNC_MESSAGE)
835+
836+
self.transport.write(memoryview(packet))
837+
838+
self.result_type = RESULT_FAILED
839+
self.result = error
840+
841+
cdef _execute_many_end(self, object sync):
842+
if sync is True:
843+
self._write_sync_message()
844+
elif sync is False:
831845
self._push_result()
832-
except Exception as e:
846+
else:
833847
self.result_type = RESULT_FAILED
834-
self.result = e
848+
self.result = sync
835849
self._push_result()
836-
else:
837-
self._send_bind_message(portal_name, stmt_name, buf, 0)
838850

839851
cdef _execute(self, str portal_name, int32_t limit):
840852
cdef WriteBuffer buf

asyncpg/protocol/protocol.pyx

Lines changed: 121 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,9 @@ cdef class BaseProtocol(CoreProtocol):
208208
@cython.iterable_coroutine
209209
async def bind_execute_many(self, PreparedStatementState state, args,
210210
str portal_name, timeout):
211+
cdef:
212+
WriteBuffer buf
213+
WriteBuffer packet
211214

212215
if self.cancel_waiter is not None:
213216
await self.cancel_waiter
@@ -217,29 +220,128 @@ cdef class BaseProtocol(CoreProtocol):
217220

218221
self._check_state()
219222
timeout = self._get_timeout_impl(timeout)
223+
timer = Timer(timeout)
220224

221-
# Make sure the argument sequence is encoded lazily with
222-
# this generator expression to keep the memory pressure under
223-
# control.
224-
data_gen = (state._encode_bind_msg(b) for b in args)
225-
arg_bufs = iter(data_gen)
225+
try:
226+
aiter = args.__aiter__
227+
except AttributeError:
228+
iterator = args.__iter__()
229+
async_iterator = False
230+
else:
231+
iterator = aiter()
232+
async_iterator = True
233+
234+
waiter = self._new_waiter(timer.get_remaining_budget())
235+
236+
self._execute_many_init()
237+
self.last_query = state.query
238+
self.statement = state
239+
self.return_extra = False
240+
self.queries_count += 1
241+
242+
more = True # for marking end of input
243+
should_commit = False # for marking empty input
244+
rollback_on_error = False # for input fails before send to DB
226245

227-
waiter = self._new_waiter(timeout)
228246
try:
229-
self._bind_execute_many(
230-
portal_name,
231-
state.name,
232-
arg_bufs) # network op
247+
# Send all input data until end of iterable reached
248+
while more:
233249

234-
self.last_query = state.query
235-
self.statement = state
236-
self.return_extra = False
237-
self.queries_count += 1
238-
except Exception as ex:
239-
waiter.set_exception(ex)
240-
self._coreproto_error()
241-
finally:
242-
return await waiter
250+
# Fill up 4 buffers to take advantage of uvloop `writelines()`
251+
buffers = []
252+
while more and len(buffers) < _EXECUTE_MANY_BUF_NUM:
253+
254+
# A new buffer
255+
packet = WriteBuffer.new()
256+
257+
try:
258+
# Fill up this buffer up to 32KB or a bit more
259+
while more and packet.len() < _EXECUTE_MANY_BUF_SIZE:
260+
261+
# Grab the next item from input
262+
if async_iterator:
263+
with timer:
264+
args_item = await asyncio.wait_for(
265+
iterator.__anext__(),
266+
timeout=timer.get_remaining_budget(),
267+
loop=self.loop)
268+
269+
# If server issues an error in parallel, we
270+
# abort with SYNC now
271+
if self.result_type == RESULT_FAILED:
272+
packet.write_bytes(SYNC_MESSAGE)
273+
more = False
274+
break
275+
else:
276+
args_item = iterator.__next__()
277+
278+
# Encode the arguments
279+
args_buf = state._encode_bind_msg(args_item)
280+
281+
# And write into the buffer on protocol
282+
buf = self._build_bind_message(
283+
portal_name, state.name, args_buf)
284+
packet.write_buffer(buf)
285+
buf = self._build_execute_message(portal_name, 0)
286+
packet.write_buffer(buf)
287+
288+
# Mark that we should send SYNC in the end
289+
should_commit = True
290+
291+
except (builtins.StopIteration,
292+
builtins.StopAsyncIteration):
293+
294+
if should_commit:
295+
# End of input reached, append SYNC in current
296+
# buffer and mark the end to escape loops
297+
packet.write_bytes(SYNC_MESSAGE)
298+
more = False
299+
300+
else:
301+
# Break loops and handle later to unify results
302+
raise
303+
304+
# One buffer ready, 3/2/1/no more to go
305+
buffers.append(packet)
306+
307+
# All 4 buffers are ready now, send them once wire has capacity
308+
with timer:
309+
await self.writing_allowed.wait()
310+
311+
if self.result_type == RESULT_FAILED:
312+
# In any case if server returns an error, we should stop
313+
# sending more data, and abort with a SYNC message
314+
self._execute_many_end(True) # network op
315+
more = False
316+
317+
else:
318+
# This uses `loop.writelines()` for its possible efficiency
319+
self._execute_many_send(buffers) # network op
320+
321+
# If input fails, we need to rollback from now on
322+
rollback_on_error = True
323+
324+
except (builtins.StopIteration, builtins.StopAsyncIteration):
325+
# Input is empty, just reset the state
326+
self._execute_many_end(False)
327+
328+
except Exception as e:
329+
# It fails for:
330+
# 1. input iteration failure;
331+
# 2. timeout;
332+
# 3. other protocol code error.
333+
334+
if rollback_on_error:
335+
# We sent some data previously without SYNC, for atomicity we
336+
# should rollback, wait for ReadyForQuery then re-raise
337+
self._execute_many_rollback(e) # network op
338+
339+
else:
340+
# If we never sent anything to DB, it is safe to reset the
341+
# state and re-raise
342+
self._execute_many_end(e)
343+
344+
return await waiter
243345

244346
@cython.iterable_coroutine
245347
async def bind(self, PreparedStatementState state, args,

0 commit comments

Comments
 (0)