diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml new file mode 100644 index 00000000..b830265d --- /dev/null +++ b/.github/workflows/lint.yml @@ -0,0 +1,22 @@ +name: Lint +on: [push, pull_request] + +jobs: + lint: + runs-on: ubuntu-latest + steps: + - name: Checkout repository + uses: actions/checkout@v3 + + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: '3.x' # This will install the latest version of Python 3 + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install ruff black + + - name: Run lint script + run: bash scripts/lint.sh \ No newline at end of file diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 00000000..f9bcdb60 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,19 @@ +repos: +- repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.4.0 + hooks: + - id: trailing-whitespace + - id: end-of-file-fixer + - id: check-yaml + - id: check-added-large-files + +- repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.0.286 + hooks: + - id: ruff + args: [ --fix, --exit-non-zero-on-fix ] + +- repo: https://github.com/psf/black + rev: 23.7.0 + hooks: + - id: black \ No newline at end of file diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 00000000..7ee24417 --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,51 @@ +# Contributing to python-mysql-replication + +Firstly, thank you for considering to contribute to `python-mysql-replication`. We appreciate your effort, and to ensure that your contributions align with the project's coding standards, we employ the use of `pre-commit` hooks. This guide will walk you through setting them up. + +## Setting up pre-commit + +1. **Install pre-commit** + + Before you can use `pre-commit`, you need to install it. You can do so using `pip`: + + ```bash + pip install pre-commit + ``` + +2. **Install the pre-commit hooks** + + Navigate to the root directory of your cloned `python-mysql-replication` repository and run: + + ```bash + pre-commit install + ``` + + This will install the `pre-commit` hooks to your local repository. + +3. **Make sure to stage your changes** + + `pre-commit` will only check the files that are staged in git. So make sure to `git add` any new changes you made before running `pre-commit`. + +4. **Run pre-commit manually (Optional)** + + Before committing, you can manually run: + + ```bash + pre-commit run --all-files + ``` + + This will run the hooks on all the files. If there's any issue, the hooks will let you know. + +## If you encounter issues + +If you run into any problems with the hooks, you can always skip them using: + +```bash +git commit -m "Your commit message" --no-verify +``` + +However, please note that skipping hooks might lead to CI failures if we use these checks in our CI pipeline. It's always recommended to adhere to the checks to ensure a smooth contribution process. + +--- + +That's it! With these steps, you should be well on your way to contributing to `python-mysql-replication`. We look forward to your contributions! \ No newline at end of file diff --git a/docs/conf.py b/docs/conf.py index 53a25342..4d46352e 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -11,201 +11,205 @@ # All configuration values have a default; values that are commented out # serve to show the default. -import sys, os +import sys +import os # If extensions (or modules to document with autodoc) are in another directory, # add these directories to sys.path here. If the directory is relative to the # documentation root, use os.path.abspath to make it absolute, like shown here. -sys.path.insert(0, os.path.abspath('..')) +sys.path.insert(0, os.path.abspath("..")) # -- General configuration ----------------------------------------------------- # If your documentation needs a minimal Sphinx version, state it here. -#needs_sphinx = '1.0' +# needs_sphinx = '1.0' # Add any Sphinx extension module names here, as strings. They can be extensions # coming with Sphinx (named 'sphinx.ext.*') or your custom ones. -extensions = ['sphinx.ext.autodoc', 'sphinx.ext.viewcode'] +extensions = ["sphinx.ext.autodoc", "sphinx.ext.viewcode"] # Add any paths that contain templates here, relative to this directory. -templates_path = ['_templates'] +templates_path = ["_templates"] # The suffix of source filenames. -source_suffix = '.rst' +source_suffix = ".rst" # The encoding of source files. -#source_encoding = 'utf-8-sig' +# source_encoding = 'utf-8-sig' # The master toctree document. -master_doc = 'index' +master_doc = "index" # General information about the project. -project = u'Python MySQL Replication' -copyright = u'2012-2023, Julien Duponchelle' +project = "Python MySQL Replication" +copyright = "2012-2023, Julien Duponchelle" # The version info for the project you're documenting, acts as replacement for # |version| and |release|, also used in various other places throughout the # built documents. # # The short X.Y version. -version = '0.44' +version = "0.44" # The full version, including alpha/beta/rc tags. -release = '0.44' +release = "0.44" # The language for content autogenerated by Sphinx. Refer to documentation # for a list of supported languages. -#language = None +# language = None # There are two options for replacing |today|: either, you set today to some # non-false value, then it is used: -#today = '' +# today = '' # Else, today_fmt is used as the format for a strftime call. -#today_fmt = '%B %d, %Y' +# today_fmt = '%B %d, %Y' # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. -exclude_patterns = ['_build'] +exclude_patterns = ["_build"] # The reST default role (used for this markup: `text`) to use for all documents. -#default_role = None +# default_role = None # If true, '()' will be appended to :func: etc. cross-reference text. -#add_function_parentheses = True +# add_function_parentheses = True # If true, the current module name will be prepended to all description # unit titles (such as .. function::). -#add_module_names = True +# add_module_names = True # If true, sectionauthor and moduleauthor directives will be shown in the # output. They are ignored by default. -#show_authors = False +# show_authors = False # The name of the Pygments (syntax highlighting) style to use. -pygments_style = 'sphinx' +pygments_style = "sphinx" # A list of ignored prefixes for module index sorting. -#modindex_common_prefix = [] +# modindex_common_prefix = [] # -- Options for HTML output --------------------------------------------------- # The theme to use for HTML and HTML Help pages. See the documentation for # a list of builtin themes. -html_theme = 'default' +html_theme = "default" # Theme options are theme-specific and customize the look and feel of a theme # further. For a list of options available for each theme, see the # documentation. -#html_theme_options = {} +# html_theme_options = {} # Add any paths that contain custom themes here, relative to this directory. -#html_theme_path = [] +# html_theme_path = [] # The name for this set of Sphinx documents. If None, it defaults to # " v documentation". -#html_title = None +# html_title = None # A shorter title for the navigation bar. Default is the same as html_title. -#html_short_title = None +# html_short_title = None # The name of an image file (relative to this directory) to place at the top # of the sidebar. -#html_logo = None +# html_logo = None # The name of an image file (within the static path) to use as favicon of the # docs. This file should be a Windows icon file (.ico) being 16x16 or 32x32 # pixels large. -#html_favicon = None +# html_favicon = None # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". -html_static_path = ['_static'] +html_static_path = ["_static"] # If not '', a 'Last updated on:' timestamp is inserted at every page bottom, # using the given strftime format. -#html_last_updated_fmt = '%b %d, %Y' +# html_last_updated_fmt = '%b %d, %Y' # If true, SmartyPants will be used to convert quotes and dashes to # typographically correct entities. -#html_use_smartypants = True +# html_use_smartypants = True # Custom sidebar templates, maps document names to template names. -#html_sidebars = {} +# html_sidebars = {} # Additional templates that should be rendered to pages, maps page names to # template names. -#html_additional_pages = {} +# html_additional_pages = {} # If false, no module index is generated. -#html_domain_indices = True +# html_domain_indices = True # If false, no index is generated. -#html_use_index = True +# html_use_index = True # If true, the index is split into individual pages for each letter. -#html_split_index = False +# html_split_index = False # If true, links to the reST sources are added to the pages. -#html_show_sourcelink = True +# html_show_sourcelink = True # If true, "Created using Sphinx" is shown in the HTML footer. Default is True. -#html_show_sphinx = True +# html_show_sphinx = True # If true, "(C) Copyright ..." is shown in the HTML footer. Default is True. -#html_show_copyright = True +# html_show_copyright = True # If true, an OpenSearch description file will be output, and all pages will # contain a tag referring to it. The value of this option must be the # base URL from which the finished HTML is served. -#html_use_opensearch = '' +# html_use_opensearch = '' # This is the file name suffix for HTML files (e.g. ".xhtml"). -#html_file_suffix = None +# html_file_suffix = None # Output file base name for HTML help builder. -htmlhelp_basename = 'PythonMySQLReplicationdoc' +htmlhelp_basename = "PythonMySQLReplicationdoc" # -- Options for LaTeX output -------------------------------------------------- latex_elements = { -# The paper size ('letterpaper' or 'a4paper'). -#'papersize': 'letterpaper', - -# The font size ('10pt', '11pt' or '12pt'). -#'pointsize': '10pt', - -# Additional stuff for the LaTeX preamble. -#'preamble': '', + # The paper size ('letterpaper' or 'a4paper'). + #'papersize': 'letterpaper', + # The font size ('10pt', '11pt' or '12pt'). + #'pointsize': '10pt', + # Additional stuff for the LaTeX preamble. + #'preamble': '', } # Grouping the document tree into LaTeX files. List of tuples # (source start file, target name, title, author, documentclass [howto/manual]). latex_documents = [ - ('index', 'PythonMySQLReplication.tex', u'Python MySQL Replication Documentation', - u'Julien Duponchelle', 'manual'), + ( + "index", + "PythonMySQLReplication.tex", + "Python MySQL Replication Documentation", + "Julien Duponchelle", + "manual", + ), ] # The name of an image file (relative to this directory) to place at the top of # the title page. -#latex_logo = None +# latex_logo = None # For "manual" documents, if this is true, then toplevel headings are parts, # not chapters. -#latex_use_parts = False +# latex_use_parts = False # If true, show page references after internal links. -#latex_show_pagerefs = False +# latex_show_pagerefs = False # If true, show URL addresses after external links. -#latex_show_urls = False +# latex_show_urls = False # Documents to append as an appendix to all manuals. -#latex_appendices = [] +# latex_appendices = [] # If false, no module index is generated. -#latex_domain_indices = True +# latex_domain_indices = True # -- Options for manual page output -------------------------------------------- @@ -213,12 +217,17 @@ # One entry per manual page. List of tuples # (source start file, name, description, authors, manual section). man_pages = [ - ('index', 'pythonmysqlreplication', u'Python MySQL Replication Documentation', - [u'Julien Duponchelle'], 1) + ( + "index", + "pythonmysqlreplication", + "Python MySQL Replication Documentation", + ["Julien Duponchelle"], + 1, + ) ] # If true, show URL addresses after external links. -#man_show_urls = False +# man_show_urls = False # -- Options for Texinfo output ------------------------------------------------ @@ -227,16 +236,22 @@ # (source start file, target name, title, author, # dir menu entry, description, category) texinfo_documents = [ - ('index', 'PythonMySQLReplication', u'Python MySQL Replication Documentation', - u'Julien Duponchelle', 'PythonMySQLReplication', 'One line description of project.', - 'Miscellaneous'), + ( + "index", + "PythonMySQLReplication", + "Python MySQL Replication Documentation", + "Julien Duponchelle", + "PythonMySQLReplication", + "One line description of project.", + "Miscellaneous", + ), ] # Documents to append as an appendix to all manuals. -#texinfo_appendices = [] +# texinfo_appendices = [] # If false, no module index is generated. -#texinfo_domain_indices = True +# texinfo_domain_indices = True # How to display URL addresses: 'footnote', 'no', or 'inline'. -#texinfo_show_urls = 'footnote' +# texinfo_show_urls = 'footnote' diff --git a/examples/dump_events.py b/examples/dump_events.py index e0de09f1..cd69fa30 100644 --- a/examples/dump_events.py +++ b/examples/dump_events.py @@ -7,21 +7,16 @@ from pymysqlreplication import BinLogStreamReader -MYSQL_SETTINGS = { - "host": "127.0.0.1", - "port": 3306, - "user": "root", - "passwd": "" -} +MYSQL_SETTINGS = {"host": "127.0.0.1", "port": 3306, "user": "root", "passwd": ""} def main(): # server_id is your slave identifier, it should be unique. # set blocking to True if you want to block and wait for the next event at # the end of the stream - stream = BinLogStreamReader(connection_settings=MYSQL_SETTINGS, - server_id=3, - blocking=True) + stream = BinLogStreamReader( + connection_settings=MYSQL_SETTINGS, server_id=3, blocking=True + ) for binlogevent in stream: binlogevent.dump() diff --git a/examples/logstash/mysql_to_logstash.py b/examples/logstash/mysql_to_logstash.py index ea83ce00..e5a70a6a 100644 --- a/examples/logstash/mysql_to_logstash.py +++ b/examples/logstash/mysql_to_logstash.py @@ -17,19 +17,15 @@ WriteRowsEvent, ) -MYSQL_SETTINGS = { - "host": "127.0.0.1", - "port": 3306, - "user": "root", - "passwd": "" -} +MYSQL_SETTINGS = {"host": "127.0.0.1", "port": 3306, "user": "root", "passwd": ""} def main(): stream = BinLogStreamReader( connection_settings=MYSQL_SETTINGS, server_id=3, - only_events=[DeleteRowsEvent, WriteRowsEvent, UpdateRowsEvent]) + only_events=[DeleteRowsEvent, WriteRowsEvent, UpdateRowsEvent], + ) for binlogevent in stream: for row in binlogevent.rows: @@ -44,10 +40,9 @@ def main(): elif isinstance(binlogevent, WriteRowsEvent): event["action"] = "insert" event = dict(event.items() + row["values"].items()) - print json.dumps(event) + print(json.dumps(event)) sys.stdout.flush() - stream.close() diff --git a/examples/mariadb_gtid/read_event.py b/examples/mariadb_gtid/read_event.py index 607bfa34..205b9ce4 100644 --- a/examples/mariadb_gtid/read_event.py +++ b/examples/mariadb_gtid/read_event.py @@ -1,8 +1,17 @@ import pymysql -from pymysqlreplication import BinLogStreamReader, gtid -from pymysqlreplication.event import GtidEvent, RotateEvent, MariadbGtidEvent, QueryEvent,MariadbAnnotateRowsEvent, MariadbBinLogCheckPointEvent -from pymysqlreplication.row_event import WriteRowsEvent, UpdateRowsEvent, DeleteRowsEvent +from pymysqlreplication import BinLogStreamReader +from pymysqlreplication.event import ( + RotateEvent, + MariadbGtidEvent, + MariadbAnnotateRowsEvent, + MariadbBinLogCheckPointEvent, +) +from pymysqlreplication.row_event import ( + WriteRowsEvent, + UpdateRowsEvent, + DeleteRowsEvent, +) MARIADB_SETTINGS = { "host": "127.0.0.1", @@ -41,7 +50,9 @@ def extract_gtid(self, gtid: str, server_id: str): return None def query_gtid_current_pos(self, server_id: str): - return self.extract_gtid(self.query_single_value("SELECT @@gtid_current_pos"), server_id) + return self.extract_gtid( + self.query_single_value("SELECT @@gtid_current_pos"), server_id + ) def query_server_id(self): return int(self.query_single_value("SELECT @@server_id")) @@ -51,10 +62,10 @@ def query_server_id(self): db = MariaDbGTID(MARIADB_SETTINGS) server_id = db.query_server_id() - print('Server ID: ', server_id) + print("Server ID: ", server_id) # gtid = db.query_gtid_current_pos(server_id) - gtid = '0-1-1' # initial pos + gtid = "0-1-1" # initial pos stream = BinLogStreamReader( connection_settings=MARIADB_SETTINGS, @@ -67,20 +78,20 @@ def query_server_id(self): WriteRowsEvent, UpdateRowsEvent, DeleteRowsEvent, - MariadbAnnotateRowsEvent + MariadbAnnotateRowsEvent, ], auto_position=gtid, is_mariadb=True, - annotate_rows_event=True + annotate_rows_event=True, ) - print('Starting reading events from GTID ', gtid) + print("Starting reading events from GTID ", gtid) for binlogevent in stream: binlogevent.dump() if isinstance(binlogevent, MariadbGtidEvent): gtid = binlogevent.gtid - print('Last encountered GTID: ', gtid) + print("Last encountered GTID: ", gtid) stream.close() diff --git a/examples/redis_cache.py b/examples/redis_cache.py index 973821e6..485a700e 100644 --- a/examples/redis_cache.py +++ b/examples/redis_cache.py @@ -15,12 +15,7 @@ WriteRowsEvent, ) -MYSQL_SETTINGS = { - "host": "127.0.0.1", - "port": 3306, - "user": "root", - "passwd": "" -} +MYSQL_SETTINGS = {"host": "127.0.0.1", "port": 3306, "user": "root", "passwd": ""} def main(): @@ -29,7 +24,8 @@ def main(): stream = BinLogStreamReader( connection_settings=MYSQL_SETTINGS, server_id=3, # server_id is your slave identifier, it should be unique - only_events=[DeleteRowsEvent, WriteRowsEvent, UpdateRowsEvent]) + only_events=[DeleteRowsEvent, WriteRowsEvent, UpdateRowsEvent], + ) for binlogevent in stream: prefix = "%s:%s:" % (binlogevent.schema, binlogevent.table) diff --git a/pymysqlreplication/_compat.py b/pymysqlreplication/_compat.py index a61b248c..854b5528 100644 --- a/pymysqlreplication/_compat.py +++ b/pymysqlreplication/_compat.py @@ -1,7 +1 @@ -import sys - - -if sys.version_info > (3,): - text_type = str -else: - text_type = unicode +text_type = str diff --git a/pymysqlreplication/binlogstream.py b/pymysqlreplication/binlogstream.py index 8b93b3b6..36782f84 100644 --- a/pymysqlreplication/binlogstream.py +++ b/pymysqlreplication/binlogstream.py @@ -10,25 +10,38 @@ from .constants.BINLOG import TABLE_MAP_EVENT, ROTATE_EVENT, FORMAT_DESCRIPTION_EVENT from .event import ( - QueryEvent, RotateEvent, FormatDescriptionEvent, - XidEvent, GtidEvent, StopEvent, XAPrepareEvent, - BeginLoadQueryEvent, ExecuteLoadQueryEvent, - HeartbeatLogEvent, NotImplementedEvent, MariadbGtidEvent, - MariadbAnnotateRowsEvent, RandEvent, MariadbStartEncryptionEvent, RowsQueryLogEvent, - MariadbGtidListEvent, MariadbBinLogCheckPointEvent, UserVarEvent, - PreviousGtidsEvent) + QueryEvent, + RotateEvent, + FormatDescriptionEvent, + XidEvent, + GtidEvent, + StopEvent, + XAPrepareEvent, + BeginLoadQueryEvent, + ExecuteLoadQueryEvent, + HeartbeatLogEvent, + NotImplementedEvent, + MariadbGtidEvent, + MariadbAnnotateRowsEvent, + RandEvent, + MariadbStartEncryptionEvent, + RowsQueryLogEvent, + MariadbGtidListEvent, + MariadbBinLogCheckPointEvent, + UserVarEvent, + PreviousGtidsEvent, +) from .exceptions import BinLogNotEnabled from .gtid import GtidSet from .packet import BinLogPacketWrapper -from .row_event import ( - UpdateRowsEvent, WriteRowsEvent, DeleteRowsEvent, TableMapEvent) +from .row_event import UpdateRowsEvent, WriteRowsEvent, DeleteRowsEvent, TableMapEvent try: from pymysql.constants.COMMAND import COM_BINLOG_DUMP_GTID except ImportError: # Handle old pymysql versions # See: https://github.com/PyMySQL/PyMySQL/pull/261 - COM_BINLOG_DUMP_GTID = 0x1e + COM_BINLOG_DUMP_GTID = 0x1E # 2013 Connection Lost # 2006 MySQL server has gone away @@ -39,9 +52,9 @@ class ReportSlave(object): """Represent the values that you may report when connecting as a slave to a master. SHOW SLAVE HOSTS related""" - hostname = '' - username = '' - password = '' + hostname = "" + username = "" + password = "" port = 0 def __init__(self, value): @@ -61,7 +74,7 @@ def __init__(self, value): except IndexError: pass elif isinstance(value, dict): - for key in ['hostname', 'username', 'password', 'port']: + for key in ["hostname", "username", "password", "port"]: try: setattr(self, key, value[key]) except KeyError: @@ -70,8 +83,12 @@ def __init__(self, value): self.hostname = value def __repr__(self): - return '' % \ - (self.hostname, self.username, self.password, self.port) + return "" % ( + self.hostname, + self.username, + self.password, + self.port, + ) def encoded(self, server_id, master_id=0): """ @@ -96,57 +113,77 @@ def encoded(self, server_id, master_id=0): lusername = len(self.username.encode()) lpassword = len(self.password.encode()) - packet_len = (1 + # command - 4 + # server-id - 1 + # hostname length - lhostname + - 1 + # username length - lusername + - 1 + # password length - lpassword + - 2 + # slave mysql port - 4 + # replication rank - 4) # master-id + packet_len = ( + 1 + + 4 # command + + 1 # server-id + + lhostname # hostname length + + 1 + + lusername # username length + + 1 + + lpassword # password length + + 2 + + 4 # slave mysql port + + 4 # replication rank + ) # master-id MAX_STRING_LEN = 257 # one byte for length + 256 chars - return (struct.pack(' 4294967: heartbeat = 4294967 @@ -362,22 +403,23 @@ def __connect_to_stream(self): self.log_file, self.log_pos = master_status[:2] cur.close() - prelude = struct.pack('> 8 if real_type == FIELD_TYPE.SET or real_type == FIELD_TYPE.ENUM: self.type = real_type - self.size = metadata & 0x00ff + self.size = metadata & 0x00FF self.__read_enum_metadata(column_schema) else: - self.max_length = (((metadata >> 4) & 0x300) ^ 0x300) \ - + (metadata & 0x00ff) + self.max_length = (((metadata >> 4) & 0x300) ^ 0x300) + (metadata & 0x00FF) def __read_enum_metadata(self, column_schema): enums = column_schema["COLUMN_TYPE"] if self.type == FIELD_TYPE.ENUM: - self.enum_values = [''] + enums.replace('enum(', '')\ - .replace(')', '').replace('\'', '').split(',') + self.enum_values = [""] + enums.replace("enum(", "").replace( + ")", "" + ).replace("'", "").split(",") else: - self.set_values = enums.replace('set(', '')\ - .replace(')', '').replace('\'', '').split(',') + self.set_values = ( + enums.replace("set(", "").replace(")", "").replace("'", "").split(",") + ) def __eq__(self, other): return self.data == other.data @@ -96,4 +97,4 @@ def serializable_data(self): @property def data(self): - return dict((k, v) for (k, v) in self.__dict__.items() if not k.startswith('_')) + return dict((k, v) for (k, v) in self.__dict__.items() if not k.startswith("_")) diff --git a/pymysqlreplication/constants/BINLOG.py b/pymysqlreplication/constants/BINLOG.py index 71e5faf4..2cf8f219 100644 --- a/pymysqlreplication/constants/BINLOG.py +++ b/pymysqlreplication/constants/BINLOG.py @@ -10,12 +10,12 @@ SLAVE_EVENT = 0x07 CREATE_FILE_EVENT = 0x08 APPEND_BLOCK_EVENT = 0x09 -EXEC_LOAD_EVENT = 0x0a -DELETE_FILE_EVENT = 0x0b -NEW_LOAD_EVENT = 0x0c -RAND_EVENT = 0x0d -USER_VAR_EVENT = 0x0e -FORMAT_DESCRIPTION_EVENT = 0x0f +EXEC_LOAD_EVENT = 0x0A +DELETE_FILE_EVENT = 0x0B +NEW_LOAD_EVENT = 0x0C +RAND_EVENT = 0x0D +USER_VAR_EVENT = 0x0E +FORMAT_DESCRIPTION_EVENT = 0x0F XID_EVENT = 0x10 BEGIN_LOAD_QUERY_EVENT = 0x11 EXECUTE_LOAD_QUERY_EVENT = 0x12 @@ -26,12 +26,12 @@ WRITE_ROWS_EVENT_V1 = 0x17 UPDATE_ROWS_EVENT_V1 = 0x18 DELETE_ROWS_EVENT_V1 = 0x19 -INCIDENT_EVENT = 0x1a -HEARTBEAT_LOG_EVENT = 0x1b -IGNORABLE_LOG_EVENT = 0x1c -ROWS_QUERY_LOG_EVENT = 0x1d -WRITE_ROWS_EVENT_V2 = 0x1e -UPDATE_ROWS_EVENT_V2 = 0x1f +INCIDENT_EVENT = 0x1A +HEARTBEAT_LOG_EVENT = 0x1B +IGNORABLE_LOG_EVENT = 0x1C +ROWS_QUERY_LOG_EVENT = 0x1D +WRITE_ROWS_EVENT_V2 = 0x1E +UPDATE_ROWS_EVENT_V2 = 0x1F DELETE_ROWS_EVENT_V2 = 0x20 GTID_LOG_EVENT = 0x21 ANONYMOUS_GTID_LOG_EVENT = 0x22 @@ -44,11 +44,11 @@ INTVAR_INSERT_ID_EVENT = 0x02 # MariaDB events -MARIADB_ANNOTATE_ROWS_EVENT = 0xa0 -MARIADB_BINLOG_CHECKPOINT_EVENT = 0xa1 -MARIADB_GTID_EVENT = 0xa2 -MARIADB_GTID_GTID_LIST_EVENT = 0xa3 -MARIADB_START_ENCRYPTION_EVENT = 0xa4 +MARIADB_ANNOTATE_ROWS_EVENT = 0xA0 +MARIADB_BINLOG_CHECKPOINT_EVENT = 0xA1 +MARIADB_GTID_EVENT = 0xA2 +MARIADB_GTID_GTID_LIST_EVENT = 0xA3 +MARIADB_START_ENCRYPTION_EVENT = 0xA4 # Common-Footer -BINLOG_CHECKSUM_LEN = 4 \ No newline at end of file +BINLOG_CHECKSUM_LEN = 4 diff --git a/pymysqlreplication/constants/FIELD_TYPE.py b/pymysqlreplication/constants/FIELD_TYPE.py index 51791d62..18a1357e 100644 --- a/pymysqlreplication/constants/FIELD_TYPE.py +++ b/pymysqlreplication/constants/FIELD_TYPE.py @@ -3,23 +3,23 @@ # Original code from PyMySQL # Copyright (c) 2010 PyMySQL contributors # -#Permission is hereby granted, free of charge, to any person obtaining a copy -#of this software and associated documentation files (the "Software"), to deal -#in the Software without restriction, including without limitation the rights -#to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -#copies of the Software, and to permit persons to whom the Software is -#furnished to do so, subject to the following conditions: +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: # -#The above copyright notice and this permission notice shall be included in -#all copies or substantial portions of the Software. +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. # -#THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -#IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -#FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -#AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -#LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -#OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -#THE SOFTWARE. +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. DECIMAL = 0 TINY = 1 @@ -41,7 +41,7 @@ TIMESTAMP2 = 17 DATETIME2 = 18 TIME2 = 19 -JSON = 245 # Introduced in 5.7.8 +JSON = 245 # Introduced in 5.7.8 NEWDECIMAL = 246 ENUM = 247 SET = 248 diff --git a/pymysqlreplication/constants/STATUS_VAR_KEY.py b/pymysqlreplication/constants/STATUS_VAR_KEY.py index 18aa43ae..c8d2ac10 100644 --- a/pymysqlreplication/constants/STATUS_VAR_KEY.py +++ b/pymysqlreplication/constants/STATUS_VAR_KEY.py @@ -1,6 +1,6 @@ -#from enum import IntEnum +# from enum import IntEnum -#class StatusVarsKey(IntEnum): +# class StatusVarsKey(IntEnum): """List of Query_event_status_vars A status variable in query events is a sequence of status KEY-VALUE pairs. @@ -16,7 +16,7 @@ # KEY Q_FLAGS2_CODE = 0x00 -Q_SQL_MODE_CODE = 0X01 +Q_SQL_MODE_CODE = 0x01 Q_CATALOG_CODE = 0x02 Q_AUTO_INCREMENT = 0x03 Q_CHARSET_CODE = 0x04 @@ -27,14 +27,14 @@ Q_TABLE_MAP_FOR_UPDATE_CODE = 0x09 Q_MASTER_DATA_WRITTEN_CODE = 0x0A Q_INVOKER = 0x0B -Q_UPDATED_DB_NAMES = 0x0C # MySQL only -Q_MICROSECONDS = 0x0D # MySQL only +Q_UPDATED_DB_NAMES = 0x0C # MySQL only +Q_MICROSECONDS = 0x0D # MySQL only Q_COMMIT_TS = 0x0E -Q_COMMIT_TS2 = 0X0F -Q_EXPLICIT_DEFAULTS_FOR_TIMESTAMP = 0X10 -Q_DDL_LOGGED_WITH_XID = 0X11 -Q_DEFAULT_COLLATION_FOR_UTF8MB4 = 0X12 -Q_SQL_REQUIRE_PRIMARY_KEY = 0X13 -Q_DEFAULT_TABLE_ENCRYPTION = 0X14 +Q_COMMIT_TS2 = 0x0F +Q_EXPLICIT_DEFAULTS_FOR_TIMESTAMP = 0x10 +Q_DDL_LOGGED_WITH_XID = 0x11 +Q_DEFAULT_COLLATION_FOR_UTF8MB4 = 0x12 +Q_SQL_REQUIRE_PRIMARY_KEY = 0x13 +Q_DEFAULT_TABLE_ENCRYPTION = 0x14 Q_HRNOW = 0x80 # MariaDB only -Q_XID = 0x81 # MariaDB only \ No newline at end of file +Q_XID = 0x81 # MariaDB only diff --git a/pymysqlreplication/event.py b/pymysqlreplication/event.py index aeea07aa..fe0d64ca 100644 --- a/pymysqlreplication/event.py +++ b/pymysqlreplication/event.py @@ -12,16 +12,22 @@ class BinLogEvent(object): - def __init__(self, from_packet, event_size, table_map, ctl_connection, - mysql_version=(0,0,0), - only_tables=None, - ignored_tables=None, - only_schemas=None, - ignored_schemas=None, - freeze_schema=False, - fail_on_table_metadata_unavailable=False, - ignore_decode_errors=False, - verify_checksum=False,): + def __init__( + self, + from_packet, + event_size, + table_map, + ctl_connection, + mysql_version=(0, 0, 0), + only_tables=None, + ignored_tables=None, + only_schemas=None, + ignored_schemas=None, + freeze_schema=False, + fail_on_table_metadata_unavailable=False, + ignore_decode_errors=False, + verify_checksum=False, + ): self.packet = from_packet self.table_map = table_map self.event_type = self.packet.event_type @@ -43,7 +49,7 @@ def _read_table_id(self): # Table ID is 6 byte # pad little-endian number table_id = self.packet.read(6) + b"\x00\x00" - return struct.unpack('= (5, 7): - self.last_committed = struct.unpack(' float: """ Read real data. """ - return struct.unpack(' int: """ Read integer data. """ - fmt = ' decimal.Decimal: @@ -729,7 +784,9 @@ def _read_decimal(self, buffer: bytes) -> decimal.Decimal: self.precision = self.temp_value_buffer[0] self.decimals = self.temp_value_buffer[1] raw_decimal = self.temp_value_buffer[2:] - return self._parse_decimal_from_bytes(raw_decimal, self.precision, self.decimals) + return self._parse_decimal_from_bytes( + raw_decimal, self.precision, self.decimals + ) def _read_default(self) -> bytes: """ @@ -739,7 +796,9 @@ def _read_default(self) -> bytes: return self.packet.read(self.value_len) @staticmethod - def _parse_decimal_from_bytes(raw_decimal: bytes, precision: int, decimals: int) -> decimal.Decimal: + def _parse_decimal_from_bytes( + raw_decimal: bytes, precision: int, decimals: int + ) -> decimal.Decimal: """ Parse decimal from bytes. """ @@ -760,27 +819,31 @@ def decode_decimal_decompress_value(comp_indx, data, mask): databuff = bytearray(data[:size]) for i in range(size): databuff[i] = (databuff[i] ^ mask) & 0xFF - return size, int.from_bytes(databuff, byteorder='big') + return size, int.from_bytes(databuff, byteorder="big") return 0, 0 - pointer, value = decode_decimal_decompress_value(comp_integral, raw_decimal, mask) + pointer, value = decode_decimal_decompress_value( + comp_integral, raw_decimal, mask + ) res += str(value) for _ in range(uncomp_integral): - value = struct.unpack('>i', raw_decimal[pointer:pointer+4])[0] ^ mask - res += '%09d' % value + value = struct.unpack(">i", raw_decimal[pointer : pointer + 4])[0] ^ mask + res += "%09d" % value pointer += 4 res += "." for _ in range(uncomp_fractional): - value = struct.unpack('>i', raw_decimal[pointer:pointer+4])[0] ^ mask - res += '%09d' % value + value = struct.unpack(">i", raw_decimal[pointer : pointer + 4])[0] ^ mask + res += "%09d" % value pointer += 4 - size, value = decode_decimal_decompress_value(comp_fractional, raw_decimal[pointer:], mask) + size, value = decode_decimal_decompress_value( + comp_fractional, raw_decimal[pointer:], mask + ) if size > 0: - res += '%0*d' % (comp_fractional, value) + res += "%0*d" % (comp_fractional, value) return decimal.Decimal(res) def _dump(self) -> None: @@ -788,11 +851,15 @@ def _dump(self) -> None: print("User variable name: %s" % self.name) print("Is NULL: %s" % ("Yes" if self.is_null else "No")) if not self.is_null: - print("Type: %s" % self.type_to_codes_and_method.get(self.type, ['UNKNOWN_TYPE'])[0]) + print( + "Type: %s" + % self.type_to_codes_and_method.get(self.type, ["UNKNOWN_TYPE"])[0] + ) print("Charset: %s" % self.charset) print("Value: %s" % self.value) print("Flags: %s" % self.flags) + class MariadbStartEncryptionEvent(BinLogEvent): """ Since MariaDB 10.1.7, @@ -832,11 +899,14 @@ class RowsQueryLogEvent(BinLogEvent): :ivar query_length: uint - Length of the SQL statement :ivar query: str - The executed SQL statement """ + def __init__(self, from_packet, event_size, table_map, ctl_connection, **kwargs): - super(RowsQueryLogEvent, self).__init__(from_packet, event_size, table_map, - ctl_connection, **kwargs) + super(RowsQueryLogEvent, self).__init__( + from_packet, event_size, table_map, ctl_connection, **kwargs + ) self.query_length = self.packet.read_uint8() - self.query = self.packet.read(self.query_length).decode('utf-8') + self.query = self.packet.read(self.query_length).decode("utf-8") + def dump(self): print("=== %s ===" % (self.__class__.__name__)) print("Query length: %d" % self.query_length) @@ -849,7 +919,7 @@ class NotImplementedEvent(BinLogEvent): The event referencing this class skips parsing. """ + def __init__(self, from_packet, event_size, table_map, ctl_connection, **kwargs): - super().__init__( - from_packet, event_size, table_map, ctl_connection, **kwargs) + super().__init__(from_packet, event_size, table_map, ctl_connection, **kwargs) self.packet.advance(event_size) diff --git a/pymysqlreplication/exceptions.py b/pymysqlreplication/exceptions.py index 434d8d76..7d99ce82 100644 --- a/pymysqlreplication/exceptions.py +++ b/pymysqlreplication/exceptions.py @@ -1,6 +1,6 @@ class TableMetadataUnavailableError(Exception): def __init__(self, table): - Exception.__init__(self,"Unable to find metadata for table {0}".format(table)) + Exception.__init__(self, "Unable to find metadata for table {0}".format(table)) class BinLogNotEnabled(Exception): @@ -10,10 +10,13 @@ def __init__(self): class StatusVariableMismatch(Exception): def __init__(self): - Exception.__init__(self, " ".join( - "Unknown status variable in query event." - , "Possible parse failure in preceding fields" - , "or outdated constants.STATUS_VAR_KEY" - , "Refer to MySQL documentation/source code" - , "or create an issue on GitHub" - )) + Exception.__init__( + self, + " ".join( + "Unknown status variable in query event.", + "Possible parse failure in preceding fields", + "or outdated constants.STATUS_VAR_KEY", + "Refer to MySQL documentation/source code", + "or create an issue on GitHub", + ), + ) diff --git a/pymysqlreplication/gtid.py b/pymysqlreplication/gtid.py index 3b2554da..04830b0a 100644 --- a/pymysqlreplication/gtid.py +++ b/pymysqlreplication/gtid.py @@ -6,12 +6,15 @@ from copy import deepcopy from io import BytesIO + def overlap(i1, i2): return i1[0] < i2[1] and i1[1] > i2[0] + def contains(i1, i2): return i2[0] >= i1[0] and i2[1] <= i1[1] + class Gtid(object): """A mysql GTID is composed of a server-id and a set of right-open intervals [a,b), and represent all transactions x that happened on @@ -48,6 +51,7 @@ class Gtid(object): Exception: Adding an already present transaction number (one that overlaps). Exception: Adding a Gtid with a different SID. """ + @staticmethod def parse_interval(interval): """ @@ -57,12 +61,12 @@ def parse_interval(interval): Raises: - ValueError if GTID format is incorrect """ - m = re.search('^([0-9]+)(?:-([0-9]+))?$', interval) + m = re.search("^([0-9]+)(?:-([0-9]+))?$", interval) if not m: - raise ValueError('GTID format is incorrect: %r' % (interval, )) + raise ValueError("GTID format is incorrect: %r" % (interval,)) a = int(m.group(1)) b = int(m.group(2) or a) - return (a, b+1) + return (a, b + 1) @staticmethod def parse(gtid): @@ -71,16 +75,18 @@ def parse(gtid): Raises: - ValueError: if GTID format is incorrect. """ - m = re.search('^([0-9a-fA-F]{8}(?:-[0-9a-fA-F]{4}){3}-[0-9a-fA-F]{12})' - '((?::[0-9-]+)+)$', gtid) + m = re.search( + "^([0-9a-fA-F]{8}(?:-[0-9a-fA-F]{4}){3}-[0-9a-fA-F]{12})" + "((?::[0-9-]+)+)$", + gtid, + ) if not m: - raise ValueError('GTID format is incorrect: %r' % (gtid, )) + raise ValueError("GTID format is incorrect: %r" % (gtid,)) sid = m.group(1) intervals = m.group(2) - intervals_parsed = [Gtid.parse_interval(x) - for x in intervals.split(':')[1:]] + intervals_parsed = [Gtid.parse_interval(x) for x in intervals.split(":")[1:]] return (sid, intervals_parsed) @@ -95,10 +101,10 @@ def __add_interval(self, itvl): new = [] if itvl[0] > itvl[1]: - raise Exception('Malformed interval %s' % (itvl,)) + raise Exception("Malformed interval %s" % (itvl,)) if any(overlap(x, itvl) for x in self.intervals): - raise Exception('Overlapping interval %s' % (itvl,)) + raise Exception("Overlapping interval %s" % (itvl,)) ## Merge: arrange interval to fit existing set for existing in sorted(self.intervals): @@ -121,7 +127,7 @@ def __sub_interval(self, itvl): new = [] if itvl[0] > itvl[1]: - raise Exception('Malformed interval %s' % (itvl,)) + raise Exception("Malformed interval %s" % (itvl,)) if not any(overlap(x, itvl) for x in self.intervals): # No raise @@ -149,8 +155,9 @@ def __contains__(self, other): if other.sid != self.sid: return False - return all(any(contains(me, them) for me in self.intervals) - for them in other.intervals) + return all( + any(contains(me, them) for me in self.intervals) for them in other.intervals + ) def __init__(self, gtid, sid=None, intervals=[]): if sid: @@ -169,8 +176,9 @@ def __add__(self, other): Raises: Exception: if the attempted merge has different SID""" if self.sid != other.sid: - raise Exception('Attempt to merge different SID' - '%s != %s' % (self.sid, other.sid)) + raise Exception( + "Attempt to merge different SID" "%s != %s" % (self.sid, other.sid) + ) result = deepcopy(self) @@ -194,21 +202,26 @@ def __sub__(self, other): def __str__(self): """We represent the human value here - a single number for one transaction, or a closed interval (decrementing b)""" - return '%s:%s' % (self.sid, - ':'.join(('%d-%d' % (x[0], x[1]-1)) if x[0] +1 != x[1] - else str(x[0]) - for x in self.intervals)) + return "%s:%s" % ( + self.sid, + ":".join( + ("%d-%d" % (x[0], x[1] - 1)) if x[0] + 1 != x[1] else str(x[0]) + for x in self.intervals + ), + ) def __repr__(self): return '' % self @property def encoded_length(self): - return (16 + # sid - 8 + # n_intervals - 2 * # stop/start - 8 * # stop/start mark encoded as int64 - len(self.intervals)) + return ( + 16 + + 8 # sid + + 2 # n_intervals + * 8 # stop/start + * len(self.intervals) # stop/start mark encoded as int64 + ) def encode(self): """Encode a Gtid in binary @@ -236,17 +249,17 @@ def encode(self): - - - - - - - - - - - ``` """ - buffer = b'' + buffer = b"" # sid - buffer += binascii.unhexlify(self.sid.replace('-', '')) + buffer += binascii.unhexlify(self.sid.replace("-", "")) # n_intervals - buffer += struct.pack('' % self.gtids + return "" % self.gtids @property def encoded_length(self): - return (8 + # n_sids - sum(x.encoded_length for x in self.gtids)) + return 8 + sum(x.encoded_length for x in self.gtids) # n_sids def encoded(self): """Encode a GtidSet in binary @@ -415,8 +434,10 @@ def encoded(self): - - - - - - - - - - - ``` """ - return b'' + (struct.pack('b', self.read(size))[0] + return struct.unpack(">b", self.read(size))[0] elif size == 2: - return struct.unpack('>h', self.read(size))[0] + return struct.unpack(">h", self.read(size))[0] elif size == 3: return self.read_int24_be() elif size == 4: - return struct.unpack('>i', self.read(size))[0] + return struct.unpack(">i", self.read(size))[0] elif size == 5: return self.read_int40_be() elif size == 8: - return struct.unpack('>l', self.read(size))[0] + return struct.unpack(">l", self.read(size))[0] def read_uint_by_size(self, size): - '''Read a little endian integer values based on byte number''' + """Read a little endian integer values based on byte number""" if size == 1: return self.read_uint8() elif size == 2: @@ -281,7 +288,7 @@ def read_variable_length_string(self): bits_read = 0 while byte & 0x80 != 0: byte = struct.unpack("!B", self.read(1))[0] - length = length | ((byte & 0x7f) << bits_read) + length = length | ((byte & 0x7F) << bits_read) bits_read = bits_read + 7 return self.read(length) @@ -293,30 +300,30 @@ def read_int24(self): return res def read_int24_be(self): - a, b, c = struct.unpack('BBB', self.read(3)) + a, b, c = struct.unpack("BBB", self.read(3)) res = (a << 16) | (b << 8) | c if res >= 0x800000: res -= 0x1000000 return res def read_uint8(self): - return struct.unpack(' length: - raise ValueError('Json length is larger than packet length') + raise ValueError("Json length is larger than packet length") if large: - key_offset_lengths = [( - self.read_uint32(), # offset (we don't actually need that) - self.read_uint16() # size of the key - ) for _ in range(elements)] + key_offset_lengths = [ + ( + self.read_uint32(), # offset (we don't actually need that) + self.read_uint16(), # size of the key + ) + for _ in range(elements) + ] else: - key_offset_lengths = [( - self.read_uint16(), # offset (we don't actually need that) - self.read_uint16() # size of key - ) for _ in range(elements)] - - value_type_inlined_lengths = [read_offset_or_inline(self, large) - for _ in range(elements)] + key_offset_lengths = [ + ( + self.read_uint16(), # offset (we don't actually need that) + self.read_uint16(), # size of key + ) + for _ in range(elements) + ] + + value_type_inlined_lengths = [ + read_offset_or_inline(self, large) for _ in range(elements) + ] keys = [self.read(x[1]) for x in key_offset_lengths] @@ -471,11 +489,11 @@ def read_binary_json_array(self, length, large): size = self.read_uint16() if size > length: - raise ValueError('Json length is larger than packet length') + raise ValueError("Json length is larger than packet length") values_type_offset_inline = [ - read_offset_or_inline(self, large) - for _ in range(elements)] + read_offset_or_inline(self, large) for _ in range(elements) + ] def _read(x): if x[1] is None: @@ -492,14 +510,14 @@ def read_string(self): Returns: Binary string parsed from __data_buffer """ - string = b'' + string = b"" while True: char = self.read(1) - if char == b'\0': + if char == b"\0": break string += char return string def bytes_to_read(self): - return len(self.packet._data) - self.packet._position \ No newline at end of file + return len(self.packet._data) - self.packet._position diff --git a/pymysqlreplication/row_event.py b/pymysqlreplication/row_event.py index 18ff2728..34267b63 100644 --- a/pymysqlreplication/row_event.py +++ b/pymysqlreplication/row_event.py @@ -3,7 +3,6 @@ import struct import decimal import datetime -import json from pymysql.charset import charset_by_name from enum import Enum @@ -16,17 +15,17 @@ from .table import Table from .bitmap import BitCount, BitGet + class RowsEvent(BinLogEvent): def __init__(self, from_packet, event_size, table_map, ctl_connection, **kwargs): - super().__init__(from_packet, event_size, table_map, - ctl_connection, **kwargs) + super().__init__(from_packet, event_size, table_map, ctl_connection, **kwargs) self.__rows = None self.__only_tables = kwargs["only_tables"] self.__ignored_tables = kwargs["ignored_tables"] self.__only_schemas = kwargs["only_schemas"] self.__ignored_schemas = kwargs["ignored_schemas"] - #Header + # Header self.table_id = self._read_table_id() # Additional information @@ -34,7 +33,7 @@ def __init__(self, from_packet, event_size, table_map, ctl_connection, **kwargs) self.primary_key = table_map[self.table_id].data["primary_key"] self.schema = self.table_map[self.table_id].schema self.table = self.table_map[self.table_id].table - except KeyError: #If we have filter the corresponding TableMap Event + except KeyError: # If we have filter the corresponding TableMap Event self._processed = False return @@ -48,41 +47,52 @@ def __init__(self, from_packet, event_size, table_map, ctl_connection, **kwargs) if self.__only_schemas is not None and self.schema not in self.__only_schemas: self._processed = False return - elif self.__ignored_schemas is not None and self.schema in self.__ignored_schemas: + elif ( + self.__ignored_schemas is not None and self.schema in self.__ignored_schemas + ): self._processed = False return - - #Event V2 - if self.event_type == BINLOG.WRITE_ROWS_EVENT_V2 or \ - self.event_type == BINLOG.DELETE_ROWS_EVENT_V2 or \ - self.event_type == BINLOG.UPDATE_ROWS_EVENT_V2: - self.flags, self.extra_data_length = struct.unpack(' 2: - self.extra_data_type = struct.unpack(' 2: + self.extra_data_type = struct.unpack(" 255 else self.__read_string(1, column) + elif column.type == FIELD_TYPE.VARCHAR or column.type == FIELD_TYPE.STRING: + ret = ( + self.__read_string(2, column) + if column.max_length > 255 + else self.__read_string(1, column) + ) if fixed_binary_length and len(ret) < fixed_binary_length: # Fixed-length binary fields are stored in the binlog # without trailing zeros and must be padded with zeros up # to the specified length at read time. nr_pad = fixed_binary_length - len(ret) - ret += b'\x00' * nr_pad + ret += b"\x00" * nr_pad return ret elif column.type == FIELD_TYPE.NEWDECIMAL: return self.__read_new_decimal(column) @@ -189,8 +219,7 @@ def __read_values_name(self, column, null_bitmap, null_bitmap_index, cols_bitmap elif column.type == FIELD_TYPE.DATE: return self.__read_date() elif column.type == FIELD_TYPE.TIMESTAMP: - return datetime.datetime.fromtimestamp( - self.packet.read_uint32()) + return datetime.datetime.utcfromtimestamp(self.packet.read_uint32()) # For new date format: elif column.type == FIELD_TYPE.DATETIME2: @@ -199,40 +228,42 @@ def __read_values_name(self, column, null_bitmap, null_bitmap_index, cols_bitmap return self.__read_time2(column) elif column.type == FIELD_TYPE.TIMESTAMP2: return self.__add_fsp_to_time( - datetime.datetime.fromtimestamp( - self.packet.read_int_be_by_size(4)), column) + datetime.datetime.utcfromtimestamp(self.packet.read_int_be_by_size(4)), + column, + ) elif column.type == FIELD_TYPE.LONGLONG: if unsigned: ret = self.packet.read_uint64() if zerofill: - ret = format(ret, '020d') + ret = format(ret, "020d") return ret else: return self.packet.read_int64() elif column.type == FIELD_TYPE.YEAR: return self.packet.read_uint8() + 1900 elif column.type == FIELD_TYPE.ENUM: - return column.enum_values[ - self.packet.read_uint_by_size(column.size)] + return column.enum_values[self.packet.read_uint_by_size(column.size)] elif column.type == FIELD_TYPE.SET: # We read set columns as a bitmap telling us which options # are enabled bit_mask = self.packet.read_uint_by_size(column.size) - return set( - val for idx, val in enumerate(column.set_values) - if bit_mask & 2 ** idx - ) or None + return ( + set( + val + for idx, val in enumerate(column.set_values) + if bit_mask & 2**idx + ) + or None + ) elif column.type == FIELD_TYPE.BIT: return self.__read_bit(column) elif column.type == FIELD_TYPE.GEOMETRY: - return self.packet.read_length_coded_pascal_string( - column.length_size) + return self.packet.read_length_coded_pascal_string(column.length_size) elif column.type == FIELD_TYPE.JSON: return self.packet.read_binary_json(column.length_size) else: - raise NotImplementedError("Unknown MySQL column type: %d" % - (column.type)) + raise NotImplementedError("Unknown MySQL column type: %d" % (column.type)) def __add_fsp_to_time(self, time, column): """Read and add the fractional part of time @@ -256,7 +287,7 @@ def __read_fsp(self, column): microsecond = self.packet.read_int_be_by_size(read) if column.fsp % 2: microsecond = int(microsecond / 10) - return microsecond * (10 ** (6-column.fsp)) + return microsecond * (10 ** (6 - column.fsp)) return 0 @staticmethod @@ -300,7 +331,8 @@ def __read_time(self): date = datetime.timedelta( hours=int(time / 10000), minutes=int((time % 10000) / 100), - seconds=int(time % 100)) + seconds=int(time % 100), + ) return date def __read_time2(self, column): @@ -322,12 +354,15 @@ def __read_time2(self, column): # hence take 2's compliment again to get the right value. data = ~data + 1 - t = datetime.timedelta( - hours=self.__read_binary_slice(data, 2, 10, 24), - minutes=self.__read_binary_slice(data, 12, 6, 24), - seconds=self.__read_binary_slice(data, 18, 6, 24), - microseconds=self.__read_fsp(column) - ) * sign + t = ( + datetime.timedelta( + hours=self.__read_binary_slice(data, 2, 10, 24), + minutes=self.__read_binary_slice(data, 12, 6, 24), + seconds=self.__read_binary_slice(data, 18, 6, 24), + microseconds=self.__read_fsp(column), + ) + * sign + ) return t def __read_date(self): @@ -337,15 +372,11 @@ def __read_date(self): year = (time & ((1 << 15) - 1) << 9) >> 9 month = (time & ((1 << 4) - 1) << 5) >> 5 - day = (time & ((1 << 5) - 1)) + day = time & ((1 << 5) - 1) if year == 0 or month == 0 or day == 0: return None - date = datetime.date( - year=year, - month=month, - day=day - ) + date = datetime.date(year=year, month=month, day=day) return date def __read_datetime(self): @@ -368,7 +399,8 @@ def __read_datetime(self): day=day, hour=int(time / 10000), minute=int((time % 10000) / 100), - second=int(time % 100)) + second=int(time % 100), + ) return date def __read_datetime2(self, column): @@ -392,7 +424,8 @@ def __read_datetime2(self, column): day=self.__read_binary_slice(data, 18, 5, 40), hour=self.__read_binary_slice(data, 23, 5, 40), minute=self.__read_binary_slice(data, 28, 6, 40), - second=self.__read_binary_slice(data, 34, 6, 40)) + second=self.__read_binary_slice(data, 34, 6, 40), + ) except ValueError: self.__read_fsp(column) return None @@ -407,12 +440,11 @@ def __read_new_decimal(self, column): digits_per_integer = 9 compressed_bytes = [0, 1, 1, 2, 2, 3, 3, 4, 4, 4] - integral = (column.precision - column.decimals) + integral = column.precision - column.decimals uncomp_integral = int(integral / digits_per_integer) uncomp_fractional = int(column.decimals / digits_per_integer) comp_integral = integral - (uncomp_integral * digits_per_integer) - comp_fractional = column.decimals - (uncomp_fractional - * digits_per_integer) + comp_fractional = column.decimals - (uncomp_fractional * digits_per_integer) # Support negative # The sign is encoded in the high bit of the the byte @@ -424,7 +456,7 @@ def __read_new_decimal(self, column): else: mask = -1 res = "-" - self.packet.unread(struct.pack(' 0: @@ -432,19 +464,19 @@ def __read_new_decimal(self, column): res += str(value) for i in range(0, uncomp_integral): - value = struct.unpack('>i', self.packet.read(4))[0] ^ mask - res += '%09d' % value + value = struct.unpack(">i", self.packet.read(4))[0] ^ mask + res += "%09d" % value res += "." for i in range(0, uncomp_fractional): - value = struct.unpack('>i', self.packet.read(4))[0] ^ mask - res += '%09d' % value + value = struct.unpack(">i", self.packet.read(4))[0] ^ mask + res += "%09d" % value size = compressed_bytes[comp_fractional] if size > 0: value = self.packet.read_int_be_by_size(size) ^ mask - res += '%0*d' % (comp_fractional, value) + res += "%0*d" % (comp_fractional, value) return decimal.Decimal(res) @@ -457,7 +489,7 @@ def __read_binary_slice(self, binary, start, size, data_length): data_length: data size """ binary = binary >> data_length - (start + size) - mask = ((1 << size) - 1) + mask = (1 << size) - 1 return binary & mask def _dump(self): @@ -489,11 +521,11 @@ class DeleteRowsEvent(RowsEvent): """ def __init__(self, from_packet, event_size, table_map, ctl_connection, **kwargs): - super().__init__(from_packet, event_size, - table_map, ctl_connection, **kwargs) + super().__init__(from_packet, event_size, table_map, ctl_connection, **kwargs) if self._processed: self.columns_present_bitmap = self.packet.read( - (self.number_of_columns + 7) / 8) + (self.number_of_columns + 7) / 8 + ) def _fetch_one_row(self): row = {} @@ -517,11 +549,11 @@ class WriteRowsEvent(RowsEvent): """ def __init__(self, from_packet, event_size, table_map, ctl_connection, **kwargs): - super().__init__(from_packet, event_size, - table_map, ctl_connection, **kwargs) + super().__init__(from_packet, event_size, table_map, ctl_connection, **kwargs) if self._processed: self.columns_present_bitmap = self.packet.read( - (self.number_of_columns + 7) / 8) + (self.number_of_columns + 7) / 8 + ) def _fetch_one_row(self): row = {} @@ -550,14 +582,15 @@ class UpdateRowsEvent(RowsEvent): """ def __init__(self, from_packet, event_size, table_map, ctl_connection, **kwargs): - super().__init__(from_packet, event_size, - table_map, ctl_connection, **kwargs) + super().__init__(from_packet, event_size, table_map, ctl_connection, **kwargs) if self._processed: # Body self.columns_present_bitmap = self.packet.read( - (self.number_of_columns + 7) / 8) + (self.number_of_columns + 7) / 8 + ) self.columns_present_bitmap2 = self.packet.read( - (self.number_of_columns + 7) / 8) + (self.number_of_columns + 7) / 8 + ) def _fetch_one_row(self): row = {} @@ -574,9 +607,11 @@ def _dump(self): for row in self.rows: print("--") for key in row["before_values"]: - print("*%s:%s=>%s" % (key, - row["before_values"][key], - row["after_values"][key])) + print( + "*%s:%s=>%s" + % (key, row["before_values"][key], row["after_values"][key]) + ) + class OptionalMetaData: def __init__(self): @@ -613,6 +648,7 @@ def dump(self): print("charset_collation_list: %s" % self.charset_collation_list) print("enum_and_set_collation_list: %s" % self.enum_and_set_collation_list) + class TableMapEvent(BinLogEvent): """This event describes the structure of a table. It's sent before a change happens on a table. @@ -620,8 +656,7 @@ class TableMapEvent(BinLogEvent): """ def __init__(self, from_packet, event_size, table_map, ctl_connection, **kwargs): - super().__init__(from_packet, event_size, - table_map, ctl_connection, **kwargs) + super().__init__(from_packet, event_size, table_map, ctl_connection, **kwargs) self.__only_tables = kwargs["only_tables"] self.__ignored_tables = kwargs["ignored_tables"] self.__only_schemas = kwargs["only_schemas"] @@ -635,7 +670,7 @@ def __init__(self, from_packet, event_size, table_map, ctl_connection, **kwargs) self._processed = False return - self.flags = struct.unpack(' BINLOG.BINLOG_CHECKSUM_LEN: option_metadata_type = self.packet.read(1)[0] length = self.packet.read_length_coded_binary() - field_type: MetadataFieldType = MetadataFieldType.by_index(option_metadata_type) + field_type: MetadataFieldType = MetadataFieldType.by_index( + option_metadata_type + ) if field_type == MetadataFieldType.SIGNEDNESS: signed_column_list = self._convert_include_non_numeric_column( - self._read_bool_list(length, True)) + self._read_bool_list(length, True) + ) optional_metadata.unsigned_column_list = signed_column_list elif field_type == MetadataFieldType.DEFAULT_CHARSET: - optional_metadata.default_charset_collation, optional_metadata.charset_collation = self._read_default_charset( - length) - optional_metadata.charset_collation_list = self._parsed_column_charset_by_default_charset( + ( optional_metadata.default_charset_collation, optional_metadata.charset_collation, - self._is_character_column) + ) = self._read_default_charset(length) + optional_metadata.charset_collation_list = ( + self._parsed_column_charset_by_default_charset( + optional_metadata.default_charset_collation, + optional_metadata.charset_collation, + self._is_character_column, + ) + ) elif field_type == MetadataFieldType.COLUMN_CHARSET: optional_metadata.column_charset = self._read_ints(length) - optional_metadata.charset_collation_list = self._parsed_column_charset_by_column_charset( - optional_metadata.column_charset, self._is_character_column) + optional_metadata.charset_collation_list = ( + self._parsed_column_charset_by_column_charset( + optional_metadata.column_charset, self._is_character_column + ) + ) elif field_type == MetadataFieldType.COLUMN_NAME: optional_metadata.column_name_list = self._read_column_names(length) @@ -765,7 +816,9 @@ def _get_optional_meta_data(self): optional_metadata.set_str_value_list = self._read_type_values(length) elif field_type == MetadataFieldType.ENUM_STR_VALUE: - optional_metadata.set_enum_str_value_list = self._read_type_values(length) + optional_metadata.set_enum_str_value_list = self._read_type_values( + length + ) elif field_type == MetadataFieldType.GEOMETRY_TYPE: optional_metadata.geometry_type_list = self._read_ints(length) @@ -774,22 +827,35 @@ def _get_optional_meta_data(self): optional_metadata.simple_primary_key_list = self._read_ints(length) elif field_type == MetadataFieldType.PRIMARY_KEY_WITH_PREFIX: - optional_metadata.primary_keys_with_prefix = self._read_primary_keys_with_prefix(length) + optional_metadata.primary_keys_with_prefix = ( + self._read_primary_keys_with_prefix(length) + ) elif field_type == MetadataFieldType.ENUM_AND_SET_DEFAULT_CHARSET: - optional_metadata.enum_and_set_default_charset, optional_metadata.enum_and_set_charset_collation = self._read_default_charset( - length) - - optional_metadata.enum_and_set_collation_list = self._parsed_column_charset_by_default_charset( + ( optional_metadata.enum_and_set_default_charset, optional_metadata.enum_and_set_charset_collation, - self._is_enum_or_set_column) + ) = self._read_default_charset(length) + + optional_metadata.enum_and_set_collation_list = ( + self._parsed_column_charset_by_default_charset( + optional_metadata.enum_and_set_default_charset, + optional_metadata.enum_and_set_charset_collation, + self._is_enum_or_set_column, + ) + ) elif field_type == MetadataFieldType.ENUM_AND_SET_COLUMN_CHARSET: - optional_metadata.enum_and_set_default_column_charset_list = self._read_ints(length) + optional_metadata.enum_and_set_default_column_charset_list = ( + self._read_ints(length) + ) - optional_metadata.enum_and_set_collation_list = self._parsed_column_charset_by_column_charset( - optional_metadata.enum_and_set_default_column_charset_list, self._is_enum_or_set_column) + optional_metadata.enum_and_set_collation_list = ( + self._parsed_column_charset_by_column_charset( + optional_metadata.enum_and_set_default_column_charset_list, + self._is_enum_or_set_column, + ) + ) elif field_type == MetadataFieldType.VISIBILITY: optional_metadata.visibility_list = self._read_bool_list(length, False) @@ -814,8 +880,12 @@ def _convert_include_non_numeric_column(self, signedness_bool_list): return bool_list - def _parsed_column_charset_by_default_charset(self, default_charset_collation: int, column_charset_collation: dict, - column_type_detect_function): + def _parsed_column_charset_by_default_charset( + self, + default_charset_collation: int, + column_charset_collation: dict, + column_type_detect_function, + ): column_charset = [] for i in range(self.column_count): column_type = self.columns[i].type @@ -828,7 +898,9 @@ def _parsed_column_charset_by_default_charset(self, default_charset_collation: i return column_charset - def _parsed_column_charset_by_column_charset(self, column_charset_list: list, column_type_detect_function): + def _parsed_column_charset_by_column_charset( + self, column_charset_list: list, column_type_detect_function + ): column_charset = [] position = 0 if len(column_charset_list) == 0: @@ -920,10 +992,15 @@ def _read_primary_keys_with_prefix(self, length): return result @staticmethod - def _is_character_column(column_type, dbms='mysql'): - if column_type in [FIELD_TYPE.STRING, FIELD_TYPE.VAR_STRING, FIELD_TYPE.VARCHAR, FIELD_TYPE.BLOB]: + def _is_character_column(column_type, dbms="mysql"): + if column_type in [ + FIELD_TYPE.STRING, + FIELD_TYPE.VAR_STRING, + FIELD_TYPE.VARCHAR, + FIELD_TYPE.BLOB, + ]: return True - if column_type == FIELD_TYPE.GEOMETRY and dbms == 'mariadb': + if column_type == FIELD_TYPE.GEOMETRY and dbms == "mariadb": return True return False @@ -940,20 +1017,28 @@ def _is_set_column(column_type): return False @staticmethod - def _is_enum_or_set_column(column_type, dbms='mysql'): + def _is_enum_or_set_column(column_type, dbms="mysql"): if column_type in [FIELD_TYPE.ENUM, FIELD_TYPE.SET]: return True return False @staticmethod def _is_numeric_column(column_type): - if column_type in [FIELD_TYPE.TINY, FIELD_TYPE.SHORT, FIELD_TYPE.INT24, FIELD_TYPE.LONG, - FIELD_TYPE.LONGLONG, FIELD_TYPE.NEWDECIMAL, FIELD_TYPE.FLOAT, - FIELD_TYPE.DOUBLE, - FIELD_TYPE.YEAR]: + if column_type in [ + FIELD_TYPE.TINY, + FIELD_TYPE.SHORT, + FIELD_TYPE.INT24, + FIELD_TYPE.LONG, + FIELD_TYPE.LONGLONG, + FIELD_TYPE.NEWDECIMAL, + FIELD_TYPE.FLOAT, + FIELD_TYPE.DOUBLE, + FIELD_TYPE.YEAR, + ]: return True return False + class MetadataFieldType(Enum): SIGNEDNESS = 1 # Signedness of numeric columns DEFAULT_CHARSET = 2 # Charsets of character columns diff --git a/pymysqlreplication/table.py b/pymysqlreplication/table.py index ed473cd6..a0fb025e 100644 --- a/pymysqlreplication/table.py +++ b/pymysqlreplication/table.py @@ -2,28 +2,32 @@ class Table(object): - def __init__(self, column_schemas, table_id, schema, table, columns, primary_key=None): + def __init__( + self, column_schemas, table_id, schema, table, columns, primary_key=None + ): if primary_key is None: primary_key = [c.data["name"] for c in columns if c.data["is_primary"]] if len(primary_key) == 0: - primary_key = '' + primary_key = "" elif len(primary_key) == 1: - primary_key, = primary_key + (primary_key,) = primary_key else: primary_key = tuple(primary_key) - self.__dict__.update({ - "column_schemas": column_schemas, - "table_id": table_id, - "schema": schema, - "table": table, - "columns": columns, - "primary_key": primary_key - }) + self.__dict__.update( + { + "column_schemas": column_schemas, + "table_id": table_id, + "schema": schema, + "table": table, + "columns": columns, + "primary_key": primary_key, + } + ) @property def data(self): - return dict((k, v) for (k, v) in self.__dict__.items() if not k.startswith('_')) + return dict((k, v) for (k, v) in self.__dict__.items() if not k.startswith("_")) def __eq__(self, other): return self.data == other.data diff --git a/pymysqlreplication/tests/base.py b/pymysqlreplication/tests/base.py index 88acda33..f2ddbcef 100644 --- a/pymysqlreplication/tests/base.py +++ b/pymysqlreplication/tests/base.py @@ -27,7 +27,7 @@ def setUp(self, charset="utf8"): "port": 3306, "use_unicode": True, "charset": charset, - "db": "pymysqlreplication_test" + "db": "pymysqlreplication_test", } self.conn_control = None @@ -47,20 +47,20 @@ def getMySQLVersion(self): """Return the MySQL version of the server If version is 5.6.10-log the result is 5.6.10 """ - return self.execute("SELECT VERSION()").fetchone()[0].split('-')[0] + return self.execute("SELECT VERSION()").fetchone()[0].split("-")[0] def isMySQL56AndMore(self): - version = float(self.getMySQLVersion().rsplit('.', 1)[0]) + version = float(self.getMySQLVersion().rsplit(".", 1)[0]) if version >= 5.6: return True return False def isMySQL57(self): - version = float(self.getMySQLVersion().rsplit('.', 1)[0]) + version = float(self.getMySQLVersion().rsplit(".", 1)[0]) return version == 5.7 def isMySQL80AndMore(self): - version = float(self.getMySQLVersion().rsplit('.', 1)[0]) + version = float(self.getMySQLVersion().rsplit(".", 1)[0]) return version >= 8.0 def isMySQL8014AndMore(self): @@ -72,7 +72,9 @@ def isMySQL8014AndMore(self): def isMariaDB(self): if self.__is_mariaDB is None: - self.__is_mariaDB = "MariaDB" in self.execute("SELECT VERSION()").fetchone()[0] + self.__is_mariaDB = ( + "MariaDB" in self.execute("SELECT VERSION()").fetchone()[0] + ) return self.__is_mariaDB @property @@ -96,7 +98,7 @@ def execute(self, query): c = self.conn_control.cursor() c.execute(query) return c - + def execute_with_args(self, query, args): c = self.conn_control.cursor() c.execute(query, args) @@ -106,12 +108,13 @@ def resetBinLog(self): self.execute("RESET MASTER") if self.stream is not None: self.stream.close() - self.stream = BinLogStreamReader(self.database, server_id=1024, - ignored_events=self.ignoredEvents()) + self.stream = BinLogStreamReader( + self.database, server_id=1024, ignored_events=self.ignoredEvents() + ) def set_sql_mode(self): """set sql_mode to test with same sql_mode (mysql 5.7 sql_mode default is changed)""" - version = float(self.getMySQLVersion().rsplit('.', 1)[0]) + version = float(self.getMySQLVersion().rsplit(".", 1)[0]) if version == 5.7: self.execute("SET @@sql_mode='NO_ENGINE_SUBSTITUTION'") @@ -122,7 +125,7 @@ def bin_log_format(self): return result[0] def bin_log_basename(self): - cursor = self.execute('SELECT @@log_bin_basename') + cursor = self.execute("SELECT @@log_bin_basename") bin_log_basename = cursor.fetchone()[0] bin_log_basename = bin_log_basename.split("/")[-1] return bin_log_basename @@ -138,7 +141,7 @@ def setUp(self): "port": int(os.environ.get("MARIADB_10_6_PORT") or 3308), "use_unicode": True, "charset": "utf8", - "db": "pymysqlreplication_test" + "db": "pymysqlreplication_test", } self.conn_control = None @@ -151,9 +154,9 @@ def setUp(self): self.connect_conn_control(db) self.stream = None self.resetBinLog() - + def bin_log_basename(self): - cursor = self.execute('SELECT @@log_bin_basename') + cursor = self.execute("SELECT @@log_bin_basename") bin_log_basename = cursor.fetchone()[0] bin_log_basename = bin_log_basename.split("/")[-1] return bin_log_basename diff --git a/pymysqlreplication/tests/benchmark.py b/pymysqlreplication/tests/benchmark.py index c947d116..36b5dbfd 100644 --- a/pymysqlreplication/tests/benchmark.py +++ b/pymysqlreplication/tests/benchmark.py @@ -7,11 +7,9 @@ import pymysql import time -import random import os from pymysqlreplication import BinLogStreamReader from pymysqlreplication.row_event import * -import cProfile def execute(con, query): @@ -19,19 +17,22 @@ def execute(con, query): c.execute(query) return c + def consume_events(): - stream = BinLogStreamReader(connection_settings=database, - server_id=3, - resume_stream=False, - blocking=True, - only_events = [UpdateRowsEvent], - only_tables = ['test'] ) + stream = BinLogStreamReader( + connection_settings=database, + server_id=3, + resume_stream=False, + blocking=True, + only_events=[UpdateRowsEvent], + only_tables=["test"], + ) start = time.clock() i = 0.0 for binlogevent in stream: - i += 1.0 - if i % 1000 == 0: - print("%d event by seconds (%d total)" % (i / (time.clock() - start), i)) + i += 1.0 + if i % 1000 == 0: + print("%d event by seconds (%d total)" % (i / (time.clock() - start), i)) stream.close() @@ -41,7 +42,7 @@ def consume_events(): "passwd": "", "use_unicode": True, "charset": "utf8", - "db": "pymysqlreplication_test" + "db": "pymysqlreplication_test", } conn = pymysql.connect(**database) @@ -63,5 +64,4 @@ def consume_events(): execute(conn, "UPDATE test2 SET i = i + 1;") else: consume_events() - #cProfile.run('consume_events()') - + # cProfile.run('consume_events()') diff --git a/pymysqlreplication/tests/binlogfilereader.py b/pymysqlreplication/tests/binlogfilereader.py index 7075039e..6dabbeb2 100644 --- a/pymysqlreplication/tests/binlogfilereader.py +++ b/pymysqlreplication/tests/binlogfilereader.py @@ -1,4 +1,4 @@ -'''Read binlog files''' +"""Read binlog files""" import struct from pymysqlreplication import constants @@ -9,10 +9,11 @@ from pymysqlreplication.row_event import TableMapEvent from pymysqlreplication.row_event import WriteRowsEvent + class SimpleBinLogFileReader(object): - '''Read binlog files''' + """Read binlog files""" - _expected_magic = b'\xfebin' + _expected_magic = b"\xfebin" def __init__(self, file_path, only_events=None): self._current_event = None @@ -22,7 +23,7 @@ def __init__(self, file_path, only_events=None): self._pos = None def fetchone(self): - '''Fetch one record from the binlog file''' + """Fetch one record from the binlog file""" if self._pos is None or self._pos < 4: self._read_magic() while True: @@ -34,12 +35,12 @@ def fetchone(self): return event def truncatebinlog(self): - '''Truncate the binlog file at the current event''' + """Truncate the binlog file at the current event""" if self._current_event is not None: self._file.truncate(self._current_event.pos) def _filter_events(self, event): - '''Return True if an event can be returned''' + """Return True if an event can be returned""" # It would be good if we could reuse the __event_map in # packet.BinLogPacketWrapper. event_type = { @@ -53,14 +54,14 @@ def _filter_events(self, event): return event_type in self._only_events def _open_file(self): - '''Open the file at ``self._file_path``''' + """Open the file at ``self._file_path``""" if self._file is None: - self._file = open(self._file_path, 'rb+') + self._file = open(self._file_path, "rb+") self._pos = self._file.tell() assert self._pos == 0 def _read_event(self): - '''Read an event from the binlog file''' + """Read an event from the binlog file""" # Assuming a binlog version > 1 headerlength = 19 header = self._file.read(headerlength) @@ -71,7 +72,7 @@ def _read_event(self): event = SimpleBinLogEvent(header) event.set_pos(event_pos) if event.event_size < headerlength: - messagefmt = 'Event size {0} is too small' + messagefmt = "Event size {0} is too small" message = messagefmt.format(event.event_size) raise EventSizeTooSmallError(message) else: @@ -81,14 +82,14 @@ def _read_event(self): return event def _read_magic(self): - '''Read the first four *magic* bytes of the binlog file''' + """Read the first four *magic* bytes of the binlog file""" self._open_file() if self._pos == 0: magic = self._file.read(4) if magic == self._expected_magic: self._pos += len(magic) else: - messagefmt = 'Magic bytes {0!r} did not match expected {1!r}' + messagefmt = "Magic bytes {0!r} did not match expected {1!r}" message = messagefmt.format(magic, self._expected_magic) raise BadMagicBytesError(message) @@ -100,17 +101,17 @@ def __repr__(self): mod = cls.__module__ name = cls.__name__ only = [type(x).__name__ for x in self._only_events] - fmt = '<{mod}.{name}(file_path={fpath}, only_events={only})>' + fmt = "<{mod}.{name}(file_path={fpath}, only_events={only})>" return fmt.format(mod=mod, name=name, fpath=self._file_path, only=only) # pylint: disable=too-many-instance-attributes class SimpleBinLogEvent(object): - '''An event from a binlog file''' + """An event from a binlog file""" def __init__(self, header): - '''Initialize the Event with the event header''' - unpacked = struct.unpack(' U(0110111) modified_byte = b"U" - wrong_event_data = correct_event_data[:1] + modified_byte + correct_event_data[2:] + wrong_event_data = ( + correct_event_data[:1] + modified_byte + correct_event_data[2:] + ) packet = MysqlPacket(correct_event_data, 0) wrong_packet = MysqlPacket(wrong_event_data, 0) @@ -593,7 +614,7 @@ def test_insert_multiple_row_event(self): self.assertIsInstance(self.stream.fetchone(), RotateEvent) self.assertIsInstance(self.stream.fetchone(), FormatDescriptionEvent) - #QueryEvent for the BEGIN + # QueryEvent for the BEGIN self.assertIsInstance(self.stream.fetchone(), QueryEvent) self.assertIsInstance(self.stream.fetchone(), TableMapEvent) @@ -627,7 +648,7 @@ def test_update_multiple_row_event(self): self.assertIsInstance(self.stream.fetchone(), RotateEvent) self.assertIsInstance(self.stream.fetchone(), FormatDescriptionEvent) - #QueryEvent for the BEGIN + # QueryEvent for the BEGIN self.assertIsInstance(self.stream.fetchone(), QueryEvent) self.assertIsInstance(self.stream.fetchone(), TableMapEvent) @@ -666,7 +687,7 @@ def test_delete_multiple_row_event(self): self.assertIsInstance(self.stream.fetchone(), RotateEvent) self.assertIsInstance(self.stream.fetchone(), FormatDescriptionEvent) - #QueryEvent for the BEGIN + # QueryEvent for the BEGIN self.assertIsInstance(self.stream.fetchone(), QueryEvent) self.assertIsInstance(self.stream.fetchone(), TableMapEvent) @@ -690,14 +711,14 @@ def test_drop_table(self): self.execute("DROP TABLE test") self.execute("COMMIT") - #RotateEvent + # RotateEvent self.stream.fetchone() - #FormatDescription + # FormatDescription self.stream.fetchone() - #QueryEvent for the Create Table + # QueryEvent for the Create Table self.stream.fetchone() - #QueryEvent for the BEGIN + # QueryEvent for the BEGIN self.stream.fetchone() event = self.stream.fetchone() @@ -723,11 +744,11 @@ def test_drop_table_tablemetadata_unavailable(self): self.database, server_id=1024, only_events=(WriteRowsEvent,), - fail_on_table_metadata_unavailable=True + fail_on_table_metadata_unavailable=True, ) had_error = False try: - event = self.stream.fetchone() + self.stream.fetchone() except TableMetadataUnavailableError as e: had_error = True assert "test" in e.args[0] @@ -736,10 +757,14 @@ def test_drop_table_tablemetadata_unavailable(self): assert had_error def test_ignore_decode_errors(self): - problematic_unicode_string = b'[{"text":"\xed\xa0\xbd \xed\xb1\x8d Some string"}]' + problematic_unicode_string = ( + b'[{"text":"\xed\xa0\xbd \xed\xb1\x8d Some string"}]' + ) self.stream.close() self.execute("CREATE TABLE test (data VARCHAR(50) CHARACTER SET utf8mb4)") - self.execute_with_args("INSERT INTO test (data) VALUES (%s)", (problematic_unicode_string)) + self.execute_with_args( + "INSERT INTO test (data) VALUES (%s)", (problematic_unicode_string) + ) self.execute("COMMIT") # Initialize with ignore_decode_errors=False @@ -747,11 +772,11 @@ def test_ignore_decode_errors(self): self.database, server_id=1024, only_events=(WriteRowsEvent,), - ignore_decode_errors=False + ignore_decode_errors=False, ) event = self.stream.fetchone() event = self.stream.fetchone() - with self.assertRaises(UnicodeError) as exception: + with self.assertRaises(UnicodeError): event = self.stream.fetchone() data = event.rows[0]["values"]["data"] @@ -760,7 +785,7 @@ def test_ignore_decode_errors(self): self.database, server_id=1024, only_events=(WriteRowsEvent,), - ignore_decode_errors=True + ignore_decode_errors=True, ) self.stream.fetchone() self.stream.fetchone() @@ -778,10 +803,8 @@ def test_drop_column(self): self.execute("COMMIT") self.stream = BinLogStreamReader( - self.database, - server_id=1024, - only_events=(WriteRowsEvent,) - ) + self.database, server_id=1024, only_events=(WriteRowsEvent,) + ) try: self.stream.fetchone() # insert with two values self.stream.fetchone() # insert with one value @@ -793,19 +816,25 @@ def test_drop_column(self): @unittest.expectedFailure def test_alter_column(self): self.stream.close() - self.execute("CREATE TABLE test_alter_column (id INTEGER(11), data VARCHAR(50))") + self.execute( + "CREATE TABLE test_alter_column (id INTEGER(11), data VARCHAR(50))" + ) self.execute("INSERT INTO test_alter_column VALUES (1, 'A value')") self.execute("COMMIT") # this is a problem only when column is added in position other than at the end - self.execute("ALTER TABLE test_alter_column ADD COLUMN another_data VARCHAR(50) AFTER id") - self.execute("INSERT INTO test_alter_column VALUES (2, 'Another value', 'A value')") + self.execute( + "ALTER TABLE test_alter_column ADD COLUMN another_data VARCHAR(50) AFTER id" + ) + self.execute( + "INSERT INTO test_alter_column VALUES (2, 'Another value', 'A value')" + ) self.execute("COMMIT") self.stream = BinLogStreamReader( self.database, server_id=1024, only_events=(WriteRowsEvent,), - ) + ) event = self.stream.fetchone() # insert with two values # both of these asserts fail because of issue underlying proble described in issue #118 # because it got table schema info after the alter table, it wrongly assumes the second @@ -814,12 +843,11 @@ def test_alter_column(self): # AR: {'id': 1, 'another_data': 'A value'} self.assertIn("data", event.rows[0]["values"]) self.assertNot("another_data", event.rows[0]["values"]) - self.assertEqual(event.rows[0]["values"]["data"], 'A value') + self.assertEqual(event.rows[0]["values"]["data"], "A value") self.stream.fetchone() # insert with three values class TestCTLConnectionSettings(base.PyMySQLReplicationTestCase): - def setUp(self): super().setUp() self.stream.close() @@ -828,8 +856,12 @@ def setUp(self): ctl_db["port"] = int(os.environ.get("MYSQL_5_7_CTL_PORT") or 3307) ctl_db["host"] = os.environ.get("MYSQL_5_7_CTL") or "localhost" self.ctl_conn_control = pymysql.connect(**ctl_db) - self.ctl_conn_control.cursor().execute("DROP DATABASE IF EXISTS pymysqlreplication_test") - self.ctl_conn_control.cursor().execute("CREATE DATABASE pymysqlreplication_test") + self.ctl_conn_control.cursor().execute( + "DROP DATABASE IF EXISTS pymysqlreplication_test" + ) + self.ctl_conn_control.cursor().execute( + "CREATE DATABASE pymysqlreplication_test" + ) self.ctl_conn_control.close() ctl_db["db"] = "pymysqlreplication_test" self.ctl_conn_control = pymysql.connect(**ctl_db) @@ -838,7 +870,7 @@ def setUp(self): ctl_connection_settings=ctl_db, server_id=1024, only_events=(WriteRowsEvent,), - fail_on_table_metadata_unavailable=True + fail_on_table_metadata_unavailable=True, ) def tearDown(self): @@ -852,7 +884,7 @@ def test_separate_ctl_settings_table_metadata_unavailable(self): had_error = False try: - event = self.stream.fetchone() + self.stream.fetchone() except TableMetadataUnavailableError as e: had_error = True assert "test" in e.args[0] @@ -880,7 +912,9 @@ class TestGtidBinLogStreamReader(base.PyMySQLReplicationTestCase): def setUp(self): super().setUp() if not self.supportsGTID: - raise unittest.SkipTest("database does not support GTID, skipping GTID tests") + raise unittest.SkipTest( + "database does not support GTID, skipping GTID tests" + ) def test_read_query_event(self): query = "CREATE TABLE test (id INT NOT NULL, data VARCHAR (50) NOT NULL, PRIMARY KEY (id))" @@ -890,8 +924,12 @@ def test_read_query_event(self): self.stream.close() self.stream = BinLogStreamReader( - self.database, server_id=1024, blocking=True, auto_position=gtid, - ignored_events=[HeartbeatLogEvent]) + self.database, + server_id=1024, + blocking=True, + auto_position=gtid, + ignored_events=[HeartbeatLogEvent], + ) self.assertIsInstance(self.stream.fetchone(), RotateEvent) self.assertIsInstance(self.stream.fetchone(), FormatDescriptionEvent) @@ -949,8 +987,12 @@ def test_position_gtid(self): self.stream.close() self.stream = BinLogStreamReader( - self.database, server_id=1024, blocking=True, auto_position=gtid, - ignored_events=[HeartbeatLogEvent]) + self.database, + server_id=1024, + blocking=True, + auto_position=gtid, + ignored_events=[HeartbeatLogEvent], + ) self.assertIsInstance(self.stream.fetchone(), RotateEvent) self.assertIsInstance(self.stream.fetchone(), FormatDescriptionEvent) @@ -958,29 +1000,40 @@ def test_position_gtid(self): self.assertIsInstance(self.stream.fetchone(), GtidEvent) event = self.stream.fetchone() - self.assertEqual(event.query, 'CREATE TABLE test2 (id INT NOT NULL, data VARCHAR (50) NOT NULL, PRIMARY KEY (id))'); + self.assertEqual( + event.query, + "CREATE TABLE test2 (id INT NOT NULL, data VARCHAR (50) NOT NULL, PRIMARY KEY (id))", + ) class TestGtidRepresentation(unittest.TestCase): def test_gtidset_representation(self): - set_repr = '57b70f4e-20d3-11e5-a393-4a63946f7eac:1-56,' \ - '4350f323-7565-4e59-8763-4b1b83a0ce0e:1-20' + set_repr = ( + "57b70f4e-20d3-11e5-a393-4a63946f7eac:1-56," + "4350f323-7565-4e59-8763-4b1b83a0ce0e:1-20" + ) myset = GtidSet(set_repr) self.assertEqual(str(myset), set_repr) def test_gtidset_representation_newline(self): - set_repr = '57b70f4e-20d3-11e5-a393-4a63946f7eac:1-56,' \ - '4350f323-7565-4e59-8763-4b1b83a0ce0e:1-20' - mysql_repr = '57b70f4e-20d3-11e5-a393-4a63946f7eac:1-56,\n' \ - '4350f323-7565-4e59-8763-4b1b83a0ce0e:1-20' + set_repr = ( + "57b70f4e-20d3-11e5-a393-4a63946f7eac:1-56," + "4350f323-7565-4e59-8763-4b1b83a0ce0e:1-20" + ) + mysql_repr = ( + "57b70f4e-20d3-11e5-a393-4a63946f7eac:1-56,\n" + "4350f323-7565-4e59-8763-4b1b83a0ce0e:1-20" + ) myset = GtidSet(mysql_repr) self.assertEqual(str(myset), set_repr) def test_gtidset_representation_payload(self): - set_repr = '57b70f4e-20d3-11e5-a393-4a63946f7eac:1-56,' \ - '4350f323-7565-4e59-8763-4b1b83a0ce0e:1-20' + set_repr = ( + "57b70f4e-20d3-11e5-a393-4a63946f7eac:1-56," + "4350f323-7565-4e59-8763-4b1b83a0ce0e:1-20" + ) myset = GtidSet(set_repr) payload = myset.encode() @@ -988,8 +1041,10 @@ def test_gtidset_representation_payload(self): self.assertEqual(str(myset), str(parsedset)) - set_repr = '57b70f4e-20d3-11e5-a393-4a63946f7eac:1,' \ - '4350f323-7565-4e59-8763-4b1b83a0ce0e:1-20' + set_repr = ( + "57b70f4e-20d3-11e5-a393-4a63946f7eac:1," + "4350f323-7565-4e59-8763-4b1b83a0ce0e:1-20" + ) myset = GtidSet(set_repr) payload = myset.encode() @@ -1055,15 +1110,18 @@ def test_sub_interval(self): assert (gtid - within).intervals == [(1, 25), (27, 57)] def test_parsing(self): - with self.assertRaises(ValueError) as exc: - gtid = Gtid("57b70f4e-20d3-11e5-a393-4a63946f7eac:1-5 57b70f4e-20d3-11e5-a393-4a63946f7eac:1-56") - gtid = Gtid("NNNNNNNN-20d3-11e5-a393-4a63946f7eac:1-5") - gtid = Gtid("-20d3-11e5-a393-4a63946f7eac:1-5") - gtid = Gtid("-20d3-11e5-a393-4a63946f7eac:1-") - gtid = Gtid("57b70f4e-20d3-11e5-a393-4a63946f7eac:A-1") - gtid = Gtid("57b70f4e-20d3-11e5-a393-4a63946f7eac:-1") - gtid = Gtid("57b70f4e-20d3-11e5-a393-4a63946f7eac:1-:1") - gtid = Gtid("57b70f4e-20d3-11e5-a393-4a63946f7eac::1") + with self.assertRaises(ValueError): + Gtid( + "57b70f4e-20d3-11e5-a393-4a63946f7eac:1-5 57b70f4e-20d3-11e5-a393-4a63946f7eac:1-56" + ) + Gtid("NNNNNNNN-20d3-11e5-a393-4a63946f7eac:1-5") + Gtid("-20d3-11e5-a393-4a63946f7eac:1-5") + Gtid("-20d3-11e5-a393-4a63946f7eac:1-") + Gtid("57b70f4e-20d3-11e5-a393-4a63946f7eac:A-1") + Gtid("57b70f4e-20d3-11e5-a393-4a63946f7eac:-1") + Gtid("57b70f4e-20d3-11e5-a393-4a63946f7eac:1-:1") + Gtid("57b70f4e-20d3-11e5-a393-4a63946f7eac::1") + class TestStatementConnectionSetting(base.PyMySQLReplicationTestCase): def setUp(self): @@ -1073,12 +1131,14 @@ def setUp(self): self.database, server_id=1024, only_events=(RandEvent, UserVarEvent, QueryEvent), - fail_on_table_metadata_unavailable=True + fail_on_table_metadata_unavailable=True, ) self.execute("SET @@binlog_format='STATEMENT'") def test_rand_event(self): - self.execute("CREATE TABLE test (id INT NOT NULL AUTO_INCREMENT, data INT NOT NULL, PRIMARY KEY (id))") + self.execute( + "CREATE TABLE test (id INT NOT NULL AUTO_INCREMENT, data INT NOT NULL, PRIMARY KEY (id))" + ) self.execute("INSERT INTO test (data) VALUES(RAND())") self.execute("COMMIT") @@ -1092,7 +1152,9 @@ def test_rand_event(self): self.assertEqual(type(expected_rand_event.seed2), int) def test_user_var_string_event(self): - self.execute("CREATE TABLE test (id INT NOT NULL AUTO_INCREMENT, data VARCHAR(50), PRIMARY KEY (id))") + self.execute( + "CREATE TABLE test (id INT NOT NULL AUTO_INCREMENT, data VARCHAR(50), PRIMARY KEY (id))" + ) self.execute("SET @test_user_var = 'foo'") self.execute("INSERT INTO test (data) VALUES(@test_user_var)") self.execute("COMMIT") @@ -1111,7 +1173,9 @@ def test_user_var_string_event(self): self.assertEqual(expected_user_var_event.charset, 33) def test_user_var_real_event(self): - self.execute("CREATE TABLE test (id INT NOT NULL AUTO_INCREMENT, data REAL, PRIMARY KEY (id))") + self.execute( + "CREATE TABLE test (id INT NOT NULL AUTO_INCREMENT, data REAL, PRIMARY KEY (id))" + ) self.execute("SET @test_user_var = @@timestamp") self.execute("INSERT INTO test (data) VALUES(@test_user_var)") self.execute("COMMIT") @@ -1124,17 +1188,21 @@ def test_user_var_real_event(self): self.assertIsInstance(expected_user_var_event, UserVarEvent) self.assertIsInstance(expected_user_var_event.name_len, int) self.assertEqual(expected_user_var_event.name, "test_user_var") - self.assertIsInstance(expected_user_var_event.value,float) + self.assertIsInstance(expected_user_var_event.value, float) self.assertEqual(expected_user_var_event.is_null, 0) self.assertEqual(expected_user_var_event.type, 1) self.assertEqual(expected_user_var_event.charset, 33) def test_user_var_int_event(self): - self.execute("CREATE TABLE test (id INT NOT NULL AUTO_INCREMENT, data1 INT, data2 INT, data3 INT, PRIMARY KEY (id))") + self.execute( + "CREATE TABLE test (id INT NOT NULL AUTO_INCREMENT, data1 INT, data2 INT, data3 INT, PRIMARY KEY (id))" + ) self.execute("SET @test_user_var1 = 5") self.execute("SET @test_user_var2 = 0") self.execute("SET @test_user_var3 = -5") - self.execute("INSERT INTO test (data1, data2, data3) VALUES(@test_user_var1, @test_user_var2, @test_user_var3)") + self.execute( + "INSERT INTO test (data1, data2, data3) VALUES(@test_user_var1, @test_user_var2, @test_user_var3)" + ) self.execute("COMMIT") self.assertEqual(self.bin_log_format(), "STATEMENT") @@ -1169,11 +1237,15 @@ def test_user_var_int_event(self): self.assertEqual(expected_user_var_event.charset, 33) def test_user_var_int24_event(self): - self.execute("CREATE TABLE test (id INT NOT NULL AUTO_INCREMENT, data1 MEDIUMINT, data2 MEDIUMINT, data3 MEDIUMINT UNSIGNED, PRIMARY KEY (id))") + self.execute( + "CREATE TABLE test (id INT NOT NULL AUTO_INCREMENT, data1 MEDIUMINT, data2 MEDIUMINT, data3 MEDIUMINT UNSIGNED, PRIMARY KEY (id))" + ) self.execute("SET @test_user_var1 = 8388607") self.execute("SET @test_user_var2 = -8388607") self.execute("SET @test_user_var3 = 16777215") - self.execute("INSERT INTO test (data1, data2, data3) VALUES(@test_user_var1, @test_user_var2, @test_user_var3)") + self.execute( + "INSERT INTO test (data1, data2, data3) VALUES(@test_user_var1, @test_user_var2, @test_user_var3)" + ) self.execute("COMMIT") self.assertEqual(self.bin_log_format(), "STATEMENT") @@ -1208,11 +1280,15 @@ def test_user_var_int24_event(self): self.assertEqual(expected_user_var_event.charset, 33) def test_user_var_longlong_event(self): - self.execute("CREATE TABLE test (id INT NOT NULL AUTO_INCREMENT, data1 BIGINT, data2 BIGINT, data3 BIGINT UNSIGNED, PRIMARY KEY (id))") + self.execute( + "CREATE TABLE test (id INT NOT NULL AUTO_INCREMENT, data1 BIGINT, data2 BIGINT, data3 BIGINT UNSIGNED, PRIMARY KEY (id))" + ) self.execute("SET @test_user_var1 = 9223372036854775807") self.execute("SET @test_user_var2 = -9223372036854775808") self.execute("SET @test_user_var3 = 18446744073709551615") - self.execute("INSERT INTO test (data1, data2, data3) VALUES(@test_user_var1, @test_user_var2, @test_user_var3)") + self.execute( + "INSERT INTO test (data1, data2, data3) VALUES(@test_user_var1, @test_user_var2, @test_user_var3)" + ) self.execute("COMMIT") self.assertEqual(self.bin_log_format(), "STATEMENT") @@ -1247,10 +1323,14 @@ def test_user_var_longlong_event(self): self.assertEqual(expected_user_var_event.charset, 33) def test_user_var_decimal_event(self): - self.execute("CREATE TABLE test (id INT NOT NULL AUTO_INCREMENT, data1 DECIMAL, data2 DECIMAL, PRIMARY KEY (id))") + self.execute( + "CREATE TABLE test (id INT NOT NULL AUTO_INCREMENT, data1 DECIMAL, data2 DECIMAL, PRIMARY KEY (id))" + ) self.execute("SET @test_user_var1 = 5.25") self.execute("SET @test_user_var2 = -5.25") - self.execute("INSERT INTO test (data1, data2) VALUES(@test_user_var1, @test_user_var2)") + self.execute( + "INSERT INTO test (data1, data2) VALUES(@test_user_var1, @test_user_var2)" + ) self.execute("COMMIT") self.assertEqual(self.bin_log_format(), "STATEMENT") @@ -1280,14 +1360,12 @@ def tearDown(self): self.assertEqual(self.bin_log_format(), "ROW") super(TestStatementConnectionSetting, self).tearDown() + class TestMariadbBinlogStreamReader(base.PyMySQLReplicationMariaDbTestCase): def test_binlog_checkpoint_event(self): self.stream.close() self.stream = BinLogStreamReader( - self.database, - server_id=1023, - blocking=False, - is_mariadb=True + self.database, server_id=1023, blocking=False, is_mariadb=True ) query = "DROP TABLE IF EXISTS test" @@ -1298,17 +1376,23 @@ def test_binlog_checkpoint_event(self): self.stream.close() event = self.stream.fetchone() - self.assertIsInstance(event, RotateEvent) - + self.assertIsInstance(event, RotateEvent) + + event = self.stream.fetchone() + self.assertIsInstance(event, FormatDescriptionEvent) + event = self.stream.fetchone() - self.assertIsInstance(event,FormatDescriptionEvent) + self.assertIsInstance(event, MariadbStartEncryptionEvent) + + event = self.stream.fetchone() + self.assertIsInstance(event, MariadbGtidListEvent) event = self.stream.fetchone() self.assertIsInstance(event, MariadbBinLogCheckPointEvent) - self.assertEqual(event.filename, self.bin_log_basename()+".000001") + self.assertEqual(event.filename, self.bin_log_basename() + ".000001") -class TestMariadbBinlogStreamReader(base.PyMySQLReplicationMariaDbTestCase): - + +class TestMariadbBinlogStreamReader2(base.PyMySQLReplicationMariaDbTestCase): def test_annotate_rows_event(self): query = "CREATE TABLE test (id INT NOT NULL AUTO_INCREMENT, data VARCHAR (50) NOT NULL, PRIMARY KEY (id))" self.execute(query) @@ -1322,20 +1406,21 @@ def test_annotate_rows_event(self): self.stream.close() self.stream = BinLogStreamReader( - self.database, - server_id=1024, + self.database, + server_id=1024, blocking=False, only_events=[MariadbAnnotateRowsEvent], is_mariadb=True, annotate_rows_event=True, - ) - + ) + event = self.stream.fetchone() - #Check event type 160,MariadbAnnotateRowsEvent - self.assertEqual(event.event_type,160) - #Check self.sql_statement - self.assertEqual(event.sql_statement,insert_query) - self.assertIsInstance(event,MariadbAnnotateRowsEvent) + # Check event type 160,MariadbAnnotateRowsEvent + self.assertEqual(event.event_type, 160) + # Check self.sql_statement + self.assertEqual(event.sql_statement, insert_query) + self.assertIsInstance(event, MariadbAnnotateRowsEvent) + def test_start_encryption_event(self): query = "CREATE TABLE test (id INT NOT NULL AUTO_INCREMENT, data VARCHAR (50) NOT NULL, PRIMARY KEY (id))" self.execute(query) @@ -1358,7 +1443,9 @@ def test_start_encryption_event(self): encryption_key_file_path = Path(__file__).parent.parent.parent try: - with open(f"{encryption_key_file_path}/.mariadb/no_encryption_key.key", "r") as key_file: + with open( + f"{encryption_key_file_path}/.mariadb/no_encryption_key.key", "r" + ) as key_file: first_line = key_file.readline() key_version_from_key_file = int(first_line.split(";")[0]) except Exception as e: @@ -1370,17 +1457,17 @@ def test_start_encryption_event(self): self.assertEqual(schema, 1) self.assertEqual(key_version, key_version_from_key_file) self.assertEqual(type(nonce), bytes) - self.assertEqual(len(nonce), 12) + self.assertEqual(len(nonce), 12) def test_gtid_list_event(self): # set max_binlog_size to create new binlog file - query = 'SET GLOBAL max_binlog_size=4096' + query = "SET GLOBAL max_binlog_size=4096" self.execute(query) # parse only Maradb GTID list event self.stream.close() self.stream = BinLogStreamReader( - self.database, - server_id=1024, + self.database, + server_id=1024, blocking=False, only_events=[MariadbGtidListEvent], is_mariadb=True, @@ -1389,20 +1476,20 @@ def test_gtid_list_event(self): query = "CREATE TABLE test (id INT NOT NULL AUTO_INCREMENT, data VARCHAR (50) NOT NULL, PRIMARY KEY (id))" self.execute(query) query = "INSERT INTO test (data) VALUES('Hello World')" - - for cnt in range(0,15): + + for cnt in range(0, 15): self.execute(query) self.execute("COMMIT") # 'mariadb gtid list event' of first binlog file event = self.stream.fetchone() - self.assertEqual(event.event_type,163) - self.assertIsInstance(event,MariadbGtidListEvent) + self.assertEqual(event.event_type, 163) + self.assertIsInstance(event, MariadbGtidListEvent) # 'mariadb gtid list event' of second binlog file event = self.stream.fetchone() - self.assertEqual(event.event_type,163) - self.assertEqual(event.gtid_list[0].gtid, '0-1-15') + self.assertEqual(event.event_type, 163) + self.assertEqual(event.gtid_list[0].gtid, "0-1-15") class TestRowsQueryLogEvents(base.PyMySQLReplicationTestCase): @@ -1421,25 +1508,29 @@ def test_rows_query_log_event(self): server_id=1024, only_events=[RowsQueryLogEvent], ) - self.execute("CREATE TABLE IF NOT EXISTS test (id INT AUTO_INCREMENT PRIMARY KEY, name VARCHAR(255))") + self.execute( + "CREATE TABLE IF NOT EXISTS test (id INT AUTO_INCREMENT PRIMARY KEY, name VARCHAR(255))" + ) self.execute("INSERT INTO test (name) VALUES ('Soul Lee')") self.execute("COMMIT") event = self.stream.fetchone() self.assertIsInstance(event, RowsQueryLogEvent) -class TestLatin1(base.PyMySQLReplicationTestCase): +class TestLatin1(base.PyMySQLReplicationTestCase): def setUp(self): - super().setUp(charset='latin1') + super().setUp(charset="latin1") def test_query_event_latin1(self): """ Ensure query events with a non-utf8 encoded query are parsed without errors. """ - self.stream = BinLogStreamReader(self.database, server_id=1024, only_events=[QueryEvent]) + self.stream = BinLogStreamReader( + self.database, server_id=1024, only_events=[QueryEvent] + ) self.execute("CREATE TABLE test_latin1_ÖÆÛ (a INT)") self.execute("COMMIT") - assert "ÖÆÛ".encode('latin-1') == b'\xd6\xc6\xdb' + assert "ÖÆÛ".encode("latin-1") == b"\xd6\xc6\xdb" event = self.stream.fetchone() assert event.query.startswith("CREATE TABLE test") @@ -1454,7 +1545,7 @@ def setUp(self): self.database, server_id=1024, only_events=(TableMapEvent,), - fail_on_table_metadata_unavailable=True + fail_on_table_metadata_unavailable=True, ) if not self.isMySQL8014AndMore(): self.skipTest("Mysql version is under 8.0.14 - pass TestOptionalMetaData") @@ -1489,7 +1580,9 @@ def test_default_charset(self): def test_column_charset(self): create_query = "CREATE TABLE test_column_charset (col1 VARCHAR(50), col2 VARCHAR(50) CHARACTER SET binary, col3 VARCHAR(50) CHARACTER SET latin1);" - insert_query = "INSERT INTO test_column_charset VALUES ('python', 'mysql', 'replication');" + insert_query = ( + "INSERT INTO test_column_charset VALUES ('python', 'mysql', 'replication');" + ) self.execute(create_query) self.execute(insert_query) @@ -1512,7 +1605,10 @@ def test_column_name(self): event = self.stream.fetchone() self.assertIsInstance(event, TableMapEvent) - self.assertEqual(event.optional_metadata.column_name_list, ['col_int', 'col_varchar', 'col_bool']) + self.assertEqual( + event.optional_metadata.column_name_list, + ["col_int", "col_varchar", "col_bool"], + ) def test_set_str_value(self): create_query = "CREATE TABLE test_set_str_value (skills SET('Programming', 'Writing', 'Design'));" @@ -1524,7 +1620,10 @@ def test_set_str_value(self): event = self.stream.fetchone() self.assertIsInstance(event, TableMapEvent) - self.assertEqual(event.optional_metadata.set_str_value_list, [['Programming', 'Writing', 'Design']]) + self.assertEqual( + event.optional_metadata.set_str_value_list, + [["Programming", "Writing", "Design"]], + ) def test_enum_str_value(self): create_query = "CREATE TABLE test_enum_str_value (pet ENUM('Dog', 'Cat'));" @@ -1536,7 +1635,9 @@ def test_enum_str_value(self): event = self.stream.fetchone() self.assertIsInstance(event, TableMapEvent) - self.assertEqual(event.optional_metadata.set_enum_str_value_list, [['Dog', 'Cat']]) + self.assertEqual( + event.optional_metadata.set_enum_str_value_list, [["Dog", "Cat"]] + ) def test_geometry_type(self): create_query = "CREATE TABLE test_geometry_type (location POINT);" @@ -1564,7 +1665,9 @@ def test_simple_primary_key(self): def test_primary_key_with_prefix(self): create_query = "CREATE TABLE test_primary_key_with_prefix (c_key1 CHAR(100), c_key2 CHAR(10), c_not_key INT, c_key3 CHAR(100), PRIMARY KEY(c_key1(5), c_key2, c_key3(10)));" - insert_query = "INSERT INTO test_primary_key_with_prefix VALUES('1', '2', 3, '4');" + insert_query = ( + "INSERT INTO test_primary_key_with_prefix VALUES('1', '2', 3, '4');" + ) self.execute(create_query) self.execute(insert_query) @@ -1572,11 +1675,15 @@ def test_primary_key_with_prefix(self): event = self.stream.fetchone() self.assertIsInstance(event, TableMapEvent) - self.assertEqual(event.optional_metadata.primary_keys_with_prefix, {0: 5, 1: 0, 3: 10}) + self.assertEqual( + event.optional_metadata.primary_keys_with_prefix, {0: 5, 1: 0, 3: 10} + ) def test_enum_and_set_default_charset(self): create_query = "CREATE TABLE test_enum_and_set_default_charset (pet ENUM('Dog', 'Cat'), skills SET('Programming', 'Writing', 'Design')) CHARACTER SET utf8mb4;" - insert_query = "INSERT INTO test_enum_and_set_default_charset VALUES('Dog', 'Design');" + insert_query = ( + "INSERT INTO test_enum_and_set_default_charset VALUES('Dog', 'Design');" + ) self.execute(create_query) self.execute(insert_query) @@ -1585,13 +1692,19 @@ def test_enum_and_set_default_charset(self): event = self.stream.fetchone() self.assertIsInstance(event, TableMapEvent) if self.isMariaDB(): - self.assertEqual(event.optional_metadata.enum_and_set_collation_list, [45, 45]) + self.assertEqual( + event.optional_metadata.enum_and_set_collation_list, [45, 45] + ) else: - self.assertEqual(event.optional_metadata.enum_and_set_collation_list, [255, 255]) + self.assertEqual( + event.optional_metadata.enum_and_set_collation_list, [255, 255] + ) def test_enum_and_set_column_charset(self): create_query = "CREATE TABLE test_enum_and_set_column_charset (pet ENUM('Dog', 'Cat') CHARACTER SET utf8mb4, number SET('00', '01', '10', '11') CHARACTER SET binary);" - insert_query = "INSERT INTO test_enum_and_set_column_charset VALUES('Cat', '10');" + insert_query = ( + "INSERT INTO test_enum_and_set_column_charset VALUES('Cat', '10');" + ) self.execute(create_query) self.execute(insert_query) @@ -1600,9 +1713,13 @@ def test_enum_and_set_column_charset(self): event = self.stream.fetchone() self.assertIsInstance(event, TableMapEvent) if self.isMariaDB(): - self.assertEqual(event.optional_metadata.enum_and_set_collation_list, [45, 63]) + self.assertEqual( + event.optional_metadata.enum_and_set_collation_list, [45, 63] + ) else: - self.assertEqual(event.optional_metadata.enum_and_set_collation_list, [255, 63]) + self.assertEqual( + event.optional_metadata.enum_and_set_collation_list, [255, 63] + ) def test_visibility(self): create_query = "CREATE TABLE test_visibility (name VARCHAR(50), secret_key VARCHAR(50) DEFAULT 'qwerty' INVISIBLE);" @@ -1621,6 +1738,8 @@ def tearDown(self): self.execute("SET GLOBAL binlog_row_metadata='MINIMAL';") super(TestOptionalMetaData, self).tearDown() + if __name__ == "__main__": import unittest + unittest.main() diff --git a/pymysqlreplication/tests/test_data_objects.py b/pymysqlreplication/tests/test_data_objects.py index dc1281c3..3f9d1cad 100644 --- a/pymysqlreplication/tests/test_data_objects.py +++ b/pymysqlreplication/tests/test_data_objects.py @@ -1,4 +1,5 @@ import sys + if sys.version_info < (2, 7): import unittest2 as unittest else: @@ -18,42 +19,54 @@ def ignoredEvents(self): return [GtidEvent] def test_column_is_primary(self): - col = Column(1, - {"COLUMN_NAME": "test", - "COLLATION_NAME": "utf8_general_ci", - "CHARACTER_SET_NAME": "UTF8", - "CHARACTER_OCTET_LENGTH": None, - "DATA_TYPE": "tinyint", - "COLUMN_COMMENT": "", - "COLUMN_TYPE": "tinyint(2)", - "COLUMN_KEY": "PRI"}, - None) + col = Column( + 1, + { + "COLUMN_NAME": "test", + "COLLATION_NAME": "utf8_general_ci", + "CHARACTER_SET_NAME": "UTF8", + "CHARACTER_OCTET_LENGTH": None, + "DATA_TYPE": "tinyint", + "COLUMN_COMMENT": "", + "COLUMN_TYPE": "tinyint(2)", + "COLUMN_KEY": "PRI", + }, + None, + ) self.assertEqual(True, col.is_primary) def test_column_not_primary(self): - col = Column(1, - {"COLUMN_NAME": "test", - "COLLATION_NAME": "utf8_general_ci", - "CHARACTER_SET_NAME": "UTF8", - "CHARACTER_OCTET_LENGTH": None, - "DATA_TYPE": "tinyint", - "COLUMN_COMMENT": "", - "COLUMN_TYPE": "tinyint(2)", - "COLUMN_KEY": ""}, - None) + col = Column( + 1, + { + "COLUMN_NAME": "test", + "COLLATION_NAME": "utf8_general_ci", + "CHARACTER_SET_NAME": "UTF8", + "CHARACTER_OCTET_LENGTH": None, + "DATA_TYPE": "tinyint", + "COLUMN_COMMENT": "", + "COLUMN_TYPE": "tinyint(2)", + "COLUMN_KEY": "", + }, + None, + ) self.assertEqual(False, col.is_primary) def test_column_serializable(self): - col = Column(1, - {"COLUMN_NAME": "test", - "COLLATION_NAME": "utf8_general_ci", - "CHARACTER_SET_NAME": "UTF8", - "CHARACTER_OCTET_LENGTH": None, - "DATA_TYPE": "tinyint", - "COLUMN_COMMENT": "", - "COLUMN_TYPE": "tinyint(2)", - "COLUMN_KEY": "PRI"}, - None) + col = Column( + 1, + { + "COLUMN_NAME": "test", + "COLLATION_NAME": "utf8_general_ci", + "CHARACTER_SET_NAME": "UTF8", + "CHARACTER_OCTET_LENGTH": None, + "DATA_TYPE": "tinyint", + "COLUMN_COMMENT": "", + "COLUMN_TYPE": "tinyint(2)", + "COLUMN_KEY": "PRI", + }, + None, + ) serialized = col.serializable_data() self.assertIn("type", serialized) diff --git a/pymysqlreplication/tests/test_data_type.py b/pymysqlreplication/tests/test_data_type.py index a1e4001a..ed30cf9c 100644 --- a/pymysqlreplication/tests/test_data_type.py +++ b/pymysqlreplication/tests/test_data_type.py @@ -3,6 +3,8 @@ import copy import platform import sys +import json + if sys.version_info < (2, 7): import unittest2 as unittest else: @@ -27,6 +29,7 @@ def encode_value(v): if isinstance(v, list): return [encode_value(x) for x in v] return v + return dict([(k.encode(), encode_value(v)) for (k, v) in d.items()]) @@ -41,10 +44,10 @@ def create_and_insert_value(self, create_query, insert_query): self.assertIsInstance(self.stream.fetchone(), RotateEvent) self.assertIsInstance(self.stream.fetchone(), FormatDescriptionEvent) - #QueryEvent for the Create Table + # QueryEvent for the Create Table self.assertIsInstance(self.stream.fetchone(), QueryEvent) - #QueryEvent for the BEGIN + # QueryEvent for the BEGIN self.assertIsInstance(self.stream.fetchone(), QueryEvent) self.assertIsInstance(self.stream.fetchone(), TableMapEvent) @@ -58,7 +61,7 @@ def create_and_insert_value(self, create_query, insert_query): return event def create_table(self, create_query): - """Create table + """Create table Create table in db and return query event. @@ -89,10 +92,10 @@ def create_and_get_tablemap_event(self, bit): self.assertIsInstance(self.stream.fetchone(), RotateEvent) self.assertIsInstance(self.stream.fetchone(), FormatDescriptionEvent) - #QueryEvent for the Create Table + # QueryEvent for the Create Table self.assertIsInstance(self.stream.fetchone(), QueryEvent) - #QueryEvent for the BEGIN + # QueryEvent for the BEGIN self.assertIsInstance(self.stream.fetchone(), QueryEvent) event = self.stream.fetchone() @@ -105,13 +108,13 @@ def test_varbinary(self): create_query = "CREATE TABLE test(b VARBINARY(4))" insert_query = "INSERT INTO test VALUES(UNHEX('ff010000'))" event = self.create_and_insert_value(create_query, insert_query) - self.assertEqual(event.rows[0]["values"]["b"], b'\xff\x01\x00\x00') + self.assertEqual(event.rows[0]["values"]["b"], b"\xff\x01\x00\x00") def test_fixed_length_binary(self): create_query = "CREATE TABLE test(b BINARY(4))" insert_query = "INSERT INTO test VALUES(UNHEX('ff010000'))" event = self.create_and_insert_value(create_query, insert_query) - self.assertEqual(event.rows[0]["values"]["b"], b'\xff\x01\x00\x00') + self.assertEqual(event.rows[0]["values"]["b"], b"\xff\x01\x00\x00") def test_decimal(self): create_query = "CREATE TABLE test (test DECIMAL(2,1))" @@ -143,8 +146,9 @@ def test_decimal_long_values_2(self): )" insert_query = "INSERT INTO test VALUES(9000000123.0000012345)" event = self.create_and_insert_value(create_query, insert_query) - self.assertEqual(event.rows[0]["values"]["test"], - Decimal("9000000123.0000012345")) + self.assertEqual( + event.rows[0]["values"]["test"], Decimal("9000000123.0000012345") + ) def test_decimal_negative_values(self): create_query = "CREATE TABLE test (\ @@ -174,7 +178,9 @@ def test_decimal_with_zero_scale_2(self): create_query = "CREATE TABLE test (test DECIMAL(23,0))" insert_query = "INSERT INTO test VALUES(12345678912345678912345)" event = self.create_and_insert_value(create_query, insert_query) - self.assertEqual(event.rows[0]["values"]["test"], Decimal("12345678912345678912345")) + self.assertEqual( + event.rows[0]["values"]["test"], Decimal("12345678912345678912345") + ) def test_decimal_with_zero_scale_3(self): create_query = "CREATE TABLE test (test DECIMAL(23,0))" @@ -192,7 +198,9 @@ def test_decimal_with_zero_scale_6(self): create_query = "CREATE TABLE test (test DECIMAL(23,0))" insert_query = "INSERT INTO test VALUES(-1234567891234567891234)" event = self.create_and_insert_value(create_query, insert_query) - self.assertEqual(event.rows[0]["values"]["test"], Decimal("-1234567891234567891234")) + self.assertEqual( + event.rows[0]["values"]["test"], Decimal("-1234567891234567891234") + ) def test_tiny(self): create_query = "CREATE TABLE test (id TINYINT UNSIGNED NOT NULL, test TINYINT)" @@ -232,7 +240,9 @@ def test_tiny_maps_to_none_2(self): self.assertEqual(event.rows[0]["values"]["test"], None) def test_short(self): - create_query = "CREATE TABLE test (id SMALLINT UNSIGNED NOT NULL, test SMALLINT)" + create_query = ( + "CREATE TABLE test (id SMALLINT UNSIGNED NOT NULL, test SMALLINT)" + ) insert_query = "INSERT INTO test VALUES(65535, -32768)" event = self.create_and_insert_value(create_query, insert_query) self.assertEqual(event.rows[0]["values"]["id"], 65535) @@ -250,51 +260,75 @@ def test_float(self): insert_query = "INSERT INTO test VALUES(42.42, -84.84)" event = self.create_and_insert_value(create_query, insert_query) self.assertEqual(round(event.rows[0]["values"]["id"], 2), 42.42) - self.assertEqual(round(event.rows[0]["values"]["test"],2 ), -84.84) + self.assertEqual(round(event.rows[0]["values"]["test"], 2), -84.84) def test_double(self): create_query = "CREATE TABLE test (id DOUBLE NOT NULL, test DOUBLE)" insert_query = "INSERT INTO test VALUES(42.42, -84.84)" event = self.create_and_insert_value(create_query, insert_query) self.assertEqual(round(event.rows[0]["values"]["id"], 2), 42.42) - self.assertEqual(round(event.rows[0]["values"]["test"],2 ), -84.84) + self.assertEqual(round(event.rows[0]["values"]["test"], 2), -84.84) def test_timestamp(self): create_query = "CREATE TABLE test (test TIMESTAMP);" insert_query = "INSERT INTO test VALUES('1984-12-03 12:33:07')" event = self.create_and_insert_value(create_query, insert_query) - self.assertEqual(event.rows[0]["values"]["test"], datetime.datetime(1984, 12, 3, 12, 33, 7)) + self.assertEqual( + event.rows[0]["values"]["test"], datetime.datetime(1984, 12, 3, 12, 33, 7) + ) def test_timestamp_mysql56(self): if not self.isMySQL56AndMore(): self.skipTest("Not supported in this version of MySQL") self.set_sql_mode() - create_query = '''CREATE TABLE test (test0 TIMESTAMP(0), + create_query = """CREATE TABLE test (test0 TIMESTAMP(0), test1 TIMESTAMP(1), test2 TIMESTAMP(2), test3 TIMESTAMP(3), test4 TIMESTAMP(4), test5 TIMESTAMP(5), - test6 TIMESTAMP(6));''' - insert_query = '''INSERT INTO test VALUES('1984-12-03 12:33:07', + test6 TIMESTAMP(6));""" + insert_query = """INSERT INTO test VALUES('1984-12-03 12:33:07', '1984-12-03 12:33:07.1', '1984-12-03 12:33:07.12', '1984-12-03 12:33:07.123', '1984-12-03 12:33:07.1234', '1984-12-03 12:33:07.12345', - '1984-12-03 12:33:07.123456')''' - event = self.create_and_insert_value(create_query, insert_query) - self.assertEqual(event.rows[0]["values"]["test0"], datetime.datetime(1984, 12, 3, 12, 33, 7)) - self.assertEqual(event.rows[0]["values"]["test1"], datetime.datetime(1984, 12, 3, 12, 33, 7, 100000)) - self.assertEqual(event.rows[0]["values"]["test2"], datetime.datetime(1984, 12, 3, 12, 33, 7, 120000)) - self.assertEqual(event.rows[0]["values"]["test3"], datetime.datetime(1984, 12, 3, 12, 33, 7, 123000)) - self.assertEqual(event.rows[0]["values"]["test4"], datetime.datetime(1984, 12, 3, 12, 33, 7, 123400)) - self.assertEqual(event.rows[0]["values"]["test5"], datetime.datetime(1984, 12, 3, 12, 33, 7, 123450)) - self.assertEqual(event.rows[0]["values"]["test6"], datetime.datetime(1984, 12, 3, 12, 33, 7, 123456)) + '1984-12-03 12:33:07.123456')""" + event = self.create_and_insert_value(create_query, insert_query) + self.assertEqual( + event.rows[0]["values"]["test0"], datetime.datetime(1984, 12, 3, 12, 33, 7) + ) + self.assertEqual( + event.rows[0]["values"]["test1"], + datetime.datetime(1984, 12, 3, 12, 33, 7, 100000), + ) + self.assertEqual( + event.rows[0]["values"]["test2"], + datetime.datetime(1984, 12, 3, 12, 33, 7, 120000), + ) + self.assertEqual( + event.rows[0]["values"]["test3"], + datetime.datetime(1984, 12, 3, 12, 33, 7, 123000), + ) + self.assertEqual( + event.rows[0]["values"]["test4"], + datetime.datetime(1984, 12, 3, 12, 33, 7, 123400), + ) + self.assertEqual( + event.rows[0]["values"]["test5"], + datetime.datetime(1984, 12, 3, 12, 33, 7, 123450), + ) + self.assertEqual( + event.rows[0]["values"]["test6"], + datetime.datetime(1984, 12, 3, 12, 33, 7, 123456), + ) def test_longlong(self): create_query = "CREATE TABLE test (id BIGINT UNSIGNED NOT NULL, test BIGINT)" - insert_query = "INSERT INTO test VALUES(18446744073709551615, -9223372036854775808)" + insert_query = ( + "INSERT INTO test VALUES(18446744073709551615, -9223372036854775808)" + ) event = self.create_and_insert_value(create_query, insert_query) self.assertEqual(event.rows[0]["values"]["id"], 18446744073709551615) self.assertEqual(event.rows[0]["values"]["test"], -9223372036854775808) @@ -343,12 +377,14 @@ def test_time(self): create_query = "CREATE TABLE test (test1 TIME, test2 TIME);" insert_query = "INSERT INTO test VALUES('838:59:59', '-838:59:59')" event = self.create_and_insert_value(create_query, insert_query) - self.assertEqual(event.rows[0]["values"]["test1"], datetime.timedelta( - microseconds=(((838*60) + 59)*60 + 59)*1000000 - )) - self.assertEqual(event.rows[0]["values"]["test2"], datetime.timedelta( - microseconds=-(((838*60) + 59)*60 + 59)*1000000 - )) + self.assertEqual( + event.rows[0]["values"]["test1"], + datetime.timedelta(microseconds=(((838 * 60) + 59) * 60 + 59) * 1000000), + ) + self.assertEqual( + event.rows[0]["values"]["test2"], + datetime.timedelta(microseconds=-(((838 * 60) + 59) * 60 + 59) * 1000000), + ) def test_time2(self): if not self.isMySQL56AndMore(): @@ -358,12 +394,18 @@ def test_time2(self): INSERT INTO test VALUES('838:59:59.000000', '-838:59:59.000000'); """ event = self.create_and_insert_value(create_query, insert_query) - self.assertEqual(event.rows[0]["values"]["test1"], datetime.timedelta( - microseconds=(((838*60) + 59)*60 + 59)*1000000 + 0 - )) - self.assertEqual(event.rows[0]["values"]["test2"], datetime.timedelta( - microseconds=-(((838*60) + 59)*60 + 59)*1000000 + 0 - )) + self.assertEqual( + event.rows[0]["values"]["test1"], + datetime.timedelta( + microseconds=(((838 * 60) + 59) * 60 + 59) * 1000000 + 0 + ), + ) + self.assertEqual( + event.rows[0]["values"]["test2"], + datetime.timedelta( + microseconds=-(((838 * 60) + 59) * 60 + 59) * 1000000 + 0 + ), + ) def test_zero_time(self): create_query = "CREATE TABLE test (id INTEGER, test TIME NOT NULL DEFAULT 0);" @@ -375,11 +417,15 @@ def test_datetime(self): create_query = "CREATE TABLE test (test DATETIME);" insert_query = "INSERT INTO test VALUES('1984-12-03 12:33:07')" event = self.create_and_insert_value(create_query, insert_query) - self.assertEqual(event.rows[0]["values"]["test"], datetime.datetime(1984, 12, 3, 12, 33, 7)) + self.assertEqual( + event.rows[0]["values"]["test"], datetime.datetime(1984, 12, 3, 12, 33, 7) + ) def test_zero_datetime(self): self.set_sql_mode() - create_query = "CREATE TABLE test (id INTEGER, test DATETIME NOT NULL DEFAULT 0);" + create_query = ( + "CREATE TABLE test (id INTEGER, test DATETIME NOT NULL DEFAULT 0);" + ) insert_query = "INSERT INTO test (id) VALUES(1)" event = self.create_and_insert_value(create_query, insert_query) self.assertEqual(event.rows[0]["values"]["test"], None) @@ -405,7 +451,7 @@ def test_varchar(self): create_query = "CREATE TABLE test (test VARCHAR(242)) CHARACTER SET latin1 COLLATE latin1_bin;" insert_query = "INSERT INTO test VALUES('Hello')" event = self.create_and_insert_value(create_query, insert_query) - self.assertEqual(event.rows[0]["values"]["test"], 'Hello') + self.assertEqual(event.rows[0]["values"]["test"], "Hello") self.assertEqual(event.columns[0].max_length, 242) def test_bit(self): @@ -431,73 +477,80 @@ def test_bit(self): self.assertEqual(event.rows[0]["values"]["test2"], "1000101010111000") self.assertEqual(event.rows[0]["values"]["test3"], "100010101101") self.assertEqual(event.rows[0]["values"]["test4"], "101100111") - self.assertEqual(event.rows[0]["values"]["test5"], "1101011010110100100111100011010100010100101110111011101011011010") + self.assertEqual( + event.rows[0]["values"]["test5"], + "1101011010110100100111100011010100010100101110111011101011011010", + ) def test_enum(self): create_query = "CREATE TABLE test (test ENUM('a', 'ba', 'c'), test2 ENUM('a', 'ba', 'c')) CHARACTER SET latin1 COLLATE latin1_bin;" insert_query = "INSERT INTO test VALUES('ba', 'a')" event = self.create_and_insert_value(create_query, insert_query) - self.assertEqual(event.rows[0]["values"]["test"], 'ba') - self.assertEqual(event.rows[0]["values"]["test2"], 'a') + self.assertEqual(event.rows[0]["values"]["test"], "ba") + self.assertEqual(event.rows[0]["values"]["test2"], "a") def test_enum_empty_string(self): create_query = "CREATE TABLE test (test ENUM('a', 'ba', 'c'), test2 ENUM('a', 'ba', 'c')) CHARACTER SET latin1 COLLATE latin1_bin;" insert_query = "INSERT INTO test VALUES('ba', 'asdf')" - last_sql_mode = self.execute("SELECT @@SESSION.sql_mode;"). \ - fetchall()[0][0] + last_sql_mode = self.execute("SELECT @@SESSION.sql_mode;").fetchall()[0][0] self.execute("SET SESSION sql_mode = 'ANSI';") event = self.create_and_insert_value(create_query, insert_query) self.execute("SET SESSION sql_mode = '%s';" % last_sql_mode) - self.assertEqual(event.rows[0]["values"]["test"], 'ba') - self.assertEqual(event.rows[0]["values"]["test2"], '') + self.assertEqual(event.rows[0]["values"]["test"], "ba") + self.assertEqual(event.rows[0]["values"]["test2"], "") def test_set(self): create_query = "CREATE TABLE test (test SET('a', 'ba', 'c'), test2 SET('a', 'ba', 'c')) CHARACTER SET latin1 COLLATE latin1_bin;" insert_query = "INSERT INTO test VALUES('ba,a,c', 'a,c')" event = self.create_and_insert_value(create_query, insert_query) - self.assertEqual(event.rows[0]["values"]["test"], set(('a', 'ba', 'c'))) - self.assertEqual(event.rows[0]["values"]["test2"], set(('a', 'c'))) + self.assertEqual(event.rows[0]["values"]["test"], set(("a", "ba", "c"))) + self.assertEqual(event.rows[0]["values"]["test2"], set(("a", "c"))) def test_tiny_blob(self): create_query = "CREATE TABLE test (test TINYBLOB, test2 TINYTEXT) CHARACTER SET latin1 COLLATE latin1_bin;" insert_query = "INSERT INTO test VALUES('Hello', 'World')" event = self.create_and_insert_value(create_query, insert_query) - self.assertEqual(event.rows[0]["values"]["test"], b'Hello') - self.assertEqual(event.rows[0]["values"]["test2"], 'World') + self.assertEqual(event.rows[0]["values"]["test"], b"Hello") + self.assertEqual(event.rows[0]["values"]["test2"], "World") def test_medium_blob(self): create_query = "CREATE TABLE test (test MEDIUMBLOB, test2 MEDIUMTEXT) CHARACTER SET latin1 COLLATE latin1_bin;" insert_query = "INSERT INTO test VALUES('Hello', 'World')" event = self.create_and_insert_value(create_query, insert_query) - self.assertEqual(event.rows[0]["values"]["test"], b'Hello') - self.assertEqual(event.rows[0]["values"]["test2"], 'World') + self.assertEqual(event.rows[0]["values"]["test"], b"Hello") + self.assertEqual(event.rows[0]["values"]["test2"], "World") def test_long_blob(self): create_query = "CREATE TABLE test (test LONGBLOB, test2 LONGTEXT) CHARACTER SET latin1 COLLATE latin1_bin;" insert_query = "INSERT INTO test VALUES('Hello', 'World')" event = self.create_and_insert_value(create_query, insert_query) - self.assertEqual(event.rows[0]["values"]["test"], b'Hello') - self.assertEqual(event.rows[0]["values"]["test2"], 'World') + self.assertEqual(event.rows[0]["values"]["test"], b"Hello") + self.assertEqual(event.rows[0]["values"]["test2"], "World") def test_blob(self): create_query = "CREATE TABLE test (test BLOB, test2 TEXT) CHARACTER SET latin1 COLLATE latin1_bin;" insert_query = "INSERT INTO test VALUES('Hello', 'World')" event = self.create_and_insert_value(create_query, insert_query) - self.assertEqual(event.rows[0]["values"]["test"], b'Hello') - self.assertEqual(event.rows[0]["values"]["test2"], 'World') + self.assertEqual(event.rows[0]["values"]["test"], b"Hello") + self.assertEqual(event.rows[0]["values"]["test2"], "World") def test_string(self): - create_query = "CREATE TABLE test (test CHAR(12)) CHARACTER SET latin1 COLLATE latin1_bin;" + create_query = ( + "CREATE TABLE test (test CHAR(12)) CHARACTER SET latin1 COLLATE latin1_bin;" + ) insert_query = "INSERT INTO test VALUES('Hello')" event = self.create_and_insert_value(create_query, insert_query) - self.assertEqual(event.rows[0]["values"]["test"], 'Hello') + self.assertEqual(event.rows[0]["values"]["test"], "Hello") def test_geometry(self): create_query = "CREATE TABLE test (test GEOMETRY);" insert_query = "INSERT INTO test VALUES(GeomFromText('POINT(1 1)'))" event = self.create_and_insert_value(create_query, insert_query) - self.assertEqual(event.rows[0]["values"]["test"], b'\x00\x00\x00\x00\x01\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\xf0?\x00\x00\x00\x00\x00\x00\xf0?') + self.assertEqual( + event.rows[0]["values"]["test"], + b"\x00\x00\x00\x00\x01\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\xf0?\x00\x00\x00\x00\x00\x00\xf0?", + ) def test_json(self): if not self.isMySQL57(): @@ -505,22 +558,31 @@ def test_json(self): create_query = "CREATE TABLE test (id int, value json);" insert_query = """INSERT INTO test (id, value) VALUES (1, '{"my_key": "my_val", "my_key2": "my_val2"}');""" event = self.create_and_insert_value(create_query, insert_query) - self.assertEqual(event.rows[0]["values"]["value"], {b"my_key": b"my_val", b"my_key2": b"my_val2"}) + self.assertEqual( + event.rows[0]["values"]["value"], + {b"my_key": b"my_val", b"my_key2": b"my_val2"}, + ) def test_json_array(self): if not self.isMySQL57(): self.skipTest("Json is only supported in mysql 5.7") create_query = "CREATE TABLE test (id int, value json);" - insert_query = """INSERT INTO test (id, value) VALUES (1, '["my_val", "my_val2"]');""" + insert_query = ( + """INSERT INTO test (id, value) VALUES (1, '["my_val", "my_val2"]');""" + ) event = self.create_and_insert_value(create_query, insert_query) - self.assertEqual(event.rows[0]["values"]["value"], [b'my_val', b'my_val2']) + self.assertEqual(event.rows[0]["values"]["value"], [b"my_val", b"my_val2"]) def test_json_large(self): if not self.isMySQL57(): self.skipTest("Json is only supported in mysql 5.7") - data = dict([('foooo%i'%i, 'baaaaar%i'%i) for i in range(2560)]) # Make it large enough to reach 2^16 length + data = dict( + [("foooo%i" % i, "baaaaar%i" % i) for i in range(2560)] + ) # Make it large enough to reach 2^16 length create_query = "CREATE TABLE test (id int, value json);" - insert_query = """INSERT INTO test (id, value) VALUES (1, '%s');""" % json.dumps(data) + insert_query = ( + """INSERT INTO test (id, value) VALUES (1, '%s');""" % json.dumps(data) + ) event = self.create_and_insert_value(create_query, insert_query) self.assertEqual(event.rows[0]["values"]["value"], to_binary_dict(data)) @@ -531,16 +593,22 @@ def test_json_large_array(self): self.skipTest("Json is only supported in mysql 5.7") create_query = "CREATE TABLE test (id int, value json);" large_array = dict(my_key=[i for i in range(100000)]) - insert_query = "INSERT INTO test (id, value) VALUES (1, '%s');" % (json.dumps(large_array),) + insert_query = "INSERT INTO test (id, value) VALUES (1, '%s');" % ( + json.dumps(large_array), + ) event = self.create_and_insert_value(create_query, insert_query) self.assertEqual(event.rows[0]["values"]["value"], to_binary_dict(large_array)) def test_json_large_with_literal(self): if not self.isMySQL57(): self.skipTest("Json is only supported in mysql 5.7") - data = dict([('foooo%i'%i, 'baaaaar%i'%i) for i in range(2560)], literal=True) # Make it large with literal + data = dict( + [("foooo%i" % i, "baaaaar%i" % i) for i in range(2560)], literal=True + ) # Make it large with literal create_query = "CREATE TABLE test (id int, value json);" - insert_query = """INSERT INTO test (id, value) VALUES (1, '%s');""" % json.dumps(data) + insert_query = ( + """INSERT INTO test (id, value) VALUES (1, '%s');""" % json.dumps(data) + ) event = self.create_and_insert_value(create_query, insert_query) self.assertEqual(event.rows[0]["values"]["value"], to_binary_dict(data)) @@ -550,23 +618,25 @@ def test_json_types(self): self.skipTest("Json is only supported in mysql 5.7") types = [ - True, - False, - None, - 1.2, - 2^14, - 2^30, - 2^62, - -1 * 2^14, - -1 * 2^30, - -1 * 2^62, - ['foo', 'bar'] + True, + False, + None, + 1.2, + 2 ^ 14, + 2 ^ 30, + 2 ^ 62, + -1 * 2 ^ 14, + -1 * 2 ^ 30, + -1 * 2 ^ 62, + ["foo", "bar"], ] for t in types: - data = {'foo': t} + data = {"foo": t} create_query = "CREATE TABLE test (id int, value json);" - insert_query = """INSERT INTO test (id, value) VALUES (1, '%s');""" % json.dumps(data) + insert_query = ( + """INSERT INTO test (id, value) VALUES (1, '%s');""" % json.dumps(data) + ) event = self.create_and_insert_value(create_query, insert_query) self.assertEqual(event.rows[0]["values"]["value"], to_binary_dict(data)) @@ -578,21 +648,23 @@ def test_json_basic(self): self.skipTest("Json is only supported in mysql 5.7") types = [ - True, - False, - None, - 1.2, - 2^14, - 2^30, - 2^62, - -1 * 2^14, - -1 * 2^30, - -1 * 2^62, + True, + False, + None, + 1.2, + 2 ^ 14, + 2 ^ 30, + 2 ^ 62, + -1 * 2 ^ 14, + -1 * 2 ^ 30, + -1 * 2 ^ 62, ] for data in types: create_query = "CREATE TABLE test (id int, value json);" - insert_query = """INSERT INTO test (id, value) VALUES (1, '%s');""" % json.dumps(data) + insert_query = ( + """INSERT INTO test (id, value) VALUES (1, '%s');""" % json.dumps(data) + ) event = self.create_and_insert_value(create_query, insert_query) self.assertEqual(event.rows[0]["values"]["value"], data) @@ -603,9 +675,9 @@ def test_json_unicode(self): if not self.isMySQL57(): self.skipTest("Json is only supported in mysql 5.7") create_query = "CREATE TABLE test (id int, value json);" - insert_query = u"""INSERT INTO test (id, value) VALUES (1, '{"miam": "🍔"}');""" + insert_query = """INSERT INTO test (id, value) VALUES (1, '{"miam": "🍔"}');""" event = self.create_and_insert_value(create_query, insert_query) - self.assertEqual(event.rows[0]["values"]["value"][b"miam"], u'🍔'.encode('utf8')) + self.assertEqual(event.rows[0]["values"]["value"][b"miam"], "🍔".encode("utf8")) def test_json_long_string(self): if not self.isMySQL57(): @@ -613,9 +685,14 @@ def test_json_long_string(self): create_query = "CREATE TABLE test (id int, value json);" # The string length needs to be larger than what can fit in a single byte. string_value = "super_long_string" * 100 - insert_query = "INSERT INTO test (id, value) VALUES (1, '{\"my_key\": \"%s\"}');" % (string_value,) + insert_query = ( + 'INSERT INTO test (id, value) VALUES (1, \'{"my_key": "%s"}\');' + % (string_value,) + ) event = self.create_and_insert_value(create_query, insert_query) - self.assertEqual(event.rows[0]["values"]["value"], to_binary_dict({"my_key": string_value})) + self.assertEqual( + event.rows[0]["values"]["value"], to_binary_dict({"my_key": string_value}) + ) def test_null(self): create_query = "CREATE TABLE test ( \ @@ -658,19 +735,23 @@ def test_encoding_latin1(self): else: string = "\u00e9" - create_query = "CREATE TABLE test (test CHAR(12)) CHARACTER SET latin1 COLLATE latin1_bin;" - insert_query = b"INSERT INTO test VALUES('" + string.encode('latin-1') + b"');" + create_query = ( + "CREATE TABLE test (test CHAR(12)) CHARACTER SET latin1 COLLATE latin1_bin;" + ) + insert_query = b"INSERT INTO test VALUES('" + string.encode("latin-1") + b"');" event = self.create_and_insert_value(create_query, insert_query) self.assertEqual(event.rows[0]["values"]["test"], string) def test_encoding_utf8(self): if platform.python_version_tuple()[0] == "2": - string = unichr(0x20ac) + string = unichr(0x20AC) else: string = "\u20ac" - create_query = "CREATE TABLE test (test CHAR(12)) CHARACTER SET utf8 COLLATE utf8_bin;" - insert_query = b"INSERT INTO test VALUES('" + string.encode('utf-8') + b"')" + create_query = ( + "CREATE TABLE test (test CHAR(12)) CHARACTER SET utf8 COLLATE utf8_bin;" + ) + insert_query = b"INSERT INTO test VALUES('" + string.encode("utf-8") + b"')" event = self.create_and_insert_value(create_query, insert_query) self.assertMultiLineEqual(event.rows[0]["values"]["test"], string) @@ -683,13 +764,15 @@ def test_zerofill(self): test4 INT UNSIGNED ZEROFILL DEFAULT NULL, \ test5 BIGINT UNSIGNED ZEROFILL DEFAULT NULL \ )" - insert_query = "INSERT INTO test (test, test2, test3, test4, test5) VALUES(1, 1, 1, 1, 1)" + insert_query = ( + "INSERT INTO test (test, test2, test3, test4, test5) VALUES(1, 1, 1, 1, 1)" + ) event = self.create_and_insert_value(create_query, insert_query) - self.assertEqual(event.rows[0]["values"]["test"], '001') - self.assertEqual(event.rows[0]["values"]["test2"], '00001') - self.assertEqual(event.rows[0]["values"]["test3"], '00000001') - self.assertEqual(event.rows[0]["values"]["test4"], '0000000001') - self.assertEqual(event.rows[0]["values"]["test5"], '00000000000000000001') + self.assertEqual(event.rows[0]["values"]["test"], "001") + self.assertEqual(event.rows[0]["values"]["test2"], "00001") + self.assertEqual(event.rows[0]["values"]["test3"], "00000001") + self.assertEqual(event.rows[0]["values"]["test4"], "0000000001") + self.assertEqual(event.rows[0]["values"]["test5"], "00000000000000000001") def test_partition_id(self): if not self.isMySQL80AndMore(): @@ -720,8 +803,8 @@ def test_status_vars(self): """ create_query = "CREATE TABLE test (id INTEGER)" event = self.create_table(create_query) - self.assertEqual(event.catalog_nz_code, b'std') - self.assertEqual(event.mts_accessed_db_names, [b'pymysqlreplication_test']) + self.assertEqual(event.catalog_nz_code, b"std") + self.assertEqual(event.mts_accessed_db_names, [b"pymysqlreplication_test"]) def test_null_bitmask(self): """Test parse of null-bitmask in table map events @@ -731,11 +814,11 @@ def test_null_bitmask(self): Raises: AssertionError: if null_bitmask isn't set as specified in 'bit_mask' variable - """ + """ # any 2-byte bitmask in little-endian hex bytes format (b'a\x03') ## b'a\x03' = 1101100001(2) - bit_mask = b'a\x03' + bit_mask = b"a\x03" # Prepare create_query create_query = "CREATE TABLE test" @@ -746,7 +829,7 @@ def test_null_bitmask(self): ## column name, column type, nullability column_definition = [] - column_name = chr(ord('a') + i) + column_name = chr(ord("a") + i) column_definition.append(column_name) column_type = "INT" @@ -764,8 +847,8 @@ def test_null_bitmask(self): values = [] for i in range(16): - values.append('0') - + values.append("0") + insert_query += f' ({",".join(values)})' self.execute(create_query) @@ -774,10 +857,10 @@ def test_null_bitmask(self): self.assertIsInstance(self.stream.fetchone(), RotateEvent) self.assertIsInstance(self.stream.fetchone(), FormatDescriptionEvent) - #QueryEvent for the Create Table + # QueryEvent for the Create Table self.assertIsInstance(self.stream.fetchone(), QueryEvent) - #QueryEvent for the BEGIN + # QueryEvent for the BEGIN self.assertIsInstance(self.stream.fetchone(), QueryEvent) event = self.stream.fetchone() @@ -805,7 +888,7 @@ def test_mariadb_only_status_vars(self): event = self.create_table(create_query) # skip dummy events with empty schema - while event.schema == b'': + while event.schema == b"": event = self.stream.fetchone() self.assertEqual(event.query, create_query) diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 00000000..aba6be01 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,9 @@ +[tool.ruff] +ignore = [ + "E501", # Line too long, handled by black + "F403", # from module import *' used, It should be removed afterwad + "F405", # same to F403 +] + +[tool.ruff.per-file-ignores] +"__init__.py" = ["F401"] \ No newline at end of file diff --git a/scripts/lint.sh b/scripts/lint.sh new file mode 100644 index 00000000..b9efc5d6 --- /dev/null +++ b/scripts/lint.sh @@ -0,0 +1,7 @@ +#!/usr/bin/env bash + +set -e +set -x + +ruff pymysqlreplication +black pymysqlreplication --check \ No newline at end of file diff --git a/setup.py b/setup.py index 8d2cc77b..9d09a424 100644 --- a/setup.py +++ b/setup.py @@ -40,14 +40,18 @@ def run(self): url="https://github.com/julien-duponchelle/python-mysql-replication", author="Julien Duponchelle", author_email="julien@duponchelle.info", - description=("Pure Python Implementation of MySQL replication protocol " - "build on top of PyMYSQL."), + description=( + "Pure Python Implementation of MySQL replication protocol " + "build on top of PyMYSQL." + ), long_description=long_description, - long_description_content_type='text/markdown', + long_description_content_type="text/markdown", license="Apache 2", - packages=["pymysqlreplication", - "pymysqlreplication.constants", - "pymysqlreplication.tests"], + packages=[ + "pymysqlreplication", + "pymysqlreplication.constants", + "pymysqlreplication.tests", + ], cmdclass={"test": TestCommand}, - install_requires=['pymysql>=1.1.0'], + install_requires=["pymysql>=1.1.0"], )