diff --git a/asyncpg/connection.py b/asyncpg/connection.py index 810227c7..6400f8e3 100644 --- a/asyncpg/connection.py +++ b/asyncpg/connection.py @@ -866,7 +866,7 @@ async def copy_to_table(self, table_name, *, source, delimiter=None, null=None, header=None, quote=None, escape=None, force_quote=None, force_not_null=None, force_null=None, - encoding=None): + encoding=None, where=None): """Copy data to the specified table. :param str table_name: @@ -885,6 +885,15 @@ async def copy_to_table(self, table_name, *, source, :param str schema_name: An optional schema name to qualify the table. + :param str where: + An optional SQL expression used to filter rows when copying. + + .. note:: + + Usage of this parameter requires support for the + ``COPY FROM ... WHERE`` syntax, introduced in + PostgreSQL version 12. + :param float timeout: Optional timeout value in seconds. @@ -912,6 +921,9 @@ async def copy_to_table(self, table_name, *, source, https://www.postgresql.org/docs/current/static/sql-copy.html .. versionadded:: 0.11.0 + + .. versionadded:: 0.29.0 + Added the *where* parameter. """ tabname = utils._quote_ident(table_name) if schema_name: @@ -923,6 +935,7 @@ async def copy_to_table(self, table_name, *, source, else: cols = '' + cond = self._format_copy_where(where) opts = self._format_copy_opts( format=format, oids=oids, freeze=freeze, delimiter=delimiter, null=null, header=header, quote=quote, escape=escape, @@ -930,14 +943,14 @@ async def copy_to_table(self, table_name, *, source, encoding=encoding ) - copy_stmt = 'COPY {tab}{cols} FROM STDIN {opts}'.format( - tab=tabname, cols=cols, opts=opts) + copy_stmt = 'COPY {tab}{cols} FROM STDIN {opts} {cond}'.format( + tab=tabname, cols=cols, opts=opts, cond=cond) return await self._copy_in(copy_stmt, source, timeout) async def copy_records_to_table(self, table_name, *, records, columns=None, schema_name=None, - timeout=None): + timeout=None, where=None): """Copy a list of records to the specified table using binary COPY. :param str table_name: @@ -954,6 +967,16 @@ async def copy_records_to_table(self, table_name, *, records, :param str schema_name: An optional schema name to qualify the table. + :param str where: + An optional SQL expression used to filter rows when copying. + + .. note:: + + Usage of this parameter requires support for the + ``COPY FROM ... WHERE`` syntax, introduced in + PostgreSQL version 12. + + :param float timeout: Optional timeout value in seconds. @@ -998,6 +1021,9 @@ async def copy_records_to_table(self, table_name, *, records, .. versionchanged:: 0.24.0 The ``records`` argument may be an asynchronous iterable. + + .. versionadded:: 0.29.0 + Added the *where* parameter. """ tabname = utils._quote_ident(table_name) if schema_name: @@ -1015,14 +1041,27 @@ async def copy_records_to_table(self, table_name, *, records, intro_ps = await self._prepare(intro_query, use_cache=True) + cond = self._format_copy_where(where) opts = '(FORMAT binary)' - copy_stmt = 'COPY {tab}{cols} FROM STDIN {opts}'.format( - tab=tabname, cols=cols, opts=opts) + copy_stmt = 'COPY {tab}{cols} FROM STDIN {opts} {cond}'.format( + tab=tabname, cols=cols, opts=opts, cond=cond) return await self._protocol.copy_in( copy_stmt, None, None, records, intro_ps._state, timeout) + def _format_copy_where(self, where): + if where and not self._server_caps.sql_copy_from_where: + raise exceptions.UnsupportedServerFeatureError( + 'the `where` parameter requires PostgreSQL 12 or later') + + if where: + where_clause = 'WHERE ' + where + else: + where_clause = '' + + return where_clause + def _format_copy_opts(self, *, format=None, oids=None, freeze=None, delimiter=None, null=None, header=None, quote=None, escape=None, force_quote=None, force_not_null=None, @@ -2404,7 +2443,7 @@ class _ConnectionProxy: ServerCapabilities = collections.namedtuple( 'ServerCapabilities', ['advisory_locks', 'notifications', 'plpgsql', 'sql_reset', - 'sql_close_all', 'jit']) + 'sql_close_all', 'sql_copy_from_where', 'jit']) ServerCapabilities.__doc__ = 'PostgreSQL server capabilities.' @@ -2417,6 +2456,7 @@ def _detect_server_capabilities(server_version, connection_settings): sql_reset = True sql_close_all = False jit = False + sql_copy_from_where = False elif hasattr(connection_settings, 'crdb_version'): # CockroachDB detected. advisory_locks = False @@ -2425,6 +2465,7 @@ def _detect_server_capabilities(server_version, connection_settings): sql_reset = False sql_close_all = False jit = False + sql_copy_from_where = False elif hasattr(connection_settings, 'crate_version'): # CrateDB detected. advisory_locks = False @@ -2433,6 +2474,7 @@ def _detect_server_capabilities(server_version, connection_settings): sql_reset = False sql_close_all = False jit = False + sql_copy_from_where = False else: # Standard PostgreSQL server assumed. advisory_locks = True @@ -2441,6 +2483,7 @@ def _detect_server_capabilities(server_version, connection_settings): sql_reset = True sql_close_all = True jit = server_version >= (11, 0) + sql_copy_from_where = server_version.major >= 12 return ServerCapabilities( advisory_locks=advisory_locks, @@ -2448,6 +2491,7 @@ def _detect_server_capabilities(server_version, connection_settings): plpgsql=plpgsql, sql_reset=sql_reset, sql_close_all=sql_close_all, + sql_copy_from_where=sql_copy_from_where, jit=jit, ) diff --git a/asyncpg/exceptions/_base.py b/asyncpg/exceptions/_base.py index e2da6bd8..00e9699a 100644 --- a/asyncpg/exceptions/_base.py +++ b/asyncpg/exceptions/_base.py @@ -12,9 +12,10 @@ __all__ = ('PostgresError', 'FatalPostgresError', 'UnknownPostgresError', 'InterfaceError', 'InterfaceWarning', 'PostgresLogMessage', + 'ClientConfigurationError', 'InternalClientError', 'OutdatedSchemaCacheError', 'ProtocolError', 'UnsupportedClientFeatureError', 'TargetServerAttributeNotMatched', - 'ClientConfigurationError') + 'UnsupportedServerFeatureError') def _is_asyncpg_class(cls): @@ -233,6 +234,10 @@ class UnsupportedClientFeatureError(InterfaceError): """Requested feature is unsupported by asyncpg.""" +class UnsupportedServerFeatureError(InterfaceError): + """Requested feature is unsupported by PostgreSQL server.""" + + class InterfaceWarning(InterfaceMessage, UserWarning): """A warning caused by an improper use of asyncpg API.""" diff --git a/asyncpg/pool.py b/asyncpg/pool.py index b02fe597..06e698df 100644 --- a/asyncpg/pool.py +++ b/asyncpg/pool.py @@ -711,7 +711,8 @@ async def copy_to_table( force_quote=None, force_not_null=None, force_null=None, - encoding=None + encoding=None, + where=None ): """Copy data to the specified table. @@ -740,7 +741,8 @@ async def copy_to_table( force_quote=force_quote, force_not_null=force_not_null, force_null=force_null, - encoding=encoding + encoding=encoding, + where=where ) async def copy_records_to_table( @@ -750,7 +752,8 @@ async def copy_records_to_table( records, columns=None, schema_name=None, - timeout=None + timeout=None, + where=None ): """Copy a list of records to the specified table using binary COPY. @@ -767,7 +770,8 @@ async def copy_records_to_table( records=records, columns=columns, schema_name=schema_name, - timeout=timeout + timeout=timeout, + where=where ) def acquire(self, *, timeout=None): diff --git a/tests/test_copy.py b/tests/test_copy.py index 70c9388e..be2aabaf 100644 --- a/tests/test_copy.py +++ b/tests/test_copy.py @@ -10,6 +10,7 @@ import io import os import tempfile +import unittest import asyncpg from asyncpg import _testbase as tb @@ -414,7 +415,7 @@ async def test_copy_to_table_basics(self): '*a4*|b4', '*a5*|b5', '*!**|*n-u-l-l*', - 'n-u-l-l|bb' + 'n-u-l-l|bb', ]).encode('utf-8') ) f.seek(0) @@ -644,6 +645,35 @@ async def test_copy_records_to_table_1(self): finally: await self.con.execute('DROP TABLE copytab') + async def test_copy_records_to_table_where(self): + if not self.con._server_caps.sql_copy_from_where: + raise unittest.SkipTest( + 'COPY WHERE not supported on server') + + await self.con.execute(''' + CREATE TABLE copytab_where(a text, b int, c timestamptz); + ''') + + try: + date = datetime.datetime.now(tz=datetime.timezone.utc) + delta = datetime.timedelta(days=1) + + records = [ + ('a-{}'.format(i), i, date + delta) + for i in range(100) + ] + + records.append(('a-100', None, None)) + records.append(('b-999', None, None)) + + res = await self.con.copy_records_to_table( + 'copytab_where', records=records, where='a <> \'b-999\'') + + self.assertEqual(res, 'COPY 101') + + finally: + await self.con.execute('DROP TABLE copytab_where') + async def test_copy_records_to_table_async(self): await self.con.execute(''' CREATE TABLE copytab_async(a text, b int, c timestamptz); @@ -660,7 +690,8 @@ async def record_generator(): yield ('a-100', None, None) res = await self.con.copy_records_to_table( - 'copytab_async', records=record_generator()) + 'copytab_async', records=record_generator(), + ) self.assertEqual(res, 'COPY 101')