Skip to content

Commit 5f4736e

Browse files
committed
Rewrite again executemany() in a much readable way.
Now `Bind` and `Execute` pairs are batched into 4 x 32KB buffers to take advantage of `writelines()`. A single `Sync` is sent at last, so that all args live in the same transaction. Closes: MagicStack#289
1 parent 43a7b21 commit 5f4736e

File tree

7 files changed

+276
-106
lines changed

7 files changed

+276
-106
lines changed

asyncpg/connection.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -301,6 +301,13 @@ async def executemany(self, command: str, args, *, timeout: float=None):
301301
302302
.. versionchanged:: 0.11.0
303303
`timeout` became a keyword-only parameter.
304+
305+
.. versionchanged:: 0.19.0
306+
The execution was changed to be in an implicit transaction if there
307+
was no explicit transaction, so that it will no longer end up with
308+
partial success. If you still need the previous behavior to
309+
progressively execute many args, please use a loop with prepared
310+
statement instead.
304311
"""
305312
self._check_open()
306313
return await self._executemany(command, args, timeout)

asyncpg/prepared_stmt.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -196,11 +196,24 @@ async def fetchrow(self, *args, timeout=None):
196196
return None
197197
return data[0]
198198

199-
async def __bind_execute(self, args, limit, timeout):
199+
@connresource.guarded
200+
async def executemany(self, args, *, timeout: float=None):
201+
"""Execute the statement for each sequence of arguments in *args*.
202+
203+
:param args: An iterable containing sequences of arguments.
204+
:param float timeout: Optional timeout value in seconds.
205+
:return None: This method discards the results of the operations.
206+
207+
.. versionadded:: 0.19.0
208+
"""
209+
return await self.__do_execute(
210+
lambda protocol: protocol.bind_execute_many(
211+
self._state, args, '', timeout))
212+
213+
async def __do_execute(self, executor):
200214
protocol = self._connection._protocol
201215
try:
202-
data, status, _ = await protocol.bind_execute(
203-
self._state, args, '', limit, True, timeout)
216+
return await executor(protocol)
204217
except exceptions.OutdatedSchemaCacheError:
205218
await self._connection.reload_schema_state()
206219
# We can not find all manually created prepared statements, so just
@@ -209,6 +222,11 @@ async def __bind_execute(self, args, limit, timeout):
209222
# invalidate themselves (unfortunately, clearing caches again).
210223
self._state.mark_closed()
211224
raise
225+
226+
async def __bind_execute(self, args, limit, timeout):
227+
data, status, _ = await self.__do_execute(
228+
lambda protocol: protocol.bind_execute(
229+
self._state, args, '', limit, True, timeout))
212230
self._last_status = status
213231
return data
214232

asyncpg/protocol/consts.pxi

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,5 @@
88
DEF _MAXINT32 = 2**31 - 1
99
DEF _COPY_BUFFER_SIZE = 524288
1010
DEF _COPY_SIGNATURE = b"PGCOPY\n\377\r\n\0"
11+
DEF _EXECUTE_MANY_BUF_NUM = 4
12+
DEF _EXECUTE_MANY_BUF_SIZE = 32768

asyncpg/protocol/coreproto.pxd

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -76,11 +76,6 @@ cdef class CoreProtocol:
7676
bint _skip_discard
7777
bint _discard_data
7878

79-
# executemany support data
80-
object _execute_iter
81-
str _execute_portal_name
82-
str _execute_stmt_name
83-
8479
ConnectionStatus con_status
8580
ProtocolState state
8681
TransactionStatus xact_status
@@ -105,6 +100,7 @@ cdef class CoreProtocol:
105100
# True - completed, False - suspended
106101
bint result_execute_completed
107102

103+
cpdef is_in_transaction(self)
108104
cdef _process__auth(self, char mtype)
109105
cdef _process__prepare(self, char mtype)
110106
cdef _process__bind_execute(self, char mtype)
@@ -135,6 +131,7 @@ cdef class CoreProtocol:
135131
cdef _auth_password_message_md5(self, bytes salt)
136132

137133
cdef _write(self, buf)
134+
cdef _writelines(self, list buffers)
138135

139136
cdef _read_server_messages(self)
140137

@@ -147,6 +144,8 @@ cdef class CoreProtocol:
147144
cdef WriteBuffer _build_bind_message(self, str portal_name,
148145
str stmt_name,
149146
WriteBuffer bind_data)
147+
cdef WriteBuffer _build_execute_message(self, str portal_name,
148+
int32_t limit)
150149

151150

152151
cdef _connect(self)
@@ -155,8 +154,11 @@ cdef class CoreProtocol:
155154
WriteBuffer bind_data, int32_t limit)
156155
cdef _bind_execute(self, str portal_name, str stmt_name,
157156
WriteBuffer bind_data, int32_t limit)
158-
cdef _bind_execute_many(self, str portal_name, str stmt_name,
159-
object bind_data)
157+
cdef _execute_many_init(self)
158+
cdef _execute_many_writelines(self, str portal_name, str stmt_name,
159+
object bind_data)
160+
cdef _execute_many_done(self, bint data_sent)
161+
cdef _execute_many_fail(self, object error)
160162
cdef _bind(self, str portal_name, str stmt_name,
161163
WriteBuffer bind_data)
162164
cdef _execute(self, str portal_name, int32_t limit)

asyncpg/protocol/coreproto.pyx

Lines changed: 79 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,13 @@ cdef class CoreProtocol:
2222
self.xact_status = PQTRANS_IDLE
2323
self.encoding = 'utf-8'
2424

25-
# executemany support data
26-
self._execute_iter = None
27-
self._execute_portal_name = None
28-
self._execute_stmt_name = None
29-
3025
self._reset_result()
3126

27+
cpdef is_in_transaction(self):
28+
# PQTRANS_INTRANS = idle, within transaction block
29+
# PQTRANS_INERROR = idle, within failed transaction
30+
return self.xact_status in (PQTRANS_INTRANS, PQTRANS_INERROR)
31+
3232
cdef _read_server_messages(self):
3333
cdef:
3434
char mtype
@@ -253,22 +253,7 @@ cdef class CoreProtocol:
253253
elif mtype == b'Z':
254254
# ReadyForQuery
255255
self._parse_msg_ready_for_query()
256-
if self.result_type == RESULT_FAILED:
257-
self._push_result()
258-
else:
259-
try:
260-
buf = <WriteBuffer>next(self._execute_iter)
261-
except StopIteration:
262-
self._push_result()
263-
except Exception as e:
264-
self.result_type = RESULT_FAILED
265-
self.result = e
266-
self._push_result()
267-
else:
268-
# Next iteration over the executemany() arg sequence
269-
self._send_bind_message(
270-
self._execute_portal_name, self._execute_stmt_name,
271-
buf, 0)
256+
self._push_result()
272257

273258
elif mtype == b'I':
274259
# EmptyQueryResponse
@@ -702,6 +687,17 @@ cdef class CoreProtocol:
702687
buf.end_message()
703688
return buf
704689

690+
cdef WriteBuffer _build_execute_message(self, str portal_name,
691+
int32_t limit):
692+
cdef WriteBuffer buf
693+
694+
buf = WriteBuffer.new_message(b'E')
695+
buf.write_str(portal_name, self.encoding) # name of the portal
696+
buf.write_int32(limit) # number of rows to return; 0 - all
697+
698+
buf.end_message()
699+
return buf
700+
705701
# API for subclasses
706702

707703
cdef _connect(self):
@@ -779,10 +775,7 @@ cdef class CoreProtocol:
779775
buf = self._build_bind_message(portal_name, stmt_name, bind_data)
780776
packet = buf
781777

782-
buf = WriteBuffer.new_message(b'E')
783-
buf.write_str(portal_name, self.encoding) # name of the portal
784-
buf.write_int32(limit) # number of rows to return; 0 - all
785-
buf.end_message()
778+
buf = self._build_execute_message(portal_name, limit)
786779
packet.write_buffer(buf)
787780

788781
packet.write_bytes(SYNC_MESSAGE)
@@ -801,30 +794,71 @@ cdef class CoreProtocol:
801794

802795
self._send_bind_message(portal_name, stmt_name, bind_data, limit)
803796

804-
cdef _bind_execute_many(self, str portal_name, str stmt_name,
805-
object bind_data):
806-
807-
cdef WriteBuffer buf
808-
797+
cdef _execute_many_init(self):
809798
self._ensure_connected()
810799
self._set_state(PROTOCOL_BIND_EXECUTE_MANY)
811800

812801
self.result = None
813802
self._discard_data = True
814-
self._execute_iter = bind_data
815-
self._execute_portal_name = portal_name
816-
self._execute_stmt_name = stmt_name
817803

818-
try:
819-
buf = <WriteBuffer>next(bind_data)
820-
except StopIteration:
821-
self._push_result()
822-
except Exception as e:
823-
self.result_type = RESULT_FAILED
824-
self.result = e
804+
cdef _execute_many_writelines(self, str portal_name, str stmt_name,
805+
object bind_data):
806+
cdef:
807+
WriteBuffer packet
808+
WriteBuffer buf
809+
list buffers = []
810+
811+
if self.result_type == RESULT_FAILED:
812+
raise StopIteration(False)
813+
814+
while len(buffers) < _EXECUTE_MANY_BUF_NUM:
815+
packet = WriteBuffer.new()
816+
817+
while packet.len() < _EXECUTE_MANY_BUF_SIZE:
818+
try:
819+
buf = <WriteBuffer>next(bind_data)
820+
except StopIteration:
821+
if packet.len() > 0:
822+
buffers.append(packet)
823+
if len(buffers) > 0:
824+
self._writelines(buffers)
825+
raise StopIteration(True)
826+
else:
827+
raise StopIteration(False)
828+
except Exception as ex:
829+
raise StopIteration(ex)
830+
buf = self._build_bind_message(portal_name, stmt_name, buf)
831+
packet.write_buffer(buf)
832+
buf = self._build_execute_message(portal_name, 0)
833+
packet.write_buffer(buf)
834+
buffers.append(packet)
835+
self._writelines(buffers)
836+
837+
cdef _execute_many_done(self, bint data_sent):
838+
if data_sent:
839+
self._write(SYNC_MESSAGE)
840+
else:
825841
self._push_result()
842+
843+
cdef _execute_many_fail(self, object error):
844+
cdef WriteBuffer buf
845+
846+
self.result_type = RESULT_FAILED
847+
self.result = error
848+
849+
# We shall rollback in an implicit transaction to prevent partial
850+
# commit, while do nothing in an explicit transaction and leaving the
851+
# error to the user
852+
if self.is_in_transaction():
853+
self._execute_many_done(True)
826854
else:
827-
self._send_bind_message(portal_name, stmt_name, buf, 0)
855+
buf = WriteBuffer.new_message(b'Q')
856+
# ROLLBACK here won't cause server to send RowDescription,
857+
# CopyInResponse or CopyOutResponse which we are not expecting, but
858+
# the server will send ReadyForQuery which finishes executemany()
859+
buf.write_str('ROLLBACK', self.encoding)
860+
buf.end_message()
861+
self._write(buf)
828862

829863
cdef _execute(self, str portal_name, int32_t limit):
830864
cdef WriteBuffer buf
@@ -834,10 +868,7 @@ cdef class CoreProtocol:
834868

835869
self.result = []
836870

837-
buf = WriteBuffer.new_message(b'E')
838-
buf.write_str(portal_name, self.encoding) # name of the portal
839-
buf.write_int32(limit) # number of rows to return; 0 - all
840-
buf.end_message()
871+
buf = self._build_execute_message(portal_name, limit)
841872

842873
buf.write_bytes(SYNC_MESSAGE)
843874

@@ -920,6 +951,9 @@ cdef class CoreProtocol:
920951
cdef _write(self, buf):
921952
raise NotImplementedError
922953

954+
cdef _writelines(self, list buffers):
955+
raise NotImplementedError
956+
923957
cdef _decode_row(self, const char* buf, ssize_t buf_len):
924958
pass
925959

asyncpg/protocol/protocol.pyx

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -122,11 +122,6 @@ cdef class BaseProtocol(CoreProtocol):
122122
def get_settings(self):
123123
return self.settings
124124

125-
def is_in_transaction(self):
126-
# PQTRANS_INTRANS = idle, within transaction block
127-
# PQTRANS_INERROR = idle, within failed transaction
128-
return self.xact_status in (PQTRANS_INTRANS, PQTRANS_INERROR)
129-
130125
cdef inline resume_reading(self):
131126
if not self.is_reading:
132127
self.is_reading = True
@@ -216,15 +211,27 @@ cdef class BaseProtocol(CoreProtocol):
216211

217212
waiter = self._new_waiter(timeout)
218213
try:
219-
self._bind_execute_many(
220-
portal_name,
221-
state.name,
222-
arg_bufs) # network op
223-
214+
self._execute_many_init()
224215
self.last_query = state.query
225216
self.statement = state
226217
self.return_extra = False
227218
self.queries_count += 1
219+
220+
data_sent = False
221+
while True:
222+
self._execute_many_writelines(
223+
portal_name,
224+
state.name,
225+
arg_bufs) # network op
226+
data_sent = True
227+
await self.writing_allowed.wait()
228+
except StopIteration as ex:
229+
if ex.value is True:
230+
self._execute_many_done(True) # network op
231+
elif ex.value is False:
232+
self._execute_many_done(data_sent) # network op
233+
else:
234+
self._execute_many_fail(ex.value) # network op
228235
except Exception as ex:
229236
waiter.set_exception(ex)
230237
self._coreproto_error()
@@ -880,6 +887,9 @@ cdef class BaseProtocol(CoreProtocol):
880887
cdef _write(self, buf):
881888
self.transport.write(memoryview(buf))
882889

890+
cdef _writelines(self, list buffers):
891+
self.transport.writelines(buffers)
892+
883893
# asyncio callbacks:
884894

885895
def data_received(self, data):

0 commit comments

Comments
 (0)