diff --git a/asyncpg/__init__.py b/asyncpg/__init__.py index c74bbe2b..f0d89be3 100644 --- a/asyncpg/__init__.py +++ b/asyncpg/__init__.py @@ -5,11 +5,12 @@ # the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 -from .connection import connect # NOQA +from .connection import connect, Connection # NOQA from .exceptions import * # NOQA from .pool import create_pool # NOQA from .protocol import Record # NOQA from .types import * # NOQA -__all__ = ('connect', 'create_pool', 'Record') + exceptions.__all__ # NOQA +__all__ = ('connect', 'create_pool', 'Record', 'Connection') + \ + exceptions.__all__ # NOQA diff --git a/asyncpg/_testbase.py b/asyncpg/_testbase.py index 8d38b258..effe02d1 100644 --- a/asyncpg/_testbase.py +++ b/asyncpg/_testbase.py @@ -18,6 +18,7 @@ from asyncpg import cluster as pg_cluster +from asyncpg import connection as pg_connection from asyncpg import pool as pg_pool @@ -162,12 +163,14 @@ def create_pool(dsn=None, *, init=None, loop=None, pool_class=pg_pool.Pool, + connection_class=pg_connection.Connection, **connect_kwargs): return pool_class( dsn, min_size=min_size, max_size=max_size, max_queries=max_queries, loop=loop, setup=setup, init=init, max_inactive_connection_lifetime=max_inactive_connection_lifetime, + connection_class=connection_class, **connect_kwargs) diff --git a/asyncpg/connect_utils.py b/asyncpg/connect_utils.py new file mode 100644 index 00000000..92ea1566 --- /dev/null +++ b/asyncpg/connect_utils.py @@ -0,0 +1,387 @@ +# Copyright (C) 2016-present the ayncpg authors and contributors +# +# +# This module is part of asyncpg and is released under +# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 + + +import asyncio +import collections +import getpass +import os +import socket +import struct +import time +import urllib.parse + +from . import exceptions +from . import protocol + + +_ConnectionParameters = collections.namedtuple( + 'ConnectionParameters', + [ + 'user', + 'password', + 'database', + 'ssl', + 'connect_timeout', + 'server_settings', + ]) + + +_ClientConfiguration = collections.namedtuple( + 'ConnectionConfiguration', + [ + 'command_timeout', + 'statement_cache_size', + 'max_cached_statement_lifetime', + 'max_cacheable_statement_size', + ]) + + +def _parse_connect_dsn_and_args(*, dsn, host, port, user, + password, database, ssl, connect_timeout, + server_settings): + if host is not None and not isinstance(host, str): + raise TypeError( + 'host argument is expected to be str, got {!r}'.format( + type(host))) + + if dsn: + parsed = urllib.parse.urlparse(dsn) + + if parsed.scheme not in {'postgresql', 'postgres'}: + raise ValueError( + 'invalid DSN: scheme is expected to be either of ' + '"postgresql" or "postgres", got {!r}'.format(parsed.scheme)) + + if parsed.port and port is None: + port = int(parsed.port) + + if parsed.hostname and host is None: + host = parsed.hostname + + if parsed.path and database is None: + database = parsed.path + if database.startswith('/'): + database = database[1:] + + if parsed.username and user is None: + user = parsed.username + + if parsed.password and password is None: + password = parsed.password + + if parsed.query: + query = urllib.parse.parse_qs(parsed.query, strict_parsing=True) + for key, val in query.items(): + if isinstance(val, list): + query[key] = val[-1] + + if 'host' in query: + val = query.pop('host') + if host is None: + host = val + + if 'port' in query: + val = int(query.pop('port')) + if port is None: + port = val + + if 'dbname' in query: + val = query.pop('dbname') + if database is None: + database = val + + if 'database' in query: + val = query.pop('database') + if database is None: + database = val + + if 'user' in query: + val = query.pop('user') + if user is None: + user = val + + if 'password' in query: + val = query.pop('password') + if password is None: + password = val + + if query: + if server_settings is None: + server_settings = query + else: + server_settings = {**query, **server_settings} + + # On env-var -> connection parameter conversion read here: + # https://www.postgresql.org/docs/current/static/libpq-envars.html + # Note that env values may be an empty string in cases when + # the variable is "unset" by setting it to an empty value + # + if host is None: + host = os.getenv('PGHOST') + if not host: + host = ['/tmp', '/private/tmp', + '/var/pgsql_socket', '/run/postgresql', + 'localhost'] + if not isinstance(host, list): + host = [host] + + if port is None: + port = os.getenv('PGPORT') + if port: + port = int(port) + else: + port = 5432 + else: + port = int(port) + + if user is None: + user = os.getenv('PGUSER') + if not user: + user = getpass.getuser() + + if password is None: + password = os.getenv('PGPASSWORD') + + if database is None: + database = os.getenv('PGDATABASE') + + if database is None: + database = user + + if user is None: + raise exceptions.InterfaceError( + 'could not determine user name to connect with') + + if database is None: + raise exceptions.InterfaceError( + 'could not determine database name to connect to') + + addrs = [] + for h in host: + if h.startswith('/'): + # UNIX socket name + if '.s.PGSQL.' not in h: + h = os.path.join(h, '.s.PGSQL.{}'.format(port)) + addrs.append(h) + else: + # TCP host/port + addrs.append((h, port)) + + if not addrs: + raise ValueError( + 'could not determine the database address to connect to') + + if ssl: + for addr in addrs: + if isinstance(addr, str): + # UNIX socket + raise exceptions.InterfaceError( + '`ssl` parameter can only be enabled for TCP addresses, ' + 'got a UNIX socket path: {!r}'.format(addr)) + + if server_settings is not None and ( + not isinstance(server_settings, dict) or + not all(isinstance(k, str) for k in server_settings) or + not all(isinstance(v, str) for v in server_settings.values())): + raise ValueError( + 'server_settings is expected to be None or ' + 'a Dict[str, str]') + + params = _ConnectionParameters( + user=user, password=password, database=database, ssl=ssl, + connect_timeout=connect_timeout, server_settings=server_settings) + + return addrs, params + + +def _parse_connect_arguments(*, dsn, host, port, user, password, database, + timeout, command_timeout, statement_cache_size, + max_cached_statement_lifetime, + max_cacheable_statement_size, + ssl, server_settings): + + local_vars = locals() + for var_name in {'max_cacheable_statement_size', + 'max_cached_statement_lifetime', + 'statement_cache_size'}: + var_val = local_vars[var_name] + if var_val is None or isinstance(var_val, bool) or var_val < 0: + raise ValueError( + '{} is expected to be greater ' + 'or equal to 0, got {!r}'.format(var_name, var_val)) + + if command_timeout is not None: + try: + if isinstance(command_timeout, bool): + raise ValueError + command_timeout = float(command_timeout) + if command_timeout <= 0: + raise ValueError + except ValueError: + raise ValueError( + 'invalid command_timeout value: ' + 'expected greater than 0 float (got {!r})'.format( + command_timeout)) from None + + addrs, params = _parse_connect_dsn_and_args( + dsn=dsn, host=host, port=port, user=user, + password=password, ssl=ssl, + database=database, connect_timeout=timeout, + server_settings=server_settings) + + config = _ClientConfiguration( + command_timeout=command_timeout, + statement_cache_size=statement_cache_size, + max_cached_statement_lifetime=max_cached_statement_lifetime, + max_cacheable_statement_size=max_cacheable_statement_size,) + + return addrs, params, config + + +async def _connect_addr(*, addr, loop, timeout, params, config, + connection_class): + assert loop is not None + + if timeout <= 0: + raise asyncio.TimeoutError + + connected = _create_future(loop) + proto_factory = lambda: protocol.Protocol( + addr, connected, params, loop) + + if isinstance(addr, str): + # UNIX socket + assert params.ssl is None + connector = loop.create_unix_connection(proto_factory, addr) + elif params.ssl: + connector = _create_ssl_connection( + proto_factory, *addr, loop=loop, ssl_context=params.ssl) + else: + connector = loop.create_connection(proto_factory, *addr) + + before = time.monotonic() + tr, pr = await asyncio.wait_for( + connector, timeout=timeout, loop=loop) + timeout -= time.monotonic() - before + + try: + if timeout <= 0: + raise asyncio.TimeoutError + await asyncio.wait_for(connected, loop=loop, timeout=timeout) + except Exception: + tr.close() + raise + + con = connection_class(pr, tr, loop, addr, config, params) + pr.set_connection(con) + return con + + +async def _connect(*, loop, timeout, connection_class, **kwargs): + if loop is None: + loop = asyncio.get_event_loop() + + addrs, params, config = _parse_connect_arguments(timeout=timeout, **kwargs) + + last_error = None + addr = None + for addr in addrs: + before = time.monotonic() + try: + con = await _connect_addr( + addr=addr, loop=loop, timeout=timeout, + params=params, config=config, + connection_class=connection_class) + except (OSError, asyncio.TimeoutError, ConnectionError) as ex: + last_error = ex + else: + return con + finally: + timeout -= time.monotonic() - before + + raise last_error + + +async def _get_ssl_ready_socket(host, port, *, loop): + reader, writer = await asyncio.open_connection(host, port, loop=loop) + + tr = writer.transport + try: + sock = _get_socket(tr) + _set_nodelay(sock) + + writer.write(struct.pack('!ll', 8, 80877103)) # SSLRequest message. + await writer.drain() + resp = await reader.readexactly(1) + + if resp == b'S': + return sock.dup() + else: + raise ConnectionError( + 'PostgreSQL server at "{}:{}" rejected SSL upgrade'.format( + host, port)) + finally: + tr.close() + + +async def _create_ssl_connection(protocol_factory, host, port, *, + loop, ssl_context): + sock = await _get_ssl_ready_socket(host, port, loop=loop) + try: + return await loop.create_connection( + protocol_factory, sock=sock, ssl=ssl_context, + server_hostname=host) + except Exception: + sock.close() + raise + + +async def _open_connection(*, loop, addr, params: _ConnectionParameters): + if isinstance(addr, str): + r, w = await asyncio.open_unix_connection(addr, loop=loop) + else: + if params.ssl: + sock = await _get_ssl_ready_socket(*addr, loop=loop) + + try: + r, w = await asyncio.open_connection( + sock=sock, + loop=loop, + ssl=params.ssl, + server_hostname=addr[0]) + except Exception: + sock.close() + raise + + else: + r, w = await asyncio.open_connection(*addr, loop=loop) + _set_nodelay(_get_socket(w.transport)) + + return r, w + + +def _get_socket(transport): + sock = transport.get_extra_info('socket') + if sock is None: + # Shouldn't happen with any asyncio-complaint event loop. + raise ConnectionError( + 'could not get the socket for transport {!r}'.format(transport)) + return sock + + +def _set_nodelay(sock): + if not hasattr(socket, 'AF_UNIX') or sock.family != socket.AF_UNIX: + sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) + + +def _create_future(loop): + try: + create_future = loop.create_future + except AttributeError: + return asyncio.Future(loop=loop) + else: + return create_future() diff --git a/asyncpg/connection.py b/asyncpg/connection.py index b48e3830..a744683d 100644 --- a/asyncpg/connection.py +++ b/asyncpg/connection.py @@ -7,14 +7,10 @@ import asyncio import collections -import getpass -import os -import socket import struct import time -import urllib.parse -import warnings +from . import connect_utils from . import cursor from . import exceptions from . import introspection @@ -39,17 +35,15 @@ class Connection(metaclass=ConnectionMeta): __slots__ = ('_protocol', '_transport', '_loop', '_types_stmt', '_type_by_name_stmt', '_top_xact', '_uid', '_aborted', - '_stmt_cache', '_stmts_to_close', - '_addr', '_opts', '_command_timeout', '_listeners', + '_stmt_cache', '_stmts_to_close', '_listeners', '_server_version', '_server_caps', '_intro_query', '_reset_query', '_proxy', '_stmt_exclusive_section', - '_max_cacheable_statement_size', '_ssl_context') + '_config', '_params', '_addr') - def __init__(self, protocol, transport, loop, addr, opts, *, - statement_cache_size, command_timeout, - max_cached_statement_lifetime, - max_cacheable_statement_size, - ssl_context): + def __init__(self, protocol, transport, loop, + addr: (str, int) or str, + config: connect_utils._ClientConfiguration, + params: connect_utils._ConnectionParameters): self._protocol = protocol self._transport = transport self._loop = loop @@ -60,20 +54,17 @@ def __init__(self, protocol, transport, loop, addr, opts, *, self._aborted = False self._addr = addr - self._opts = opts - self._ssl_context = ssl_context + self._config = config + self._params = params - self._max_cacheable_statement_size = max_cacheable_statement_size self._stmt_cache = _StatementCache( loop=loop, - max_size=statement_cache_size, + max_size=config.statement_cache_size, on_remove=self._maybe_gc_stmt, - max_lifetime=max_cached_statement_lifetime) + max_lifetime=config.max_cached_statement_lifetime) self._stmts_to_close = set() - self._command_timeout = command_timeout - self._listeners = {} settings = self._protocol.get_settings() @@ -252,8 +243,8 @@ async def _get_statement(self, query, timeout, *, named: bool=False): # * query size is less than `max_cacheable_statement_size`. use_cache = self._stmt_cache.get_max_size() > 0 if (use_cache and - self._max_cacheable_statement_size and - len(query) > self._max_cacheable_statement_size): + self._config.max_cacheable_statement_size and + len(query) > self._config.max_cacheable_statement_size): use_cache = False if use_cache or named: @@ -496,38 +487,18 @@ def _cancel_current_command(self, waiter): async def cancel(): try: # Open new connection to the server - if isinstance(self._addr, str): - r, w = await asyncio.open_unix_connection( - self._addr, loop=self._loop) - else: - if self._ssl_context: - sock = await _get_ssl_ready_socket( - *self._addr, loop=self._loop) - - try: - r, w = await asyncio.open_connection( - sock=sock, - loop=self._loop, - ssl=self._ssl_context, - server_hostname=self._addr[0]) - except Exception: - sock.close() - raise - - else: - r, w = await asyncio.open_connection( - *self._addr, loop=self._loop) - _set_nodelay(_get_socket(w.transport)) + r, w = await connect_utils._open_connection( + loop=self._loop, addr=self._addr, params=self._params) + except Exception as ex: + waiter.set_exception(ex) + return + try: # Pack CancelRequest message msg = struct.pack('!llll', 16, 80877102, self._protocol.backend_pid, self._protocol.backend_secret) - except Exception as ex: - waiter.set_exception(ex) - return - try: w.write(msg) await r.read() # Wait until EOF except ConnectionResetError: @@ -701,8 +672,8 @@ async def connect(dsn=None, *, max_cacheable_statement_size=1024 * 15, command_timeout=None, ssl=None, - __connection_class__=Connection, - **opts): + connection_class=Connection, + server_settings=None): r"""A coroutine to establish a connection to a PostgreSQL server. Returns a new :class:`~asyncpg.connection.Connection` object. @@ -765,6 +736,13 @@ async def connect(dsn=None, *, returned by `ssl.create_default_context() `_ will be used. + :param dict server_settings: + an optional dict of server parameters. + + :param Connection connection_class: + class of the returned connection object. Must be a subclass of + :class:`~asyncpg.connection.Connection`. + :return: A :class:`~asyncpg.connection.Connection` instance. Example: @@ -780,149 +758,38 @@ async def connect(dsn=None, *, >>> asyncio.get_event_loop().run_until_complete(run()) [ connection parameter conversion read here: - # https://www.postgresql.org/docs/current/static/libpq-envars.html - # Note that env values may be an empty string in cases when - # the variable is "unset" by setting it to an empty value - # - if host is None: - host = os.getenv('PGHOST') - if not host: - host = ['/tmp', '/private/tmp', - '/var/pgsql_socket', '/run/postgresql', - 'localhost'] - if not isinstance(host, list): - host = [host] - - if port is None: - port = os.getenv('PGPORT') - if port: - port = int(port) - else: - port = 5432 - else: - port = int(port) - - if user is None: - user = os.getenv('PGUSER') - if not user: - user = getpass.getuser() - - if password is None: - password = os.getenv('PGPASSWORD') - - if database is None: - database = os.getenv('PGDATABASE') - - if user is not None: - opts['user'] = user - if password is not None: - opts['password'] = password - if database is not None: - opts['database'] = database - - for param in opts: - if not isinstance(param, str): - raise ValueError( - 'invalid connection parameter {!r} (str expected)' - .format(param)) - if not isinstance(opts[param], str): - raise ValueError( - 'invalid connection parameter {!r}: {!r} (str expected)' - .format(param, opts[param])) - - addrs = [] - for h in host: - if h.startswith('/'): - # UNIX socket name - if '.s.PGSQL.' not in h: - h = os.path.join(h, '.s.PGSQL.{}'.format(port)) - addrs.append(h) - else: - # TCP host/port - addrs.append((h, port)) - - return addrs, opts - - -def _create_future(loop): - try: - create_future = loop.create_future - except AttributeError: - return asyncio.Future(loop=loop) - else: - return create_future() - - ServerCapabilities = collections.namedtuple( 'ServerCapabilities', ['advisory_locks', 'notifications', 'plpgsql', 'sql_reset', diff --git a/asyncpg/pool.py b/asyncpg/pool.py index a0399f48..82b90796 100644 --- a/asyncpg/pool.py +++ b/asyncpg/pool.py @@ -10,6 +10,7 @@ import inspect from . import connection +from . import connect_utils from . import exceptions @@ -110,27 +111,26 @@ async def connect(self): if self._pool._working_addr is None: # First connection attempt on this pool. - con = await self._pool._connect(*self._connect_args, - loop=self._pool._loop, - **self._connect_kwargs) + con = await connection.connect( + *self._connect_args, + loop=self._pool._loop, + connection_class=self._pool._connection_class, + **self._connect_kwargs) + self._pool._working_addr = con._addr - self._pool._working_opts = con._opts - self._pool._working_ssl_context = con._ssl_context + self._pool._working_config = con._config + self._pool._working_params = con._params else: - # We've connected before and have a resolved address - # and parsed options in `pool._working_addr` and - # `pool._working_opts`. - if isinstance(self._pool._working_addr, str): - host = self._pool._working_addr - port = 0 - else: - host, port = self._pool._working_addr - - con = await self._pool._connect( - host=host, port=port, loop=self._pool._loop, - ssl=self._pool._working_ssl_context, - **self._pool._working_opts) + # We've connected before and have a resolved address, + # and parsed options and config. + con = await connect_utils._connect_addr( + loop=self._pool._loop, + addr=self._pool._working_addr, + timeout=self._pool._working_params.connect_timeout, + config=self._pool._working_config, + params=self._pool._working_params, + connection_class=self._pool._connection_class) if self._init is not None: await self._init(con) @@ -250,8 +250,9 @@ class Pool: """ __slots__ = ('_queue', '_loop', '_minsize', '_maxsize', - '_working_addr', '_working_opts', '_working_ssl_context', - '_holders', '_initialized', '_closed') + '_working_addr', '_working_config', '_working_params', + '_holders', '_initialized', '_closed', + '_connection_class') def __init__(self, *connect_args, min_size, @@ -261,6 +262,7 @@ def __init__(self, *connect_args, setup, init, loop, + connection_class, **connect_kwargs): if loop is None: @@ -293,8 +295,10 @@ def __init__(self, *connect_args, self._queue = asyncio.LifoQueue(maxsize=self._maxsize, loop=self._loop) self._working_addr = None - self._working_opts = None - self._working_ssl_context = None + self._working_config = None + self._working_params = None + + self._connection_class = connection_class self._closed = False @@ -311,10 +315,6 @@ def __init__(self, *connect_args, self._holders.append(ch) self._queue.put_nowait(ch) - async def _connect(self, *args, **kwargs): - # Used by PoolConnectionHolder. - return await connection.connect(*args, **kwargs) - async def _async__init__(self): if self._initialized: return @@ -555,6 +555,7 @@ def create_pool(dsn=None, *, setup=None, init=None, loop=None, + connection_class=connection.Connection, **connect_kwargs): r"""Create a connection pool. @@ -625,8 +626,14 @@ def create_pool(dsn=None, *, An :exc:`~asyncpg.exceptions.InterfaceError` will be raised on any attempted operation on a released connection. """ + if not issubclass(connection_class, connection.Connection): + raise TypeError( + 'connection_class is expected to be a subclass of ' + 'asyncpg.Connection, got {!r}'.format(connection_class)) + return Pool( dsn, + connection_class=connection_class, min_size=min_size, max_size=max_size, max_queries=max_queries, loop=loop, setup=setup, init=init, max_inactive_connection_lifetime=max_inactive_connection_lifetime, diff --git a/asyncpg/protocol/coreproto.pxd b/asyncpg/protocol/coreproto.pxd index efe5d1bf..a1527aff 100644 --- a/asyncpg/protocol/coreproto.pxd +++ b/asyncpg/protocol/coreproto.pxd @@ -83,8 +83,8 @@ cdef class CoreProtocol: object transport - # Dict with all connection arguments - dict con_args + # Instance of _ConnectionParameters + object con_params readonly int32_t backend_pid readonly int32_t backend_secret diff --git a/asyncpg/protocol/coreproto.pyx b/asyncpg/protocol/coreproto.pyx index b099042e..815f5f63 100644 --- a/asyncpg/protocol/coreproto.pyx +++ b/asyncpg/protocol/coreproto.pyx @@ -10,12 +10,13 @@ from hashlib import md5 as hashlib_md5 # for MD5 authentication cdef class CoreProtocol: - def __init__(self, con_args): + def __init__(self, con_params): + # type of `con_params` is `_ConnectionParameters` self.buffer = ReadBuffer() - self.user = con_args.get('user') - self.password = con_args.get('password') + self.user = con_params.user + self.password = con_params.password self.auth_msg = None - self.con_args = con_args + self.con_params = con_params self.transport = None self.con_status = CONNECTION_BAD self.state = PROTOCOL_IDLE @@ -560,11 +561,16 @@ cdef class CoreProtocol: buf.write_bytestring(b'client_encoding') buf.write_bytestring("'{}'".format(self.encoding).encode('ascii')) - for param in self.con_args: - if param == 'password': - continue - buf.write_str(param, self.encoding) - buf.write_str(self.con_args[param], self.encoding) + buf.write_str('user', self.encoding) + buf.write_str(self.con_params.user, self.encoding) + + buf.write_str('database', self.encoding) + buf.write_str(self.con_params.database, self.encoding) + + if self.con_params.server_settings is not None: + for k, v in self.con_params.server_settings.items(): + buf.write_str(k, self.encoding) + buf.write_str(v, self.encoding) buf.write_bytestring(b'') diff --git a/asyncpg/protocol/protocol.pyx b/asyncpg/protocol/protocol.pyx index 00d4aa2e..a25002c3 100644 --- a/asyncpg/protocol/protocol.pyx +++ b/asyncpg/protocol/protocol.pyx @@ -83,8 +83,9 @@ NO_TIMEOUT = object() cdef class BaseProtocol(CoreProtocol): - def __init__(self, addr, connected_fut, con_args, loop): - CoreProtocol.__init__(self, con_args) + def __init__(self, addr, connected_fut, con_params, loop): + # type of `con_params` is `_ConnectionParameters` + CoreProtocol.__init__(self, con_params) self.loop = loop self.waiter = connected_fut @@ -92,8 +93,7 @@ cdef class BaseProtocol(CoreProtocol): self.cancel_sent_waiter = None self.address = addr - self.settings = ConnectionSettings( - (self.address, con_args.get('database'))) + self.settings = ConnectionSettings((self.address, con_params.database)) self.statement = None self.return_extra = False @@ -371,7 +371,7 @@ cdef class BaseProtocol(CoreProtocol): cdef inline _get_timeout_impl(self, timeout): if timeout is None: - timeout = self.connection._command_timeout + timeout = self.connection._config.command_timeout elif timeout is NO_TIMEOUT: timeout = None else: diff --git a/tests/test_connect.py b/tests/test_connect.py index 740b904f..acb05b39 100644 --- a/tests/test_connect.py +++ b/tests/test_connect.py @@ -16,6 +16,7 @@ import asyncpg from asyncpg import _testbase as tb from asyncpg import connection +from asyncpg import connect_utils from asyncpg import cluster as pg_cluster from asyncpg.serverversion import split_server_version_string @@ -289,7 +290,24 @@ class TestConnectParams(tb.TestCase): 'password': 'ask', 'database': 'db', 'result': ([('127.0.0.1', 888)], { - 'param': '123', + 'server_settings': {'param': '123'}, + 'user': 'me', + 'password': 'ask', + 'database': 'db'}) + }, + + { + 'dsn': 'postgresql://user3:123123@localhost:5555/' + 'abcdef?param=sss¶m=123&host=testhost&user=testuser' + '&port=2222&database=testdb', + 'host': '127.0.0.1', + 'port': '888', + 'user': 'me', + 'password': 'ask', + 'database': 'db', + 'server_settings': {'aa': 'bb'}, + 'result': ([('127.0.0.1', 888)], { + 'server_settings': {'aa': 'bb', 'param': '123'}, 'user': 'me', 'password': 'ask', 'database': 'db'}) @@ -339,12 +357,12 @@ def run_testcase(self, testcase): test_env.update(env) dsn = testcase.get('dsn') - opts = testcase.get('opts', {}) user = testcase.get('user') port = testcase.get('port') host = testcase.get('host') password = testcase.get('password') database = testcase.get('database') + server_settings = testcase.get('server_settings') expected = testcase.get('result') expected_error = testcase.get('error') @@ -358,15 +376,20 @@ def run_testcase(self, testcase): 'has to be specified, got both') with contextlib.ExitStack() as es: - es.enter_context(self.subTest(dsn=dsn, opts=opts, env=env)) + es.enter_context(self.subTest(dsn=dsn, env=env)) es.enter_context(self.environ(**test_env)) if expected_error: es.enter_context(self.assertRaisesRegex(*expected_error)) - result = connection._parse_connect_params( + addrs, params = connect_utils._parse_connect_dsn_and_args( dsn=dsn, host=host, port=port, user=user, password=password, - database=database, opts=opts) + database=database, ssl=None, connect_timeout=None, + server_settings=server_settings) + + params = {k: v for k, v in params._asdict().items() + if v is not None} + result = (addrs, params) if expected is not None: self.assertEqual(expected, result) @@ -405,16 +428,10 @@ def test_test_connect_params_run_testcase(self): 'PGUSER': '__test__' }, 'host': 'abc', - 'result': ([('abc', 5432)], {'user': '__test__'}) - }) - - with self.assertRaises(AssertionError): - self.run_testcase({ - 'env': { - 'PGUSER': '__test__' - }, - 'host': 'abc', - 'result': ([('abc', 5432)], {'user': 'wrong_user'}) + 'result': ( + [('abc', 5432)], + {'user': '__test__', 'database': '__test__'} + ) }) def test_connect_params(self): @@ -492,7 +509,7 @@ async def test_connection_ssl_unix(self): with self.assertRaisesRegex(asyncpg.InterfaceError, 'can only be enabled for TCP addresses'): await self.cluster.connect( - host=['localhost', '/tmp'], + host='/tmp', loop=self.loop, ssl=ssl_context) diff --git a/tests/test_pool.py b/tests/test_pool.py index 35ed1943..85d84252 100644 --- a/tests/test_pool.py +++ b/tests/test_pool.py @@ -39,7 +39,7 @@ async def reset(self): class SlowResetConnectionPool(pg_pool.Pool): async def _connect(self, *args, **kwargs): return await pg_connection.connect( - *args, __connection_class__=SlowResetConnection, **kwargs) + *args, connection_class=SlowResetConnection, **kwargs) class TestPool(tb.ConnectedTestCase): @@ -351,6 +351,30 @@ async def sleep_and_release(): async with pool.acquire() as con: await con.fetchval('SELECT 1') + async def test_pool_config_persistence(self): + N = 100 + cons = set() + + class MyConnection(asyncpg.Connection): + pass + + async def test(pool): + async with pool.acquire() as con: + await con.fetchval('SELECT 1') + self.assertTrue(isinstance(con, MyConnection)) + self.assertEqual(con._con._config.statement_cache_size, 3) + cons.add(con) + + async with self.create_pool( + database='postgres', min_size=10, max_size=10, + max_queries=1, connection_class=MyConnection, + statement_cache_size=3) as pool: + + await asyncio.gather(*[test(pool) for _ in range(N)], + loop=self.loop) + + self.assertEqual(len(cons), N) + async def test_pool_release_in_xact(self): """Test that Connection.reset() closes any open transaction.""" async with self.create_pool(database='postgres', diff --git a/tests/test_timeout.py b/tests/test_timeout.py index 6ca5d63f..bf9daef7 100644 --- a/tests/test_timeout.py +++ b/tests/test_timeout.py @@ -147,7 +147,7 @@ async def _get_statement(self, query, timeout): class TestTimeoutCoversPrepare(tb.ConnectedTestCase): - @tb.with_connection_options(__connection_class__=SlowPrepareConnection, + @tb.with_connection_options(connection_class=SlowPrepareConnection, command_timeout=0.3) async def test_timeout_covers_prepare_01(self): for methname in {'fetch', 'fetchrow', 'fetchval', 'execute'}: