Skip to content

Implement the Connection.executemany() method #45

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 16, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions asyncpg/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,27 @@ async def execute(self, query: str, *args, timeout: float=None) -> str:
True, timeout)
return status.decode()

async def executemany(self, command: str, args, timeout: float=None):
"""Execute an SQL *command* for each sequence of arguments in *args*.

Example:

.. code-block:: pycon

>>> await con.executemany('''
... INSERT INTO mytab (a) VALUES ($1, $2, $3);
... ''', [(1, 2, 3), (4, 5, 6)])

:param command: Command to execute.
:args: An iterable containing sequences of arguments.
:param float timeout: Optional timeout value in seconds.
:return None: This method discards the results of the operations.

.. versionadded:: 0.7.0
"""
stmt = await self._get_statement(command, timeout)
return await self._protocol.bind_execute_many(stmt, args, '', timeout)

async def _get_statement(self, query, timeout):
cache = self._stmt_cache_max_size > 0

Expand Down
20 changes: 16 additions & 4 deletions asyncpg/protocol/coreproto.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,11 @@ cdef enum ProtocolState:
PROTOCOL_AUTH = 10
PROTOCOL_PREPARE = 11
PROTOCOL_BIND_EXECUTE = 12
PROTOCOL_CLOSE_STMT_PORTAL = 13
PROTOCOL_SIMPLE_QUERY = 14
PROTOCOL_EXECUTE = 15
PROTOCOL_BIND = 16
PROTOCOL_BIND_EXECUTE_MANY = 13
PROTOCOL_CLOSE_STMT_PORTAL = 14
PROTOCOL_SIMPLE_QUERY = 15
PROTOCOL_EXECUTE = 16
PROTOCOL_BIND = 17


cdef enum AuthenticationMessage:
Expand Down Expand Up @@ -67,6 +68,12 @@ cdef class CoreProtocol:
cdef:
ReadBuffer buffer
bint _skip_discard
bint _discard_data

# executemany support data
object _execute_iter
str _execute_portal_name
str _execute_stmt_name

ConnectionStatus con_status
ProtocolState state
Expand Down Expand Up @@ -95,6 +102,7 @@ cdef class CoreProtocol:
cdef _process__auth(self, char mtype)
cdef _process__prepare(self, char mtype)
cdef _process__bind_execute(self, char mtype)
cdef _process__bind_execute_many(self, char mtype)
cdef _process__close_stmt_portal(self, char mtype)
cdef _process__simple_query(self, char mtype)
cdef _process__bind(self, char mtype)
Expand Down Expand Up @@ -129,8 +137,12 @@ cdef class CoreProtocol:

cdef _connect(self)
cdef _prepare(self, str stmt_name, str query)
cdef _send_bind_message(self, str portal_name, str stmt_name,
WriteBuffer bind_data, int32_t limit)
cdef _bind_execute(self, str portal_name, str stmt_name,
WriteBuffer bind_data, int32_t limit)
cdef _bind_execute_many(self, str portal_name, str stmt_name,
object bind_data)
cdef _bind(self, str portal_name, str stmt_name,
WriteBuffer bind_data)
cdef _execute(self, str portal_name, int32_t limit)
Expand Down
102 changes: 95 additions & 7 deletions asyncpg/protocol/coreproto.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,11 @@ cdef class CoreProtocol:

self._skip_discard = False

# executemany support data
self._execute_iter = None
self._execute_portal_name = None
self._execute_stmt_name = None

self._reset_result()

cdef _write(self, buf):
Expand Down Expand Up @@ -60,6 +65,9 @@ cdef class CoreProtocol:
elif state == PROTOCOL_BIND_EXECUTE:
self._process__bind_execute(mtype)

elif state == PROTOCOL_BIND_EXECUTE_MANY:
self._process__bind_execute_many(mtype)

elif state == PROTOCOL_EXECUTE:
self._process__bind_execute(mtype)

Expand Down Expand Up @@ -194,6 +202,49 @@ cdef class CoreProtocol:
# EmptyQueryResponse
self.buffer.consume_message()

cdef _process__bind_execute_many(self, char mtype):
cdef WriteBuffer buf

if mtype == b'D':
# DataRow
self._parse_data_msgs()

elif mtype == b's':
# PortalSuspended
self.buffer.consume_message()

elif mtype == b'C':
# CommandComplete
self._parse_msg_command_complete()

elif mtype == b'E':
# ErrorResponse
self._parse_msg_error_response(True)

elif mtype == b'2':
# BindComplete
self.buffer.consume_message()

elif mtype == b'Z':
# ReadyForQuery
self._parse_msg_ready_for_query()
if self.result_type == RESULT_FAILED:
self._push_result()
else:
try:
buf = <WriteBuffer>next(self._execute_iter)
except StopIteration:
self._push_result()
else:
# Next iteration over the executemany() arg sequence
self._send_bind_message(
self._execute_portal_name, self._execute_stmt_name,
buf, 0)

elif mtype == b'I':
# EmptyQueryResponse
self.buffer.consume_message()

cdef _process__bind(self, char mtype):
if mtype == b'E':
# ErrorResponse
Expand Down Expand Up @@ -275,6 +326,14 @@ cdef class CoreProtocol:
raise RuntimeError(
'_parse_data_msgs: first message is not "D"')

if self._discard_data:
while True:
buf.consume_message()
if not buf.has_message() or buf.get_message_type() != b'D':
self._skip_discard = True
return

if ASYNCPG_DEBUG:
if type(self.result) is not list:
raise RuntimeError(
'_parse_data_msgs: result is not a list, but {!r}'.
Expand Down Expand Up @@ -424,6 +483,7 @@ cdef class CoreProtocol:
self.result_row_desc = None
self.result_status_msg = None
self.result_execute_completed = False
self._discard_data = False

cdef _set_state(self, ProtocolState new_state):
if new_state == PROTOCOL_IDLE:
Expand Down Expand Up @@ -537,16 +597,11 @@ cdef class CoreProtocol:

self.transport.write(memoryview(packet))

cdef _bind_execute(self, str portal_name, str stmt_name,
WriteBuffer bind_data, int32_t limit):
cdef _send_bind_message(self, str portal_name, str stmt_name,
WriteBuffer bind_data, int32_t limit):

cdef WriteBuffer buf

self._ensure_connected()
self._set_state(PROTOCOL_BIND_EXECUTE)

self.result = []

buf = self._build_bind_message(portal_name, stmt_name, bind_data)
self._write(buf)

Expand All @@ -558,6 +613,39 @@ cdef class CoreProtocol:

self._write_sync_message()

cdef _bind_execute(self, str portal_name, str stmt_name,
WriteBuffer bind_data, int32_t limit):

cdef WriteBuffer buf

self._ensure_connected()
self._set_state(PROTOCOL_BIND_EXECUTE)

self.result = []

self._send_bind_message(portal_name, stmt_name, bind_data, limit)

cdef _bind_execute_many(self, str portal_name, str stmt_name,
object bind_data):

cdef WriteBuffer buf

self._ensure_connected()
self._set_state(PROTOCOL_BIND_EXECUTE_MANY)

self.result = None
self._discard_data = True
self._execute_iter = bind_data
self._execute_portal_name = portal_name
self._execute_stmt_name = stmt_name

try:
buf = <WriteBuffer>next(bind_data)
except StopIteration:
self._push_result()
else:
self._send_bind_message(portal_name, stmt_name, buf, 0)

cdef _execute(self, str portal_name, int32_t limit):
cdef WriteBuffer buf

Expand Down
34 changes: 34 additions & 0 deletions asyncpg/protocol/protocol.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,37 @@ cdef class BaseProtocol(CoreProtocol):

return await self._new_waiter(timeout)

async def bind_execute_many(self, PreparedStatementState state, args,
str portal_name, timeout):

if self.cancel_waiter is not None:
await self.cancel_waiter
if self.cancel_sent_waiter is not None:
await self.cancel_sent_waiter
self.cancel_sent_waiter = None

self._ensure_clear_state()

# Make sure the argument sequence is encoded lazily with
# this generator expression to keep the memory pressure under
# control.
data_gen = (state._encode_bind_msg(b) for b in args)
arg_bufs = iter(data_gen)

waiter = self._new_waiter(timeout)

self._bind_execute_many(
portal_name,
state.name,
arg_bufs)

self.last_query = state.query
self.statement = state
self.return_extra = False
self.queries_count += 1

return await waiter

async def bind(self, PreparedStatementState state, args,
str portal_name, timeout):

Expand Down Expand Up @@ -405,6 +436,9 @@ cdef class BaseProtocol(CoreProtocol):
elif self.state == PROTOCOL_BIND_EXECUTE:
self._on_result__bind_and_exec(waiter)

elif self.state == PROTOCOL_BIND_EXECUTE_MANY:
self._on_result__bind_and_exec(waiter)

elif self.state == PROTOCOL_EXECUTE:
self._on_result__bind_and_exec(waiter)

Expand Down
35 changes: 35 additions & 0 deletions tests/test_execute.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,3 +97,38 @@ async def test_execute_script_interrupted_terminate(self):
await fut

self.con.terminate()

async def test_execute_many_1(self):
await self.con.execute('CREATE TEMP TABLE exmany (a text, b int)')

try:
result = await self.con.executemany('''
INSERT INTO exmany VALUES($1, $2)
''', [
('a', 1), ('b', 2), ('c', 3), ('d', 4)
])

self.assertIsNone(result)

result = await self.con.fetch('''
SELECT * FROM exmany
''')

self.assertEqual(result, [
('a', 1), ('b', 2), ('c', 3), ('d', 4)
])

# Empty set
result = await self.con.executemany('''
INSERT INTO exmany VALUES($1, $2)
''', ())

result = await self.con.fetch('''
SELECT * FROM exmany
''')

self.assertEqual(result, [
('a', 1), ('b', 2), ('c', 3), ('d', 4)
])
finally:
await self.con.execute('DROP TABLE exmany')