diff --git a/asyncpg/connection.py b/asyncpg/connection.py index 2a4eba86..1a4753d2 100644 --- a/asyncpg/connection.py +++ b/asyncpg/connection.py @@ -260,6 +260,15 @@ async def executemany(self, command: str, args, *, timeout: float=None): .. versionchanged:: 0.11.0 `timeout` became a keyword-only parameter. + + .. versionchanged:: 0.16.0 + The execution was changed to be in a implicit transaction if there + was no explicit transaction, so that it will no longer end up with + partial success. It also combined all args into one network packet + to reduce round-trip time, therefore you should make sure not to + blow up your memory with a super long iterable. If you still need + the previous behavior to progressively execute many args, please use + prepared statement instead. """ self._check_open() return await self._executemany(command, args, timeout) diff --git a/asyncpg/prepared_stmt.py b/asyncpg/prepared_stmt.py index 09a0a2ec..7a5ff8d7 100644 --- a/asyncpg/prepared_stmt.py +++ b/asyncpg/prepared_stmt.py @@ -196,6 +196,28 @@ async def fetchrow(self, *args, timeout=None): return None return data[0] + @connresource.guarded + async def executemany(self, args, *, timeout: float=None): + """Execute the statement for each sequence of arguments in *args*. + + This combines all args into one network packet, thus reduces round-trip + time than executing one by one. + + :param args: An iterable containing sequences of arguments. + :param float timeout: Optional timeout value in seconds. + :return None: This method discards the results of the operations. + + .. versionadded:: 0.16.0 + """ + protocol = self._connection._protocol + try: + return await protocol.bind_execute_many( + self._state, args, '', timeout) + except exceptions.OutdatedSchemaCacheError: + await self._connection.reload_schema_state() + self._state.mark_closed() + raise + async def __bind_execute(self, args, limit, timeout): protocol = self._connection._protocol try: diff --git a/asyncpg/protocol/coreproto.pxd b/asyncpg/protocol/coreproto.pxd index 60efa591..b9a88179 100644 --- a/asyncpg/protocol/coreproto.pxd +++ b/asyncpg/protocol/coreproto.pxd @@ -75,11 +75,6 @@ cdef class CoreProtocol: bint _skip_discard bint _discard_data - # executemany support data - object _execute_iter - str _execute_portal_name - str _execute_stmt_name - ConnectionStatus con_status ProtocolState state TransactionStatus xact_status diff --git a/asyncpg/protocol/coreproto.pyx b/asyncpg/protocol/coreproto.pyx index acfec953..f673711a 100644 --- a/asyncpg/protocol/coreproto.pyx +++ b/asyncpg/protocol/coreproto.pyx @@ -25,11 +25,6 @@ cdef class CoreProtocol: self._skip_discard = False - # executemany support data - self._execute_iter = None - self._execute_portal_name = None - self._execute_stmt_name = None - self._reset_result() cdef _write(self, buf): @@ -256,22 +251,7 @@ cdef class CoreProtocol: elif mtype == b'Z': # ReadyForQuery self._parse_msg_ready_for_query() - if self.result_type == RESULT_FAILED: - self._push_result() - else: - try: - buf = next(self._execute_iter) - except StopIteration: - self._push_result() - except Exception as e: - self.result_type = RESULT_FAILED - self.result = e - self._push_result() - else: - # Next iteration over the executemany() arg sequence - self._send_bind_message( - self._execute_portal_name, self._execute_stmt_name, - buf, 0) + self._push_result() elif mtype == b'I': # EmptyQueryResponse @@ -799,27 +779,42 @@ cdef class CoreProtocol: cdef _bind_execute_many(self, str portal_name, str stmt_name, object bind_data): - cdef WriteBuffer buf + cdef: + WriteBuffer packet + WriteBuffer buf self._ensure_connected() self._set_state(PROTOCOL_BIND_EXECUTE_MANY) + packet = WriteBuffer.new() + self.result = None self._discard_data = True - self._execute_iter = bind_data - self._execute_portal_name = portal_name - self._execute_stmt_name = stmt_name - try: - buf = next(bind_data) - except StopIteration: - self._push_result() - except Exception as e: - self.result_type = RESULT_FAILED - self.result = e - self._push_result() - else: - self._send_bind_message(portal_name, stmt_name, buf, 0) + while True: + try: + buf = next(bind_data) + except StopIteration: + if packet.len() > 0: + packet.write_bytes(SYNC_MESSAGE) + self.transport.write(memoryview(packet)) + else: + self._push_result() + break + except Exception as e: + self.result_type = RESULT_FAILED + self.result = e + self._push_result() + break + else: + buf = self._build_bind_message(portal_name, stmt_name, buf) + packet.write_buffer(buf) + + buf = WriteBuffer.new_message(b'E') + buf.write_str(portal_name, self.encoding) # name of the portal + buf.write_int32(0) # number of rows to return; 0 - all + buf.end_message() + packet.write_buffer(buf) cdef _execute(self, str portal_name, int32_t limit): cdef WriteBuffer buf diff --git a/tests/test_execute.py b/tests/test_execute.py index ccde0993..ee02486b 100644 --- a/tests/test_execute.py +++ b/tests/test_execute.py @@ -151,3 +151,25 @@ async def test_execute_many_2(self): ''', good_data) finally: await self.con.execute('DROP TABLE exmany') + + async def test_execute_many_atomic(self): + from asyncpg.exceptions import UniqueViolationError + + await self.con.execute('CREATE TEMP TABLE exmany ' + '(a text, b int PRIMARY KEY)') + + try: + with self.assertRaises(UniqueViolationError): + await self.con.executemany(''' + INSERT INTO exmany VALUES($1, $2) + ''', [ + ('a', 1), ('b', 2), ('c', 2), ('d', 4) + ]) + + result = await self.con.fetch(''' + SELECT * FROM exmany + ''') + + self.assertEqual(result, []) + finally: + await self.con.execute('DROP TABLE exmany') diff --git a/tests/test_prepare.py b/tests/test_prepare.py index 8fc06e3e..d1753697 100644 --- a/tests/test_prepare.py +++ b/tests/test_prepare.py @@ -600,3 +600,38 @@ async def test_prepare_does_not_use_cache(self): # prepare with disabled cache await self.con.prepare('select 1') self.assertEqual(len(cache), 0) + + async def test_prepare_executemany(self): + await self.con.execute('CREATE TEMP TABLE exmany (a text, b int)') + + try: + stmt = await self.con.prepare(''' + INSERT INTO exmany VALUES($1, $2) + ''') + + result = await stmt.executemany([ + ('a', 1), ('b', 2), ('c', 3), ('d', 4) + ]) + + self.assertIsNone(result) + + result = await self.con.fetch(''' + SELECT * FROM exmany + ''') + + self.assertEqual(result, [ + ('a', 1), ('b', 2), ('c', 3), ('d', 4) + ]) + + # Empty set + await stmt.executemany(()) + + result = await self.con.fetch(''' + SELECT * FROM exmany + ''') + + self.assertEqual(result, [ + ('a', 1), ('b', 2), ('c', 3), ('d', 4) + ]) + finally: + await self.con.execute('DROP TABLE exmany')