Skip to content

Fix Connection.reset() on read-only connections #50

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Dec 16, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 27 additions & 14 deletions asyncpg/_testbase.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,17 +70,19 @@ def wrapper(self, *args, __meth__=meth, **kwargs):

class TestCase(unittest.TestCase, metaclass=TestCaseMeta):

def setUp(self):
@classmethod
def setUpClass(cls):
if os.environ.get('USE_UVLOOP'):
import uvloop
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())

loop = asyncio.new_event_loop()
asyncio.set_event_loop(None)
self.loop = loop
cls.loop = loop

def tearDown(self):
self.loop.close()
@classmethod
def tearDownClass(cls):
cls.loop.close()
asyncio.set_event_loop(None)

@contextlib.contextmanager
Expand All @@ -97,7 +99,16 @@ def assertRunUnder(self, delta):
_default_cluster = None


def _start_cluster(server_settings={}):
def _start_cluster(ClusterCls, cluster_kwargs, server_settings):
cluster = ClusterCls(**cluster_kwargs)
cluster.init()
cluster.trust_local_connections()
cluster.start(port='dynamic', server_settings=server_settings)
atexit.register(_shutdown_cluster, cluster)
return cluster


def _start_default_cluster(server_settings={}):
global _default_cluster

if _default_cluster is None:
Expand All @@ -106,12 +117,8 @@ def _start_cluster(server_settings={}):
# Using existing cluster, assuming it is initialized and running
_default_cluster = pg_cluster.RunningCluster()
else:
_default_cluster = pg_cluster.TempCluster()
_default_cluster.init()
_default_cluster.trust_local_connections()
_default_cluster.start(port='dynamic',
server_settings=server_settings)
atexit.register(_shutdown_cluster, _default_cluster)
_default_cluster = _start_cluster(
pg_cluster.TempCluster, {}, server_settings)

return _default_cluster

Expand All @@ -122,9 +129,10 @@ def _shutdown_cluster(cluster):


class ClusterTestCase(TestCase):
def setUp(self):
super().setUp()
self.cluster = _start_cluster({
@classmethod
def setUpClass(cls):
super().setUpClass()
cls.cluster = _start_default_cluster({
'log_connections': 'on'
})

Expand All @@ -133,6 +141,11 @@ def create_pool(self, **kwargs):
conn_spec.update(kwargs)
return pg_pool.create_pool(loop=self.loop, **conn_spec)

@classmethod
def start_cluster(cls, ClusterCls, *,
cluster_kwargs={}, server_settings={}):
return _start_cluster(ClusterCls, cluster_kwargs, server_settings)


class ConnectedTestCase(ClusterTestCase):

Expand Down
64 changes: 64 additions & 0 deletions asyncpg/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import socket
import subprocess
import tempfile
import textwrap
import time

import asyncpg
Expand Down Expand Up @@ -332,6 +333,20 @@ def trust_local_connections(self):
if status == 'running':
self.reload()

def trust_local_replication_by(self, user):
if _system != 'Windows':
self.add_hba_entry(type='local', database='replication',
user=user, auth_method='trust')
self.add_hba_entry(type='host', address='127.0.0.1/32',
database='replication', user=user,
auth_method='trust')
self.add_hba_entry(type='host', address='::1/128',
database='replication', user=user,
auth_method='trust')
status = self.get_status()
if status == 'running':
self.reload()

def _init_env(self):
self._pg_config = self._find_pg_config(self._pg_config_path)
self._pg_config_data = self._run_pg_config(self._pg_config)
Expand Down Expand Up @@ -489,6 +504,55 @@ def __init__(self, *,
super().__init__(self._data_dir, pg_config_path=pg_config_path)


class HotStandbyCluster(TempCluster):
def __init__(self, *,
master, replication_user,
data_dir_suffix=None, data_dir_prefix=None,
data_dir_parent=None, pg_config_path=None):
self._master = master
self._repl_user = replication_user
super().__init__(
data_dir_suffix=data_dir_suffix,
data_dir_prefix=data_dir_prefix,
data_dir_parent=data_dir_parent,
pg_config_path=pg_config_path)

def _init_env(self):
super()._init_env()
self._pg_basebackup = self._find_pg_binary('pg_basebackup')

def init(self, **settings):
"""Initialize cluster."""
if self.get_status() != 'not-initialized':
raise ClusterError(
'cluster in {!r} has already been initialized'.format(
self._data_dir))

process = subprocess.run(
[self._pg_basebackup, '-h', self._master['host'],
'-p', self._master['port'], '-D', self._data_dir,
'-U', self._repl_user],
stdout=subprocess.PIPE, stderr=subprocess.STDOUT)

output = process.stdout

if process.returncode != 0:
raise ClusterError(
'pg_basebackup init exited with status {:d}:\n{}'.format(
process.returncode, output.decode()))

with open(os.path.join(self._data_dir, 'recovery.conf'), 'w') as f:
f.write(textwrap.dedent("""\
standby_mode = 'on'
primary_conninfo = 'host={host} port={port} user={user}'
""".format(
host=self._master['host'],
port=self._master['port'],
user=self._repl_user)))

return output.decode()


class RunningCluster(Cluster):
def __init__(self, **kwargs):
self.conn_spec = kwargs
Expand Down
10 changes: 9 additions & 1 deletion asyncpg/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,11 +369,19 @@ def terminate(self):

async def reset(self):
self._listeners = {}

await self.execute('''
DO $$
BEGIN
PERFORM * FROM pg_listening_channels() LIMIT 1;
IF FOUND THEN
UNLISTEN *;
END IF;
END;
$$;
SET SESSION AUTHORIZATION DEFAULT;
RESET ALL;
CLOSE ALL;
UNLISTEN *;
SELECT pg_advisory_unlock_all();
''')

Expand Down
72 changes: 71 additions & 1 deletion tests/test_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
import platform

from asyncpg import _testbase as tb

from asyncpg import cluster as pg_cluster
from asyncpg import pool as pg_pool

_system = platform.uname().system

Expand Down Expand Up @@ -148,3 +149,72 @@ async def worker():
# Reset cluster's pg_hba.conf since we've meddled with it
self.cluster.trust_local_connections()
self.cluster.reload()


class TestHostStandby(tb.ConnectedTestCase):
@classmethod
def setUpClass(cls):
super().setUpClass()

cls.master_cluster = cls.start_cluster(
pg_cluster.TempCluster,
server_settings={
'max_wal_senders': 10,
'wal_level': 'hot_standby'
})

con = None

try:
con = cls.loop.run_until_complete(
cls.master_cluster.connect(database='postgres', loop=cls.loop))

cls.loop.run_until_complete(
con.execute('''
CREATE ROLE replication WITH LOGIN REPLICATION
'''))

cls.master_cluster.trust_local_replication_by('replication')

conn_spec = cls.master_cluster.get_connection_spec()

cls.standby_cluster = cls.start_cluster(
pg_cluster.HotStandbyCluster,
cluster_kwargs={
'master': conn_spec,
'replication_user': 'replication'
},
server_settings={
'hot_standby': True
})

finally:
if con is not None:
cls.loop.run_until_complete(con.close())

@classmethod
def tearDownMethod(cls):
cls.standby_cluster.stop()
cls.standby_cluster.destroy()
cls.master_cluster.stop()
cls.master_cluster.destroy()

def create_pool(self, **kwargs):
conn_spec = self.standby_cluster.get_connection_spec()
conn_spec.update(kwargs)
return pg_pool.create_pool(loop=self.loop, **conn_spec)

async def test_standby_pool_01(self):
for n in {1, 3, 5, 10, 20, 100}:
with self.subTest(tasksnum=n):
pool = await self.create_pool(database='postgres',
min_size=5, max_size=10)

async def worker():
con = await pool.acquire()
self.assertEqual(await con.fetchval('SELECT 1'), 1)
await pool.release(con)

tasks = [worker() for _ in range(n)]
await asyncio.gather(*tasks, loop=self.loop)
await pool.close()