diff --git a/.github/workflows/pytest.yml b/.github/workflows/pytest.yml index 33b8f63c..23a24e34 100644 --- a/.github/workflows/pytest.yml +++ b/.github/workflows/pytest.yml @@ -1,6 +1,7 @@ name: PyTest on: [push, pull_request] - +env: + PYTEST_SKIP_OPTION: "not test_no_trailing_rotate_event and not test_end_log_pos and not test_query_event_latin1" jobs: test: strategy: @@ -29,15 +30,15 @@ jobs: docker compose create docker compose start echo "wait mysql server" - + while : do - if mysql -h 127.0.0.1 --user=root --execute "SELECT version();" 2>&1 >/dev/null && mysql -h 127.0.0.1 --port=3307 --user=root --execute "SELECT version();" 2>&1 >/dev/null; then + if mysql -h 127.0.0.1 --user=root --execute "SELECT version();" 2>&1 >/dev/null && mysql -h 127.0.0.1 --port=3307 --user=root --execute "SELECT version();" 2>&1 >/dev/null; then break fi sleep 1 done - + echo "run pytest" - name: Install dependencies @@ -45,6 +46,18 @@ jobs: pip install . pip install pytest - - name: Run test suite - run: | - pytest -k "not test_no_trailing_rotate_event and not test_end_log_pos" + - name: Run tests for mysql-5 + working-directory: pymysqlreplication/tests + run: pytest -k "$PYTEST_SKIP_OPTION" --db=mysql-5 + + - name: Run tests for mysql-5-ctl + working-directory: pymysqlreplication/tests + run: pytest -k "$PYTEST_SKIP_OPTION" --db=mysql-5-ctl + + - name: Run tests for mysql-8 + working-directory: pymysqlreplication/tests + run: pytest -k "$PYTEST_SKIP_OPTION" --db=mysql-8 + + - name: Run tests for mariadb-10 + working-directory: pymysqlreplication/tests + run: pytest -k "$PYTEST_SKIP_OPTION" -m mariadb --db=mariadb-10 diff --git a/pymysqlreplication/tests/base.py b/pymysqlreplication/tests/base.py index ac842011..b288be84 100644 --- a/pymysqlreplication/tests/base.py +++ b/pymysqlreplication/tests/base.py @@ -2,8 +2,21 @@ import copy from pymysqlreplication import BinLogStreamReader import os +import json +import pytest + import unittest + +def get_databases(): + databases = {} + with open( + os.path.join(os.path.dirname(os.path.abspath(__file__)), "config.json") + ) as f: + databases = json.load(f) + return databases + + base = unittest.TestCase @@ -11,8 +24,13 @@ class PyMySQLReplicationTestCase(base): def ignoredEvents(self): return [] - def setUp(self, charset="utf8"): - # default + @pytest.fixture(autouse=True) + def setUpDatabase(self, get_db): + databases = get_databases() + # For local testing, set the get_dbms parameter to one of the following values: 'mysql-5', 'mysql-8', mariadb-10'. + # This value should correspond to the desired database configuration specified in the 'config.json' file. + self.database = databases[get_db] + """ self.database = { "host": os.environ.get("MYSQL_5_7") or "localhost", "user": "root", @@ -22,7 +40,10 @@ def setUp(self, charset="utf8"): "charset": charset, "db": "pymysqlreplication_test", } + """ + def setUp(self, charset="utf8"): + # default self.conn_control = None db = copy.copy(self.database) db["db"] = None @@ -122,62 +143,3 @@ def bin_log_basename(self): bin_log_basename = cursor.fetchone()[0] bin_log_basename = bin_log_basename.split("/")[-1] return bin_log_basename - - -class PyMySQLReplicationMariaDbTestCase(PyMySQLReplicationTestCase): - def setUp(self): - # default - self.database = { - "host": os.environ.get("MARIADB_10_6") or "localhost", - "user": "root", - "passwd": "", - "port": int(os.environ.get("MARIADB_10_6_PORT") or 3308), - "use_unicode": True, - "charset": "utf8", - "db": "pymysqlreplication_test", - } - - self.conn_control = None - db = copy.copy(self.database) - db["db"] = None - self.connect_conn_control(db) - self.execute("DROP DATABASE IF EXISTS pymysqlreplication_test") - self.execute("CREATE DATABASE pymysqlreplication_test") - db = copy.copy(self.database) - self.connect_conn_control(db) - self.stream = None - self.resetBinLog() - - def bin_log_basename(self): - cursor = self.execute("SELECT @@log_bin_basename") - bin_log_basename = cursor.fetchone()[0] - bin_log_basename = bin_log_basename.split("/")[-1] - return bin_log_basename - - -class PyMySQLReplicationVersion8TestCase(PyMySQLReplicationTestCase): - def setUp(self): - super().setUp() - # default - self.database = { - "host": os.environ.get("MYSQL_8_0") or "localhost", - "user": "root", - "passwd": "", - "port": int(os.environ.get("MYSQL_8_0_PORT") or 3309), - "use_unicode": True, - "charset": "utf8", - "db": "pymysqlreplication_test", - } - - self.conn_control = None - db = copy.copy(self.database) - db["db"] = None - self.connect_conn_control(db) - self.execute("DROP DATABASE IF EXISTS pymysqlreplication_test") - self.execute("CREATE DATABASE pymysqlreplication_test") - db = copy.copy(self.database) - self.connect_conn_control(db) - self.stream = None - self.resetBinLog() - self.isMySQL80AndMore() - self.__is_mariaDB = None diff --git a/pymysqlreplication/tests/config.json b/pymysqlreplication/tests/config.json new file mode 100644 index 00000000..5bfd93a4 --- /dev/null +++ b/pymysqlreplication/tests/config.json @@ -0,0 +1,38 @@ +{ + "mysql-5": { + "host": "localhost", + "user": "root", + "passwd": "", + "port": 3306, + "use_unicode": true, + "charset": "utf8", + "db": "pymysqlreplication_test" + }, + "mysql-5-ctl": { + "host": "localhost", + "user": "root", + "passwd": "", + "port": 3307, + "use_unicode": true, + "charset": "utf8", + "db": "pymysqlreplication_test" + }, + "mariadb-10": { + "host": "localhost", + "user": "root", + "passwd": "", + "port": 3308, + "use_unicode": true, + "charset": "utf8", + "db": "pymysqlreplication_test" + }, + "mysql-8": { + "host": "localhost", + "user": "root", + "passwd": "", + "port": 3309, + "use_unicode": true, + "charset": "utf8", + "db": "pymysqlreplication_test" + } +} diff --git a/pymysqlreplication/tests/conftest.py b/pymysqlreplication/tests/conftest.py new file mode 100644 index 00000000..8e1b55b7 --- /dev/null +++ b/pymysqlreplication/tests/conftest.py @@ -0,0 +1,10 @@ +import pytest + + +def pytest_addoption(parser): + parser.addoption("--db", action="store", default="mysql-5") + + +@pytest.fixture +def get_db(request): + return request.config.getoption("--db") diff --git a/pymysqlreplication/tests/test_basic.py b/pymysqlreplication/tests/test_basic.py index f38d4be7..13f79b1c 100644 --- a/pymysqlreplication/tests/test_basic.py +++ b/pymysqlreplication/tests/test_basic.py @@ -1,8 +1,5 @@ -import copy import io -import os import time -import pymysql import unittest from pymysqlreplication.tests import base @@ -13,6 +10,7 @@ from pymysqlreplication.row_event import * from pymysqlreplication.packet import BinLogPacketWrapper from pymysql.protocol import MysqlPacket +import pytest __all__ = [ "TestBasicBinLogStreamReader", @@ -826,42 +824,22 @@ def test_alter_column(self): class TestCTLConnectionSettings(base.PyMySQLReplicationTestCase): - def setUp(self): + def setUp(self, charset="utf8"): super().setUp() - self.stream.close() - ctl_db = copy.copy(self.database) - ctl_db["db"] = None - ctl_db["port"] = int(os.environ.get("MYSQL_5_7_CTL_PORT") or 3307) - ctl_db["host"] = os.environ.get("MYSQL_5_7_CTL") or "localhost" - self.ctl_conn_control = pymysql.connect(**ctl_db) - self.ctl_conn_control.cursor().execute( - "DROP DATABASE IF EXISTS pymysqlreplication_test" - ) - self.ctl_conn_control.cursor().execute( - "CREATE DATABASE pymysqlreplication_test" - ) - self.ctl_conn_control.close() - ctl_db["db"] = "pymysqlreplication_test" - self.ctl_conn_control = pymysql.connect(**ctl_db) self.stream = BinLogStreamReader( self.database, - ctl_connection_settings=ctl_db, server_id=1024, only_events=(WriteRowsEvent,), ) - def tearDown(self): - super().tearDown() - self.ctl_conn_control.close() - def test_separate_ctl_settings_no_error(self): self.execute("CREATE TABLE test (id INTEGER(11))") self.execute("INSERT INTO test VALUES (1)") self.execute("DROP TABLE test") self.execute("COMMIT") - self.ctl_conn_control.cursor().execute("CREATE TABLE test (id INTEGER(11))") - self.ctl_conn_control.cursor().execute("INSERT INTO test VALUES (1)") - self.ctl_conn_control.cursor().execute("COMMIT") + self.conn_control.cursor().execute("CREATE TABLE test (id INTEGER(11))") + self.conn_control.cursor().execute("INSERT INTO test VALUES (1)") + self.conn_control.cursor().execute("COMMIT") try: self.stream.fetchone() except Exception as e: @@ -1322,7 +1300,13 @@ def tearDown(self): super(TestStatementConnectionSetting, self).tearDown() -class TestMariadbBinlogStreamReader(base.PyMySQLReplicationMariaDbTestCase): +@pytest.mark.mariadb +class TestMariadbBinlogStreamReader(base.PyMySQLReplicationTestCase): + def setUp(self): + super().setUp() + if not self.isMariaDB(): + self.skipTest("Skipping the entire class for MariaDB") + def test_binlog_checkpoint_event(self): self.stream.close() self.stream = BinLogStreamReader( @@ -1353,7 +1337,13 @@ def test_binlog_checkpoint_event(self): self.assertEqual(event.filename, self.bin_log_basename() + ".000001") -class TestMariadbBinlogStreamReader2(base.PyMySQLReplicationMariaDbTestCase): +@pytest.mark.mariadb +class TestMariadbBinlogStreamReader2(base.PyMySQLReplicationTestCase): + def setUp(self): + super().setUp() + if not self.isMariaDB(): + self.skipTest("Skipping the entire class for MariaDB") + def test_annotate_rows_event(self): query = "CREATE TABLE test (id INT NOT NULL AUTO_INCREMENT, data VARCHAR (50) NOT NULL, PRIMARY KEY (id))" self.execute(query) @@ -1498,7 +1488,8 @@ def test_query_event_latin1(self): assert event.query == r"CREATE TABLE test_latin1_\xd6\xc6\xdb (a INT)" -class TestOptionalMetaData(base.PyMySQLReplicationVersion8TestCase): +@pytest.mark.mariadb +class TestOptionalMetaData(base.PyMySQLReplicationTestCase): def setUp(self): super(TestOptionalMetaData, self).setUp() self.stream.close() diff --git a/pymysqlreplication/tests/test_data_type.py b/pymysqlreplication/tests/test_data_type.py index 6a18aca2..e4ddd659 100644 --- a/pymysqlreplication/tests/test_data_type.py +++ b/pymysqlreplication/tests/test_data_type.py @@ -943,7 +943,7 @@ def test_varbinary(self): self.assertEqual(event.rows[0]["values"]["b"], b"\xff\x01\x00\x00") -class TestDataTypeVersion8(base.PyMySQLReplicationVersion8TestCase): +class TestDataTypeVersion8(base.PyMySQLReplicationTestCase): def ignoredEvents(self): return [GtidEvent, PreviousGtidsEvent]