diff --git a/asyncpg/connection.py b/asyncpg/connection.py index 8910b43b..604c4a0f 100644 --- a/asyncpg/connection.py +++ b/asyncpg/connection.py @@ -148,6 +148,27 @@ async def execute(self, query: str, *args, timeout: float=None) -> str: True, timeout) return status.decode() + async def executemany(self, command: str, args, timeout: float=None): + """Execute an SQL *command* for each sequence of arguments in *args*. + + Example: + + .. code-block:: pycon + + >>> await con.executemany(''' + ... INSERT INTO mytab (a) VALUES ($1, $2, $3); + ... ''', [(1, 2, 3), (4, 5, 6)]) + + :param command: Command to execute. + :args: An iterable containing sequences of arguments. + :param float timeout: Optional timeout value in seconds. + :return None: This method discards the results of the operations. + + .. versionadded:: 0.7.0 + """ + stmt = await self._get_statement(command, timeout) + return await self._protocol.bind_execute_many(stmt, args, '', timeout) + async def _get_statement(self, query, timeout): cache = self._stmt_cache_max_size > 0 diff --git a/asyncpg/protocol/coreproto.pxd b/asyncpg/protocol/coreproto.pxd index 8496e014..b6845a23 100644 --- a/asyncpg/protocol/coreproto.pxd +++ b/asyncpg/protocol/coreproto.pxd @@ -21,10 +21,11 @@ cdef enum ProtocolState: PROTOCOL_AUTH = 10 PROTOCOL_PREPARE = 11 PROTOCOL_BIND_EXECUTE = 12 - PROTOCOL_CLOSE_STMT_PORTAL = 13 - PROTOCOL_SIMPLE_QUERY = 14 - PROTOCOL_EXECUTE = 15 - PROTOCOL_BIND = 16 + PROTOCOL_BIND_EXECUTE_MANY = 13 + PROTOCOL_CLOSE_STMT_PORTAL = 14 + PROTOCOL_SIMPLE_QUERY = 15 + PROTOCOL_EXECUTE = 16 + PROTOCOL_BIND = 17 cdef enum AuthenticationMessage: @@ -67,6 +68,12 @@ cdef class CoreProtocol: cdef: ReadBuffer buffer bint _skip_discard + bint _discard_data + + # executemany support data + object _execute_iter + str _execute_portal_name + str _execute_stmt_name ConnectionStatus con_status ProtocolState state @@ -95,6 +102,7 @@ cdef class CoreProtocol: cdef _process__auth(self, char mtype) cdef _process__prepare(self, char mtype) cdef _process__bind_execute(self, char mtype) + cdef _process__bind_execute_many(self, char mtype) cdef _process__close_stmt_portal(self, char mtype) cdef _process__simple_query(self, char mtype) cdef _process__bind(self, char mtype) @@ -129,8 +137,12 @@ cdef class CoreProtocol: cdef _connect(self) cdef _prepare(self, str stmt_name, str query) + cdef _send_bind_message(self, str portal_name, str stmt_name, + WriteBuffer bind_data, int32_t limit) cdef _bind_execute(self, str portal_name, str stmt_name, WriteBuffer bind_data, int32_t limit) + cdef _bind_execute_many(self, str portal_name, str stmt_name, + object bind_data) cdef _bind(self, str portal_name, str stmt_name, WriteBuffer bind_data) cdef _execute(self, str portal_name, int32_t limit) diff --git a/asyncpg/protocol/coreproto.pyx b/asyncpg/protocol/coreproto.pyx index ccdc2c52..4ea36b23 100644 --- a/asyncpg/protocol/coreproto.pyx +++ b/asyncpg/protocol/coreproto.pyx @@ -24,6 +24,11 @@ cdef class CoreProtocol: self._skip_discard = False + # executemany support data + self._execute_iter = None + self._execute_portal_name = None + self._execute_stmt_name = None + self._reset_result() cdef _write(self, buf): @@ -60,6 +65,9 @@ cdef class CoreProtocol: elif state == PROTOCOL_BIND_EXECUTE: self._process__bind_execute(mtype) + elif state == PROTOCOL_BIND_EXECUTE_MANY: + self._process__bind_execute_many(mtype) + elif state == PROTOCOL_EXECUTE: self._process__bind_execute(mtype) @@ -194,6 +202,49 @@ cdef class CoreProtocol: # EmptyQueryResponse self.buffer.consume_message() + cdef _process__bind_execute_many(self, char mtype): + cdef WriteBuffer buf + + if mtype == b'D': + # DataRow + self._parse_data_msgs() + + elif mtype == b's': + # PortalSuspended + self.buffer.consume_message() + + elif mtype == b'C': + # CommandComplete + self._parse_msg_command_complete() + + elif mtype == b'E': + # ErrorResponse + self._parse_msg_error_response(True) + + elif mtype == b'2': + # BindComplete + self.buffer.consume_message() + + elif mtype == b'Z': + # ReadyForQuery + self._parse_msg_ready_for_query() + if self.result_type == RESULT_FAILED: + self._push_result() + else: + try: + buf = next(self._execute_iter) + except StopIteration: + self._push_result() + else: + # Next iteration over the executemany() arg sequence + self._send_bind_message( + self._execute_portal_name, self._execute_stmt_name, + buf, 0) + + elif mtype == b'I': + # EmptyQueryResponse + self.buffer.consume_message() + cdef _process__bind(self, char mtype): if mtype == b'E': # ErrorResponse @@ -275,6 +326,14 @@ cdef class CoreProtocol: raise RuntimeError( '_parse_data_msgs: first message is not "D"') + if self._discard_data: + while True: + buf.consume_message() + if not buf.has_message() or buf.get_message_type() != b'D': + self._skip_discard = True + return + + if ASYNCPG_DEBUG: if type(self.result) is not list: raise RuntimeError( '_parse_data_msgs: result is not a list, but {!r}'. @@ -424,6 +483,7 @@ cdef class CoreProtocol: self.result_row_desc = None self.result_status_msg = None self.result_execute_completed = False + self._discard_data = False cdef _set_state(self, ProtocolState new_state): if new_state == PROTOCOL_IDLE: @@ -537,16 +597,11 @@ cdef class CoreProtocol: self.transport.write(memoryview(packet)) - cdef _bind_execute(self, str portal_name, str stmt_name, - WriteBuffer bind_data, int32_t limit): + cdef _send_bind_message(self, str portal_name, str stmt_name, + WriteBuffer bind_data, int32_t limit): cdef WriteBuffer buf - self._ensure_connected() - self._set_state(PROTOCOL_BIND_EXECUTE) - - self.result = [] - buf = self._build_bind_message(portal_name, stmt_name, bind_data) self._write(buf) @@ -558,6 +613,39 @@ cdef class CoreProtocol: self._write_sync_message() + cdef _bind_execute(self, str portal_name, str stmt_name, + WriteBuffer bind_data, int32_t limit): + + cdef WriteBuffer buf + + self._ensure_connected() + self._set_state(PROTOCOL_BIND_EXECUTE) + + self.result = [] + + self._send_bind_message(portal_name, stmt_name, bind_data, limit) + + cdef _bind_execute_many(self, str portal_name, str stmt_name, + object bind_data): + + cdef WriteBuffer buf + + self._ensure_connected() + self._set_state(PROTOCOL_BIND_EXECUTE_MANY) + + self.result = None + self._discard_data = True + self._execute_iter = bind_data + self._execute_portal_name = portal_name + self._execute_stmt_name = stmt_name + + try: + buf = next(bind_data) + except StopIteration: + self._push_result() + else: + self._send_bind_message(portal_name, stmt_name, buf, 0) + cdef _execute(self, str portal_name, int32_t limit): cdef WriteBuffer buf diff --git a/asyncpg/protocol/protocol.pyx b/asyncpg/protocol/protocol.pyx index adebd891..9ae9bc23 100644 --- a/asyncpg/protocol/protocol.pyx +++ b/asyncpg/protocol/protocol.pyx @@ -162,6 +162,37 @@ cdef class BaseProtocol(CoreProtocol): return await self._new_waiter(timeout) + async def bind_execute_many(self, PreparedStatementState state, args, + str portal_name, timeout): + + if self.cancel_waiter is not None: + await self.cancel_waiter + if self.cancel_sent_waiter is not None: + await self.cancel_sent_waiter + self.cancel_sent_waiter = None + + self._ensure_clear_state() + + # Make sure the argument sequence is encoded lazily with + # this generator expression to keep the memory pressure under + # control. + data_gen = (state._encode_bind_msg(b) for b in args) + arg_bufs = iter(data_gen) + + waiter = self._new_waiter(timeout) + + self._bind_execute_many( + portal_name, + state.name, + arg_bufs) + + self.last_query = state.query + self.statement = state + self.return_extra = False + self.queries_count += 1 + + return await waiter + async def bind(self, PreparedStatementState state, args, str portal_name, timeout): @@ -405,6 +436,9 @@ cdef class BaseProtocol(CoreProtocol): elif self.state == PROTOCOL_BIND_EXECUTE: self._on_result__bind_and_exec(waiter) + elif self.state == PROTOCOL_BIND_EXECUTE_MANY: + self._on_result__bind_and_exec(waiter) + elif self.state == PROTOCOL_EXECUTE: self._on_result__bind_and_exec(waiter) diff --git a/tests/test_execute.py b/tests/test_execute.py index 1db6773f..c4710157 100644 --- a/tests/test_execute.py +++ b/tests/test_execute.py @@ -97,3 +97,38 @@ async def test_execute_script_interrupted_terminate(self): await fut self.con.terminate() + + async def test_execute_many_1(self): + await self.con.execute('CREATE TEMP TABLE exmany (a text, b int)') + + try: + result = await self.con.executemany(''' + INSERT INTO exmany VALUES($1, $2) + ''', [ + ('a', 1), ('b', 2), ('c', 3), ('d', 4) + ]) + + self.assertIsNone(result) + + result = await self.con.fetch(''' + SELECT * FROM exmany + ''') + + self.assertEqual(result, [ + ('a', 1), ('b', 2), ('c', 3), ('d', 4) + ]) + + # Empty set + result = await self.con.executemany(''' + INSERT INTO exmany VALUES($1, $2) + ''', ()) + + result = await self.con.fetch(''' + SELECT * FROM exmany + ''') + + self.assertEqual(result, [ + ('a', 1), ('b', 2), ('c', 3), ('d', 4) + ]) + finally: + await self.con.execute('DROP TABLE exmany')