Skip to content

Commit a0376e1

Browse files
committed
Allow executemany to return rows
1 parent c2c8d20 commit a0376e1

File tree

7 files changed

+92
-15
lines changed

7 files changed

+92
-15
lines changed

asyncpg/connection.py

+16-4
Original file line numberDiff line numberDiff line change
@@ -359,7 +359,8 @@ async def execute(self, query: str, *args, timeout: float=None) -> str:
359359
)
360360
return status.decode()
361361

362-
async def executemany(self, command: str, args, *, timeout: float=None):
362+
async def executemany(self, command: str, args, *, timeout: float=None,
363+
return_rows: bool=False):
363364
"""Execute an SQL *command* for each sequence of arguments in *args*.
364365
365366
Example:
@@ -373,7 +374,13 @@ async def executemany(self, command: str, args, *, timeout: float=None):
373374
:param command: Command to execute.
374375
:param args: An iterable containing sequences of arguments.
375376
:param float timeout: Optional timeout value in seconds.
376-
:return None: This method discards the results of the operations.
377+
:param bool return_rows:
378+
If ``True``, the resulting rows of each command will be
379+
returned as a list of :class:`~asyncpg.Record`
380+
(defaults to ``False``).
381+
:return:
382+
None, or a list of :class:`~asyncpg.Record` instances
383+
if `return_rows` is true.
377384
378385
.. versionadded:: 0.7.0
379386
@@ -386,9 +393,13 @@ async def executemany(self, command: str, args, *, timeout: float=None):
386393
to prior versions, where the effect of already-processed iterations
387394
would remain in place when an error has occurred, unless
388395
``executemany()`` was called in a transaction.
396+
397+
.. versionchanged:: 0.30.0
398+
Added `return_rows` keyword-only parameter.
389399
"""
390400
self._check_open()
391-
return await self._executemany(command, args, timeout)
401+
return await self._executemany(
402+
command, args, timeout, return_rows=return_rows)
392403

393404
async def _get_statement(
394405
self,
@@ -1898,12 +1909,13 @@ async def __execute(
18981909
)
18991910
return result, stmt
19001911

1901-
async def _executemany(self, query, args, timeout):
1912+
async def _executemany(self, query, args, timeout, return_rows):
19021913
executor = lambda stmt, timeout: self._protocol.bind_execute_many(
19031914
state=stmt,
19041915
args=args,
19051916
portal_name='',
19061917
timeout=timeout,
1918+
return_rows=return_rows,
19071919
)
19081920
timeout = self._protocol._get_timeout(timeout)
19091921
with self._stmt_exclusive_section:

asyncpg/pool.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -538,7 +538,8 @@ async def execute(self, query: str, *args, timeout: float=None) -> str:
538538
async with self.acquire() as con:
539539
return await con.execute(query, *args, timeout=timeout)
540540

541-
async def executemany(self, command: str, args, *, timeout: float=None):
541+
async def executemany(self, command: str, args, *, timeout: float=None,
542+
return_rows: bool=False):
542543
"""Execute an SQL *command* for each sequence of arguments in *args*.
543544
544545
Pool performs this operation using one of its connections. Other than
@@ -549,7 +550,8 @@ async def executemany(self, command: str, args, *, timeout: float=None):
549550
.. versionadded:: 0.10.0
550551
"""
551552
async with self.acquire() as con:
552-
return await con.executemany(command, args, timeout=timeout)
553+
return await con.executemany(
554+
command, args, timeout=timeout, return_rows=return_rows)
553555

554556
async def fetch(
555557
self,

asyncpg/prepared_stmt.py

+13-3
Original file line numberDiff line numberDiff line change
@@ -211,18 +211,28 @@ async def fetchrow(self, *args, timeout=None):
211211
return data[0]
212212

213213
@connresource.guarded
214-
async def executemany(self, args, *, timeout: float=None):
214+
async def executemany(self, args, *, timeout: float=None,
215+
return_rows: bool=False):
215216
"""Execute the statement for each sequence of arguments in *args*.
216217
217218
:param args: An iterable containing sequences of arguments.
218219
:param float timeout: Optional timeout value in seconds.
219-
:return None: This method discards the results of the operations.
220+
:param bool return_rows:
221+
If ``True``, the resulting rows of each command will be
222+
returned as a list of :class:`~asyncpg.Record`
223+
(defaults to ``False``).
224+
:return:
225+
None, or a list of :class:`~asyncpg.Record` instances
226+
if `return_rows` is true.
220227
221228
.. versionadded:: 0.22.0
229+
230+
.. versionchanged:: 0.30.0
231+
Added `return_rows` keyword-only parameter.
222232
"""
223233
return await self.__do_execute(
224234
lambda protocol: protocol.bind_execute_many(
225-
self._state, args, '', timeout))
235+
self._state, args, '', timeout, return_rows=return_rows))
226236

227237
async def __do_execute(self, executor):
228238
protocol = self._connection._protocol

asyncpg/protocol/coreproto.pxd

+1-1
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,7 @@ cdef class CoreProtocol:
174174
cdef _bind_execute(self, str portal_name, str stmt_name,
175175
WriteBuffer bind_data, int32_t limit)
176176
cdef bint _bind_execute_many(self, str portal_name, str stmt_name,
177-
object bind_data)
177+
object bind_data, bint return_rows)
178178
cdef bint _bind_execute_many_more(self, bint first=*)
179179
cdef _bind_execute_many_fail(self, object error, bint first=*)
180180
cdef _bind(self, str portal_name, str stmt_name,

asyncpg/protocol/coreproto.pyx

+3-3
Original file line numberDiff line numberDiff line change
@@ -940,12 +940,12 @@ cdef class CoreProtocol:
940940
self._send_bind_message(portal_name, stmt_name, bind_data, limit)
941941

942942
cdef bint _bind_execute_many(self, str portal_name, str stmt_name,
943-
object bind_data):
943+
object bind_data, bint return_rows):
944944
self._ensure_connected()
945945
self._set_state(PROTOCOL_BIND_EXECUTE_MANY)
946946

947-
self.result = None
948-
self._discard_data = True
947+
self.result = [] if return_rows else None
948+
self._discard_data = not return_rows
949949
self._execute_iter = bind_data
950950
self._execute_portal_name = portal_name
951951
self._execute_stmt_name = stmt_name

asyncpg/protocol/protocol.pyx

+3-1
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,7 @@ cdef class BaseProtocol(CoreProtocol):
213213
args,
214214
portal_name: str,
215215
timeout,
216+
return_rows: bool,
216217
):
217218
if self.cancel_waiter is not None:
218219
await self.cancel_waiter
@@ -238,7 +239,8 @@ cdef class BaseProtocol(CoreProtocol):
238239
more = self._bind_execute_many(
239240
portal_name,
240241
state.name,
241-
arg_bufs) # network op
242+
arg_bufs,
243+
return_rows) # network op
242244

243245
self.last_query = state.query
244246
self.statement = state

tests/test_execute.py

+52-1
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,45 @@ async def test_executemany_basic(self):
139139
('a', 1), ('b', 2), ('c', 3), ('d', 4)
140140
])
141141

142+
async def test_executemany_returning(self):
143+
result = await self.con.executemany('''
144+
INSERT INTO exmany VALUES($1, $2) RETURNING a, b
145+
''', [
146+
('a', 1), ('b', 2), ('c', 3), ('d', 4)
147+
], return_rows=True)
148+
self.assertEqual(result, [
149+
('a', 1), ('b', 2), ('c', 3), ('d', 4)
150+
])
151+
result = await self.con.fetch('''
152+
SELECT * FROM exmany
153+
''')
154+
self.assertEqual(result, [
155+
('a', 1), ('b', 2), ('c', 3), ('d', 4)
156+
])
157+
158+
# Empty set
159+
await self.con.executemany('''
160+
INSERT INTO exmany VALUES($1, $2) RETURNING a, b
161+
''', (), return_rows=True)
162+
result = await self.con.fetch('''
163+
SELECT * FROM exmany
164+
''')
165+
self.assertEqual(result, [
166+
('a', 1), ('b', 2), ('c', 3), ('d', 4)
167+
])
168+
169+
# Without "RETURNING"
170+
result = await self.con.executemany('''
171+
INSERT INTO exmany VALUES($1, $2)
172+
''', [('e', 5), ('f', 6)], return_rows=True)
173+
self.assertEqual(result, [])
174+
result = await self.con.fetch('''
175+
SELECT * FROM exmany
176+
''')
177+
self.assertEqual(result, [
178+
('a', 1), ('b', 2), ('c', 3), ('d', 4), ('e', 5), ('f', 6)
179+
])
180+
142181
async def test_executemany_bad_input(self):
143182
with self.assertRaisesRegex(
144183
exceptions.DataError,
@@ -288,11 +327,13 @@ async def test_executemany_client_server_failure_conflict(self):
288327

289328
async def test_executemany_prepare(self):
290329
stmt = await self.con.prepare('''
291-
INSERT INTO exmany VALUES($1, $2)
330+
INSERT INTO exmany VALUES($1, $2) RETURNING a, b
292331
''')
293332
result = await stmt.executemany([
294333
('a', 1), ('b', 2), ('c', 3), ('d', 4)
295334
])
335+
# While the query contains a "RETURNING" clause, by default
336+
# `executemany` does not return anything
296337
self.assertIsNone(result)
297338
result = await self.con.fetch('''
298339
SELECT * FROM exmany
@@ -308,3 +349,13 @@ async def test_executemany_prepare(self):
308349
self.assertEqual(result, [
309350
('a', 1), ('b', 2), ('c', 3), ('d', 4)
310351
])
352+
# Now with `return_rows=True`, we should retrieve the tuples
353+
# from the "RETURNING" clause.
354+
result = await stmt.executemany([('e', 5), ('f', 6)], return_rows=True)
355+
self.assertEqual(result, [('e', 5), ('f', 6)])
356+
result = await self.con.fetch('''
357+
SELECT * FROM exmany
358+
''')
359+
self.assertEqual(result, [
360+
('a', 1), ('b', 2), ('c', 3), ('d', 4), ('e', 5), ('f', 6)
361+
])

0 commit comments

Comments
 (0)