From 441aaa9ce398ad8c600f3461989046964cd2bb6a Mon Sep 17 00:00:00 2001 From: heehehe Date: Tue, 5 Sep 2023 10:24:47 +0900 Subject: [PATCH 1/7] feat: add gtid.py typing Co-authored-by: mjs1995 Co-authored-by: mikaniz --- pymysqlreplication/gtid.py | 124 +++++++++++++++++++++---------------- 1 file changed, 69 insertions(+), 55 deletions(-) diff --git a/pymysqlreplication/gtid.py b/pymysqlreplication/gtid.py index 3b2554da..df80aac2 100644 --- a/pymysqlreplication/gtid.py +++ b/pymysqlreplication/gtid.py @@ -5,15 +5,18 @@ import binascii from copy import deepcopy from io import BytesIO +from typing import List, Optional, Tuple, Union, Set -def overlap(i1, i2): + +def overlap(i1: Tuple[int, int], i2: Tuple[int, int]) -> bool: return i1[0] < i2[1] and i1[1] > i2[0] -def contains(i1, i2): +def contains(i1: Tuple[int, int], i2: Tuple[int, int]) -> bool: return i2[0] >= i1[0] and i2[1] <= i1[1] class Gtid(object): - """A mysql GTID is composed of a server-id and a set of right-open + """ + A mysql GTID is composed of a server-id and a set of right-open intervals [a,b), and represent all transactions x that happened on server SID such as @@ -49,7 +52,7 @@ class Gtid(object): Exception: Adding a Gtid with a different SID. """ @staticmethod - def parse_interval(interval): + def parse_interval(interval: str) -> Tuple[int, int]: """ We parse a human-generated string here. So our end value b is incremented to conform to the internal representation format. @@ -65,8 +68,9 @@ def parse_interval(interval): return (a, b+1) @staticmethod - def parse(gtid): - """Parse a GTID from mysql textual format. + def parse(gtid: str) -> Tuple[str, List[Tuple[int, int]]]: + """ + Parse a GTID from mysql textual format. Raises: - ValueError: if GTID format is incorrect. @@ -84,7 +88,7 @@ def parse(gtid): return (sid, intervals_parsed) - def __add_interval(self, itvl): + def __add_interval(self, itvl: Tuple[int, int]) -> None: """ Use the internal representation format and add it to our intervals, merging if required. @@ -92,7 +96,7 @@ def __add_interval(self, itvl): Raises: Exception: if Malformated interval or Overlapping interval """ - new = [] + new: List[Tuple[int, int]] = [] if itvl[0] > itvl[1]: raise Exception('Malformed interval %s' % (itvl,)) @@ -114,11 +118,13 @@ def __add_interval(self, itvl): self.intervals = sorted(new + [itvl]) - def __sub_interval(self, itvl): - """Using the internal representation, remove an interval + def __sub_interval(self, itvl: Tuple[int, int]) -> None: + """ + Using the internal representation, remove an interval - Raises: Exception if itvl malformated""" - new = [] + Raises: Exception if itvl malformated + """ + new: List[Tuple[int, int]] = [] if itvl[0] > itvl[1]: raise Exception('Malformed interval %s' % (itvl,)) @@ -139,8 +145,9 @@ def __sub_interval(self, itvl): self.intervals = new - def __contains__(self, other): - """Test if other is contained within self. + def __contains__(self, other: 'Gtid') -> bool: + """ + Test if other is contained within self. First we compare sid they must be equals. Then we search if intervals from other are contained within @@ -152,10 +159,8 @@ def __contains__(self, other): return all(any(contains(me, them) for me in self.intervals) for them in other.intervals) - def __init__(self, gtid, sid=None, intervals=[]): - if sid: - intervals = intervals - else: + def __init__(self, gtid: str, sid: Optional[str] = None, intervals: Optional[List[Tuple[int, int]]] = None) -> None: + if sid is None: sid, intervals = Gtid.parse(gtid) self.sid = sid @@ -163,11 +168,13 @@ def __init__(self, gtid, sid=None, intervals=[]): for itvl in intervals: self.__add_interval(itvl) - def __add__(self, other): - """Include the transactions of this gtid. + def __add__(self, other: 'Gtid') -> 'Gtid': + """ + Include the transactions of this gtid. Raises: - Exception: if the attempted merge has different SID""" + Exception: if the attempted merge has different SID + """ if self.sid != other.sid: raise Exception('Attempt to merge different SID' '%s != %s' % (self.sid, other.sid)) @@ -179,9 +186,10 @@ def __add__(self, other): return result - def __sub__(self, other): - """Remove intervals. Do not raise, if different SID simply - ignore""" + def __sub__(self, other: 'Gtid') -> 'Gtid': + """ + Remove intervals. Do not raise, if different SID simply ignore + """ result = deepcopy(self) if self.sid != other.sid: return result @@ -191,27 +199,30 @@ def __sub__(self, other): return result - def __str__(self): - """We represent the human value here - a single number - for one transaction, or a closed interval (decrementing b)""" + def __str__(self) -> str: + """ + We represent the human value here - a single number + for one transaction, or a closed interval (decrementing b) + """ return '%s:%s' % (self.sid, ':'.join(('%d-%d' % (x[0], x[1]-1)) if x[0] +1 != x[1] else str(x[0]) for x in self.intervals)) - def __repr__(self): + def __repr__(self) -> str: return '' % self @property - def encoded_length(self): + def encoded_length(self) -> int: return (16 + # sid 8 + # n_intervals 2 * # stop/start 8 * # stop/start mark encoded as int64 len(self.intervals)) - def encode(self): - """Encode a Gtid in binary + def encode(self) -> bytes: + """ + Encode a Gtid in binary Bytes are in **little endian**. Format: @@ -251,8 +262,9 @@ def encode(self): return buffer @classmethod - def decode(cls, payload): - """Decode from binary a Gtid + def decode(cls, payload: BytesIO) -> 'Gtid': + """ + Decode from binary a Gtid :param BytesIO payload to decode """ @@ -281,27 +293,27 @@ def decode(cls, payload): else '%d' % x for x in intervals]))) - def __eq__(self, other): + def __eq__(self, other: 'Gtid') -> bool: if other.sid != self.sid: return False return self.intervals == other.intervals - def __lt__(self, other): + def __lt__(self, other: 'Gtid') -> bool: if other.sid != self.sid: return self.sid < other.sid return self.intervals < other.intervals - def __le__(self, other): + def __le__(self, other: 'Gtid') -> bool: if other.sid != self.sid: return self.sid <= other.sid return self.intervals <= other.intervals - def __gt__(self, other): + def __gt__(self, other: 'Gtid') -> bool: if other.sid != self.sid: return self.sid > other.sid return self.intervals > other.intervals - def __ge__(self, other): + def __ge__(self, other: 'Gtid') -> bool: if other.sid != self.sid: return self.sid >= other.sid return self.intervals >= other.intervals @@ -309,7 +321,7 @@ def __ge__(self, other): class GtidSet(object): """Represents a set of Gtid""" - def __init__(self, gtid_set): + def __init__(self, gtid_set: Optional[Union[None, str, Set[Gtid], List[Gtid], Gtid]] = None) -> None: """ Construct a GtidSet initial state depends of the nature of `gtid_set` param. @@ -325,21 +337,21 @@ def __init__(self, gtid_set): - Exception: if Gtid interval are either malformated or overlapping """ - def _to_gtid(element): + def _to_gtid(element: str) -> Gtid: if isinstance(element, Gtid): return element return Gtid(element.strip(' \n')) if not gtid_set: - self.gtids = [] + self.gtids: List[Gtid] = [] elif isinstance(gtid_set, (list, set)): - self.gtids = [_to_gtid(x) for x in gtid_set] + self.gtids: List[Gtid] = [_to_gtid(x) for x in gtid_set] else: - self.gtids = [Gtid(x.strip(' \n')) for x in gtid_set.split(',')] + self.gtids: List[Gtid] = [Gtid(x.strip(' \n')) for x in gtid_set.split(',')] - def merge_gtid(self, gtid): + def merge_gtid(self, gtid: Gtid) -> None: """Insert a Gtid in current GtidSet.""" - new_gtids = [] + new_gtids: List[Gtid] = [] for existing in self.gtids: if existing.sid == gtid.sid: new_gtids.append(existing + gtid) @@ -349,7 +361,7 @@ def merge_gtid(self, gtid): new_gtids.append(gtid) self.gtids = new_gtids - def __contains__(self, other): + def __contains__(self, other: Union['GtidSet', Gtid]) -> bool: """ Test if self contains other, could be a GtidSet or a Gtid. @@ -363,7 +375,7 @@ def __contains__(self, other): return any(other in x for x in self.gtids) raise NotImplementedError - def __add__(self, other): + def __add__(self, other: Union['GtidSet', Gtid]) -> 'GtidSet': """ Merge current instance with an other GtidSet or with a Gtid alone. @@ -384,22 +396,23 @@ def __add__(self, other): raise NotImplementedError - def __str__(self): + def __str__(self) -> str: """ Returns a comma separated string of gtids. """ return ','.join(str(x) for x in self.gtids) - def __repr__(self): + def __repr__(self) -> str: return '' % self.gtids @property - def encoded_length(self): + def encoded_length(self) -> int: return (8 + # n_sids sum(x.encoded_length for x in self.gtids)) - def encoded(self): - """Encode a GtidSet in binary + def encoded(self) -> bytes: + """ + Encode a GtidSet in binary Bytes are in **little endian**. - `n_sid`: u64 is the number of Gtid to read @@ -421,8 +434,9 @@ def encoded(self): encode = encoded @classmethod - def decode(cls, payload): - """Decode a GtidSet from binary. + def decode(cls, payload: BytesIO) -> 'GtidSet': + """ + Decode a GtidSet from binary. :param BytesIO payload to decode """ @@ -432,5 +446,5 @@ def decode(cls, payload): return cls([Gtid.decode(payload) for _ in range(0, n_sid)]) - def __eq__(self, other): + def __eq__(self, other: 'GtidSet') -> bool: return self.gtids == other.gtids From 3b680fddb3740bd65f9809f4a942df70229b9f10 Mon Sep 17 00:00:00 2001 From: heehehe Date: Tue, 5 Sep 2023 10:32:29 +0900 Subject: [PATCH 2/7] feat: add exceptions.py typing Co-authored-by: starcat37 --- pymysqlreplication/exceptions.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/pymysqlreplication/exceptions.py b/pymysqlreplication/exceptions.py index 434d8d76..d233a6a6 100644 --- a/pymysqlreplication/exceptions.py +++ b/pymysqlreplication/exceptions.py @@ -1,19 +1,19 @@ class TableMetadataUnavailableError(Exception): - def __init__(self, table): + def __init__(self, table: str) -> None: Exception.__init__(self,"Unable to find metadata for table {0}".format(table)) class BinLogNotEnabled(Exception): - def __init__(self): + def __init__(self) -> None: Exception.__init__(self, "MySQL binary logging is not enabled.") class StatusVariableMismatch(Exception): - def __init__(self): - Exception.__init__(self, " ".join( + def __init__(self) -> None: + Exception.__init__(self, " ".join([ "Unknown status variable in query event." , "Possible parse failure in preceding fields" , "or outdated constants.STATUS_VAR_KEY" , "Refer to MySQL documentation/source code" , "or create an issue on GitHub" - )) + ])) From ec20292cbc866dd1e94458aa5e8e8c737684ae4f Mon Sep 17 00:00:00 2001 From: heehehe Date: Tue, 5 Sep 2023 10:37:43 +0900 Subject: [PATCH 3/7] feat: add binlogstream.py typing Co-authored-by: starcat37 --- pymysqlreplication/binlogstream.py | 331 +++++++++++++++-------------- 1 file changed, 168 insertions(+), 163 deletions(-) diff --git a/pymysqlreplication/binlogstream.py b/pymysqlreplication/binlogstream.py index 71eb10a2..ef59803a 100644 --- a/pymysqlreplication/binlogstream.py +++ b/pymysqlreplication/binlogstream.py @@ -5,7 +5,8 @@ import pymysql from pymysql.constants.COMMAND import COM_BINLOG_DUMP, COM_REGISTER_SLAVE -from pymysql.cursors import DictCursor +from pymysql.cursors import Cursor, DictCursor +from pymysql.connections import Connection, MysqlPacket from .constants.BINLOG import TABLE_MAP_EVENT, ROTATE_EVENT, FORMAT_DESCRIPTION_EVENT from .event import ( @@ -21,6 +22,7 @@ from .packet import BinLogPacketWrapper from .row_event import ( UpdateRowsEvent, WriteRowsEvent, DeleteRowsEvent, TableMapEvent) +from typing import ByteString, Union, Optional, List, Tuple, Dict, Any, Iterator, FrozenSet, Type try: from pymysql.constants.COMMAND import COM_BINLOG_DUMP_GTID @@ -35,28 +37,30 @@ class ReportSlave(object): - """Represent the values that you may report when connecting as a slave - to a master. SHOW SLAVE HOSTS related""" - - hostname = '' - username = '' - password = '' - port = 0 + """ + Represent the values that you may report + when connecting as a slave to a master. SHOW SLAVE HOSTS related. + """ - def __init__(self, value): + def __init__(self, value: Union[str, Tuple[str, str, str, int], Dict[str, Union[str, int]]]) -> None: """ Attributes: - value: string or tuple + value: string, tuple or dict if string, then it will be used hostname if tuple it will be used as (hostname, user, password, port) + if dict, keys 'hostname', 'username', 'password', 'port' will be used. """ + self.hostname: str = '' + self.username: str = '' + self.password: str = '' + self.port: int = 0 if isinstance(value, (tuple, list)): try: - self.hostname = value[0] - self.username = value[1] - self.password = value[2] - self.port = int(value[3]) + self.hostname: str = value[0] + self.username: str = value[1] + self.password: str = value[2] + self.port: int = int(value[3]) except IndexError: pass elif isinstance(value, dict): @@ -66,17 +70,17 @@ def __init__(self, value): except KeyError: pass else: - self.hostname = value + self.hostname: Union[str, tuple] = value - def __repr__(self): + def __repr__(self) -> str: return '' % \ (self.hostname, self.username, self.password, self.port) - def encoded(self, server_id, master_id=0): + def encoded(self, server_id: int, master_id: int = 0) -> ByteString: """ - server_id: the slave server-id - master_id: usually 0. Appears as "master id" in SHOW SLAVE HOSTS - on the master. Unknown what else it impacts. + :ivar server_id: int - the slave server-id + :ivar master_id: int - usually 0. Appears as "master id" in SHOW SLAVE HOSTS on the master. + Unknown what else it impacts. """ # 1 [15] COM_REGISTER_SLAVE @@ -91,23 +95,23 @@ def encoded(self, server_id, master_id=0): # 4 replication rank # 4 master-id - lhostname = len(self.hostname.encode()) - lusername = len(self.username.encode()) - lpassword = len(self.password.encode()) + lhostname: int = len(self.hostname.encode()) + lusername: int = len(self.username.encode()) + lpassword: int = len(self.password.encode()) - packet_len = (1 + # command - 4 + # server-id - 1 + # hostname length - lhostname + - 1 + # username length - lusername + - 1 + # password length - lpassword + - 2 + # slave mysql port - 4 + # replication rank - 4) # master-id + packet_len: int = (1 + # command + 4 + # server-id + 1 + # hostname length + lhostname + + 1 + # username length + lusername + + 1 + # password length + lpassword + + 2 + # slave mysql port + 4 + # replication rank + 4) # master-id - MAX_STRING_LEN = 257 # one byte for length + 256 chars + MAX_STRING_LEN: int = 257 # one byte for length + 256 chars return (struct.pack(' None: """ Attributes: - ctl_connection_settings: Connection settings for cluster holding + ctl_connection_settings[Dict]: Connection settings for cluster holding schema information - resume_stream: Start for event from position or the latest event of + resume_stream[bool]: Start for event from position or the latest event of binlog or from older available event - blocking: When master has finished reading/sending binlog it will + blocking[bool]: When master has finished reading/sending binlog it will send EOF instead of blocking connection. - only_events: Array of allowed events - ignored_events: Array of ignored events - log_file: Set replication start log file - log_pos: Set replication start log pos (resume_stream should be + only_events[List[str]]: Array of allowed events + ignored_events[List[str]]: Array of ignored events + log_file[str]: Set replication start log file + log_pos[int]: Set replication start log pos (resume_stream should be true) - end_log_pos: Set replication end log pos - auto_position: Use master_auto_position gtid to set position - only_tables: An array with the tables you want to watch (only works + end_log_pos[int]: Set replication end log pos + auto_position[str]: Use master_auto_position gtid to set position + only_tables[List[str]]: An array with the tables you want to watch (only works in binlog_format ROW) - ignored_tables: An array with the tables you want to skip - only_schemas: An array with the schemas you want to watch - ignored_schemas: An array with the schemas you want to skip - freeze_schema: If true do not support ALTER TABLE. It's faster. - skip_to_timestamp: Ignore all events until reaching specified - timestamp. - report_slave: Report slave in SHOW SLAVE HOSTS. - slave_uuid: Report slave_uuid or replica_uuid in SHOW SLAVE HOSTS(MySQL 8.0.21-) or + ignored_tables[List[str]]: An array with the tables you want to skip + only_schemas[List[str]]: An array with the schemas you want to watch + ignored_schemas[List[str]]: An array with the schemas you want to skip + freeze_schema[bool]: If true do not support ALTER TABLE. It's faster. + skip_to_timestamp[float]: Ignore all events until reaching specified timestamp. + report_slave[ReportSlave]: Report slave in SHOW SLAVE HOSTS. + slave_uuid[str]: Report slave_uuid or replica_uuid in SHOW SLAVE HOSTS(MySQL 8.0.21-) or SHOW REPLICAS(MySQL 8.0.22+) depends on your MySQL version. - fail_on_table_metadata_unavailable: Should raise exception if we - can't get table information on - row_events - slave_heartbeat: (seconds) Should master actively send heartbeat on + fail_on_table_metadata_unavailable[bool]: Should raise exception if we + can't get table information on row_events + slave_heartbeat[float]: (seconds) Should master actively send heartbeat on connection. This also reduces traffic in GTID replication on replication resumption (in case many event to skip in binlog). See MASTER_HEARTBEAT_PERIOD in mysql documentation for semantics - is_mariadb: Flag to indicate it's a MariaDB server, used with auto_position + is_mariadb[bool]: Flag to indicate it's a MariaDB server, used with auto_position to point to Mariadb specific GTID. - annotate_rows_event: Parameter value to enable annotate rows event in mariadb, + annotate_rows_event[bool]: Parameter value to enable annotate rows event in mariadb, used with 'is_mariadb' - ignore_decode_errors: If true, any decode errors encountered + ignore_decode_errors[bool]: If true, any decode errors encountered when reading column data will be ignored. - verify_checksum: If true, verify events read from the binary log by examining checksums. + verify_checksum[bool]: If true, verify events read from the binary log by examining checksums. """ - self.__connection_settings = connection_settings + self.__connection_settings: Dict = connection_settings self.__connection_settings.setdefault("charset", "utf8") - self.__connected_stream = False - self.__connected_ctl = False - self.__resume_stream = resume_stream - self.__blocking = blocking - self._ctl_connection_settings = ctl_connection_settings + self.__connected_stream: bool = False + self.__connected_ctl: bool = False + self.__resume_stream: bool = resume_stream + self.__blocking: bool = blocking + self._ctl_connection_settings: Dict = ctl_connection_settings if ctl_connection_settings: self._ctl_connection_settings.setdefault("charset", "utf8") - self.__only_tables = only_tables - self.__ignored_tables = ignored_tables - self.__only_schemas = only_schemas - self.__ignored_schemas = ignored_schemas - self.__freeze_schema = freeze_schema - self.__allowed_events = self._allowed_event_list( + self.__only_tables: Optional[List[str]] = only_tables + self.__ignored_tables: Optional[List[str]] = ignored_tables + self.__only_schemas: Optional[List[str]] = only_schemas + self.__ignored_schemas: Optional[List[str]] = ignored_schemas + self.__freeze_schema: bool = freeze_schema + self.__allowed_events: FrozenSet[str] = self._allowed_event_list( only_events, ignored_events, filter_non_implemented_events) - self.__fail_on_table_metadata_unavailable = fail_on_table_metadata_unavailable - self.__ignore_decode_errors = ignore_decode_errors - self.__verify_checksum = verify_checksum + self.__fail_on_table_metadata_unavailable: bool = fail_on_table_metadata_unavailable + self.__ignore_decode_errors: bool = ignore_decode_errors + self.__verify_checksum: bool = verify_checksum # We can't filter on packet level TABLE_MAP and rotate event because # we need them for handling other operations - self.__allowed_events_in_packet = frozenset( + self.__allowed_events_in_packet: FrozenSet[str] = frozenset( [TableMapEvent, RotateEvent]).union(self.__allowed_events) - self.__server_id = server_id - self.__use_checksum = False + self.__server_id: int = server_id + self.__use_checksum: bool = False # Store table meta information - self.table_map = {} - self.log_pos = log_pos - self.end_log_pos = end_log_pos - self.log_file = log_file - self.auto_position = auto_position - self.skip_to_timestamp = skip_to_timestamp - self.is_mariadb = is_mariadb - self.__annotate_rows_event = annotate_rows_event + self.table_map: Dict = {} + self.log_pos: Optional[int] = log_pos + self.end_log_pos: Optional[int] = end_log_pos + self.log_file: Optional[str] = log_file + self.auto_position: Optional[str] = auto_position + self.skip_to_timestamp: Optional[float] = skip_to_timestamp + self.is_mariadb: bool = is_mariadb + self.__annotate_rows_event: bool = annotate_rows_event if end_log_pos: - self.is_past_end_log_pos = False + self.is_past_end_log_pos: bool = False if report_slave: - self.report_slave = ReportSlave(report_slave) - self.slave_uuid = slave_uuid - self.slave_heartbeat = slave_heartbeat + self.report_slave: ReportSlave = ReportSlave(report_slave) + self.slave_uuid: Optional[str] = slave_uuid + self.slave_heartbeat: Optional[float] = slave_heartbeat if pymysql_wrapper: - self.pymysql_wrapper = pymysql_wrapper + self.pymysql_wrapper: Connection = pymysql_wrapper else: - self.pymysql_wrapper = pymysql.connect - self.mysql_version = (0, 0, 0) + self.pymysql_wrapper: Optional[Union[Connection, Type[Connection]]] = pymysql.connect + self.mysql_version: Tuple = (0, 0, 0) - def close(self): + def close(self) -> None: if self.__connected_stream: self._stream_connection.close() - self.__connected_stream = False + self.__connected_stream: bool = False if self.__connected_ctl: # break reference cycle between stream reader and underlying # mysql connection object self._ctl_connection._get_table_information = None self._ctl_connection.close() - self.__connected_ctl = False + self.__connected_ctl: bool = False - def __connect_to_ctl(self): + def __connect_to_ctl(self) -> None: if not self._ctl_connection_settings: - self._ctl_connection_settings = dict(self.__connection_settings) + self._ctl_connection_settings: Dict[str, Any] = dict(self.__connection_settings) self._ctl_connection_settings["db"] = "information_schema" self._ctl_connection_settings["cursorclass"] = DictCursor self._ctl_connection_settings["autocommit"] = True - self._ctl_connection = self.pymysql_wrapper(**self._ctl_connection_settings) + self._ctl_connection: Connection = self.pymysql_wrapper(**self._ctl_connection_settings) self._ctl_connection._get_table_information = self.__get_table_information - self.__connected_ctl = True + self.__connected_ctl: bool = True - def __checksum_enabled(self): - """Return True if binlog-checksum = CRC32. Only for MySQL > 5.6""" - cur = self._stream_connection.cursor() + def __checksum_enabled(self) -> bool: + """ + Return True if binlog-checksum = CRC32. Only for MySQL > 5.6 + """ + cur: Cursor = self._stream_connection.cursor() cur.execute("SHOW GLOBAL VARIABLES LIKE 'BINLOG_CHECKSUM'") - result = cur.fetchone() + result: Optional[Tuple[str, str]] = cur.fetchone() cur.close() if result is None: @@ -278,11 +283,11 @@ def __checksum_enabled(self): return False return True - def _register_slave(self): + def _register_slave(self) -> None: if not self.report_slave: return - packet = self.report_slave.encoded(self.__server_id) + packet: bytes = self.report_slave.encoded(self.__server_id) if pymysql.__version__ < LooseVersion("0.6"): self._stream_connection.wfile.write(packet) @@ -293,40 +298,40 @@ def _register_slave(self): self._stream_connection._next_seq_id = 1 self._stream_connection._read_packet() - def __connect_to_stream(self): + def __connect_to_stream(self) -> None: # log_pos (4) -- position in the binlog-file to start the stream with # flags (2) BINLOG_DUMP_NON_BLOCK (0 or 1) # server_id (4) -- server id of this slave # log_file (string.EOF) -- filename of the binlog on the master - self._stream_connection = self.pymysql_wrapper(**self.__connection_settings) + self._stream_connection: Connection = self.pymysql_wrapper(**self.__connection_settings) - self.__use_checksum = self.__checksum_enabled() + self.__use_checksum: bool = self.__checksum_enabled() # If checksum is enabled we need to inform the server about the that # we support it if self.__use_checksum: - cur = self._stream_connection.cursor() + cur: Cursor = self._stream_connection.cursor() cur.execute("SET @master_binlog_checksum= @@global.binlog_checksum") cur.close() if self.slave_uuid: - cur = self._stream_connection.cursor() + cur: Cursor = self._stream_connection.cursor() cur.execute("SET @slave_uuid = %s, @replica_uuid = %s", (self.slave_uuid, self.slave_uuid)) cur.close() if self.slave_heartbeat: # 4294967 is documented as the max value for heartbeats - net_timeout = float(self.__connection_settings.get('read_timeout', + net_timeout: float = float(self.__connection_settings.get('read_timeout', 4294967)) # If heartbeat is too low, the connection will disconnect before, # this is also the behavior in mysql - heartbeat = float(min(net_timeout / 2., self.slave_heartbeat)) + heartbeat: float = float(min(net_timeout / 2., self.slave_heartbeat)) if heartbeat > 4294967: heartbeat = 4294967 # master_heartbeat_period is nanoseconds - heartbeat = int(heartbeat * 1000000000) - cur = self._stream_connection.cursor() + heartbeat: int = int(heartbeat * 1000000000) + cur: Cursor = self._stream_connection.cursor() cur.execute("SET @master_heartbeat_period= %d" % heartbeat) cur.close() @@ -334,7 +339,7 @@ def __connect_to_stream(self): # Mariadb, when it tries to replace GTID events with dummy ones. Given that this library understands GTID # events, setting the capability to 4 circumvents this error. # If the DB is mysql, this won't have any effect so no need to run this in a condition - cur = self._stream_connection.cursor() + cur: Cursor = self._stream_connection.cursor() cur.execute("SET @mariadb_slave_capability=4") cur.close() @@ -347,15 +352,15 @@ def __connect_to_stream(self): # only when log_file and log_pos both provided, the position info is # valid, if not, get the current position from master if self.log_file is None or self.log_pos is None: - cur = self._stream_connection.cursor() + cur: Cursor = self._stream_connection.cursor() cur.execute("SHOW MASTER STATUS") - master_status = cur.fetchone() + master_status: Optional[Tuple[str, int, Any]] = cur.fetchone() if master_status is None: raise BinLogNotEnabled() self.log_file, self.log_pos = master_status[:2] cur.close() - prelude = struct.pack(' bytes: # https://mariadb.com/kb/en/5-slave-registration/ - cur = self._stream_connection.cursor() + cur: Cursor = self._stream_connection.cursor() if self.auto_position != None: cur.execute("SET @slave_connect_state='%s'" % self.auto_position) cur.execute("SET @slave_gtid_strict_mode=1") @@ -465,19 +470,19 @@ def __set_mariadb_settings(self): cur.close() # https://mariadb.com/kb/en/com_binlog_dump/ - header_size = ( + header_size: int = ( 4 + # binlog pos 2 + # binlog flags 4 + # slave server_id, 4 # requested binlog file name , set it to empty ) - prelude = struct.pack(' Union[BinLogPacketWrapper, None]: while True: if self.end_log_pos and self.is_past_end_log_pos: return None @@ -510,9 +515,9 @@ def fetchone(self): try: if pymysql.__version__ < LooseVersion("0.6"): - pkt = self._stream_connection.read_packet() + pkt: MysqlPacket = self._stream_connection.read_packet() else: - pkt = self._stream_connection._read_packet() + pkt: MysqlPacket = self._stream_connection._read_packet() except pymysql.OperationalError as error: code, message = error.args if code in MYSQL_EXPECTED_ERROR_CODES: @@ -528,7 +533,7 @@ def fetchone(self): if not pkt.is_ok_packet(): continue - binlog_event = BinLogPacketWrapper(pkt, self.table_map, + binlog_event: BinLogPacketWrapper = BinLogPacketWrapper(pkt, self.table_map, self._ctl_connection, self.mysql_version, self.__use_checksum, @@ -554,7 +559,7 @@ def fetchone(self): # invalidates all our cached table id to schema mappings. This means we have to load them all # again for each logfile which is potentially wasted effort but we can't really do much better # without being broken in restart case - self.table_map = {} + self.table_map: Dict = {} elif binlog_event.log_pos: self.log_pos = binlog_event.log_pos @@ -604,8 +609,8 @@ def fetchone(self): return binlog_event.event - def _allowed_event_list(self, only_events, ignored_events, - filter_non_implemented_events): + def _allowed_event_list(self, only_events: Optional[List[str]], ignored_events: Optional[List[str]], + filter_non_implemented_events: bool) -> FrozenSet[str]: if only_events is not None: events = set(only_events) else: @@ -645,13 +650,13 @@ def _allowed_event_list(self, only_events, ignored_events, pass return frozenset(events) - def __get_table_information(self, schema, table): + def __get_table_information(self, schema: str, table: str) -> List[Dict[str, Any]]: for i in range(1, 3): try: if not self.__connected_ctl: self.__connect_to_ctl() - cur = self._ctl_connection.cursor() + cur: Cursor = self._ctl_connection.cursor() cur.execute(""" SELECT COLUMN_NAME, COLLATION_NAME, CHARACTER_SET_NAME, @@ -662,7 +667,7 @@ def __get_table_information(self, schema, table): WHERE table_schema = %s AND table_name = %s """, (schema, table)) - result = sorted(cur.fetchall(), key=lambda x: x['ORDINAL_POSITION']) + result: List = sorted(cur.fetchall(), key=lambda x: x['ORDINAL_POSITION']) cur.close() return result @@ -674,5 +679,5 @@ def __get_table_information(self, schema, table): else: raise error - def __iter__(self): + def __iter__(self) -> Iterator[Union[BinLogPacketWrapper, None]]: return iter(self.fetchone, None) From b26525949f28a716844f7b5d6c2a1b72eaca5062 Mon Sep 17 00:00:00 2001 From: Heeseon Cheon Date: Tue, 5 Sep 2023 10:51:37 +0900 Subject: [PATCH 4/7] feat: add packet.py typing Co-authored-by: sean-k1 --- pymysqlreplication/packet.py | 255 ++++++++++++++++++----------------- 1 file changed, 133 insertions(+), 122 deletions(-) diff --git a/pymysqlreplication/packet.py b/pymysqlreplication/packet.py index 665caebe..ecc38607 100644 --- a/pymysqlreplication/packet.py +++ b/pymysqlreplication/packet.py @@ -4,6 +4,9 @@ from pymysqlreplication import constants, event, row_event +from typing import List, Tuple, Dict, Optional, Union, FrozenSet, Type +from pymysql.connections import MysqlPacket, Connection + # Constants from PyMYSQL source code NULL_COLUMN = 251 UNSIGNED_CHAR_COLUMN = 251 @@ -15,7 +18,6 @@ UNSIGNED_INT24_LENGTH = 3 UNSIGNED_INT64_LENGTH = 8 - JSONB_TYPE_SMALL_OBJECT = 0x0 JSONB_TYPE_LARGE_OBJECT = 0x1 JSONB_TYPE_SMALL_ARRAY = 0x2 @@ -36,24 +38,10 @@ JSONB_LITERAL_FALSE = 0x2 -def read_offset_or_inline(packet, large): - t = packet.read_uint8() - - if t in (JSONB_TYPE_LITERAL, - JSONB_TYPE_INT16, JSONB_TYPE_UINT16): - return (t, None, packet.read_binary_json_type_inlined(t, large)) - if large and t in (JSONB_TYPE_INT32, JSONB_TYPE_UINT32): - return (t, None, packet.read_binary_json_type_inlined(t, large)) - - if large: - return (t, packet.read_uint32(), None) - return (t, packet.read_uint16(), None) - - class BinLogPacketWrapper(object): """ - Bin Log Packet Wrapper. It uses an existing packet object, and wraps - around it, exposing useful variables while still providing access + Bin Log Packet Wrapper uses an existing packet object and wraps around it, + exposing useful variables while still providing access to the original packet objects variables and methods. """ @@ -83,7 +71,7 @@ class BinLogPacketWrapper(object): constants.DELETE_ROWS_EVENT_V2: row_event.DeleteRowsEvent, constants.TABLE_MAP_EVENT: row_event.TableMapEvent, - #5.6 GTID enabled replication events + # 5.6 GTID enabled replication events constants.ANONYMOUS_GTID_LOG_EVENT: event.NotImplementedEvent, # MariaDB GTID constants.MARIADB_ANNOTATE_ROWS_EVENT: event.MariadbAnnotateRowsEvent, @@ -93,26 +81,28 @@ class BinLogPacketWrapper(object): constants.MARIADB_START_ENCRYPTION_EVENT: event.MariadbStartEncryptionEvent } - def __init__(self, from_packet, table_map, - ctl_connection, - mysql_version, - use_checksum, - allowed_events, - only_tables, - ignored_tables, - only_schemas, - ignored_schemas, - freeze_schema, - fail_on_table_metadata_unavailable, - ignore_decode_errors, - verify_checksum,): + def __init__(self, + from_packet: MysqlPacket, + table_map: dict, + ctl_connection: Connection, + mysql_version: Tuple[int, int, int], + use_checksum: bool, + allowed_events: FrozenSet[Type[event.BinLogEvent]], + only_tables: Optional[List[str]], + ignored_tables: Optional[List[str]], + only_schemas: Optional[List[str]], + ignored_schemas: Optional[List[str]], + freeze_schema: bool, + fail_on_table_metadata_unavailable: bool, + ignore_decode_errors: bool, + verify_checksum: bool) -> None: # -1 because we ignore the ok byte self.read_bytes = 0 # Used when we want to override a value in the data buffer self.__data_buffer = b'' - self.packet = from_packet - self.charset = ctl_connection.charset + self.packet: MysqlPacket = from_packet + self.charset: str = ctl_connection.charset # OK value # timestamp @@ -123,13 +113,13 @@ def __init__(self, from_packet, table_map, unpack = struct.unpack(' bytes: size = int(size) self.read_bytes += size if len(self.__data_buffer) > 0: @@ -169,14 +163,15 @@ def read(self, size): return data + self.packet.read(size - len(data)) return self.packet.read(size) - def unread(self, data): - '''Push again data in data buffer. It's use when you want - to extract a bit from a value a let the rest of the code normally - read the datas''' + def unread(self, data: bytes) -> None: + """ + Push again data in data buffer. + Use to extract a bit from a value and ensure that the rest of the code reads data normally + """ self.read_bytes -= len(data) self.__data_buffer += data - def advance(self, size): + def advance(self, size: int) -> None: size = int(size) self.read_bytes += size buffer_len = len(self.__data_buffer) @@ -187,13 +182,11 @@ def advance(self, size): else: self.packet.advance(size) - def read_length_coded_binary(self): - """Read a 'Length Coded Binary' number from the data buffer. - + def read_length_coded_binary(self) -> Optional[int]: + """ + Read a 'Length Coded Binary' number from the data buffer. Length coded numbers can be anywhere from 1 to 9 bytes depending - on the value of the first byte. - - From PyMYSQL source code + on the value of the first byte. (From PyMYSQL source code) """ c = struct.unpack("!B", self.read(1))[0] if c == NULL_COLUMN: @@ -207,14 +200,12 @@ def read_length_coded_binary(self): elif c == UNSIGNED_INT64_COLUMN: return self.unpack_int64(self.read(UNSIGNED_INT64_LENGTH)) - def read_length_coded_string(self): - """Read a 'Length Coded String' from the data buffer. - - A 'Length Coded String' consists first of a length coded - (unsigned, positive) integer represented in 1-9 bytes followed by - that many bytes of binary data. (For example "cat" would be "3cat".) - - From PyMYSQL source code + def read_length_coded_string(self) -> Optional[str]: + """ + Read a 'Length Coded String' from the data buffer. + A 'Length Coded String' consists first of a length coded (unsigned, positive) integer + represented in 1-9 bytes followed by that many bytes of binary data. + (For example, "cat" would be "3cat". - From PyMYSQL source code) """ length = self.read_length_coded_binary() if length is None: @@ -228,8 +219,10 @@ def __getattr__(self, key): raise AttributeError("%s instance has no attribute '%s'" % (self.__class__, key)) - def read_int_be_by_size(self, size): - '''Read a big endian integer values based on byte number''' + def read_int_be_by_size(self, size: int) -> int: + """ + Read a big endian integer values based on byte number + """ if size == 1: return struct.unpack('>b', self.read(size))[0] elif size == 2: @@ -243,8 +236,10 @@ def read_int_be_by_size(self, size): elif size == 8: return struct.unpack('>l', self.read(size))[0] - def read_uint_by_size(self, size): - '''Read a little endian integer values based on byte number''' + def read_uint_by_size(self, size: int) -> int: + """ + Read a little endian integer values based on byte number + """ if size == 1: return self.read_uint8() elif size == 2: @@ -262,19 +257,18 @@ def read_uint_by_size(self, size): elif size == 8: return self.read_uint64() - def read_length_coded_pascal_string(self, size): - """Read a string with length coded using pascal style. + def read_length_coded_pascal_string(self, size: int) -> bytes: + """ + Read a string with length coded using pascal style. The string start by the size of the string """ length = self.read_uint_by_size(size) return self.read(length) - def read_variable_length_string(self): - """Read a variable length string where the first 1-5 bytes stores the - length of the string. - - For each byte, the first bit being high indicates another byte must be - read. + def read_variable_length_string(self) -> bytes: + """ + Read a variable length string where the first 1-5 bytes stores the length of the string. + For each byte, the first bit being high indicates another byte must be read. """ byte = 0x80 length = 0 @@ -285,82 +279,82 @@ def read_variable_length_string(self): bits_read = bits_read + 7 return self.read(length) - def read_int24(self): + def read_int24(self) -> int: a, b, c = struct.unpack("BBB", self.read(3)) res = a | (b << 8) | (c << 16) if res >= 0x800000: res -= 0x1000000 return res - def read_int24_be(self): + def read_int24_be(self) -> int: a, b, c = struct.unpack('BBB', self.read(3)) res = (a << 16) | (b << 8) | c if res >= 0x800000: res -= 0x1000000 return res - def read_uint8(self): + def read_uint8(self) -> int: return struct.unpack(' int: return struct.unpack(' int: return struct.unpack(' int: a, b, c = struct.unpack(" int: return struct.unpack(' int: return struct.unpack(' int: a, b = struct.unpack(" int: a, b = struct.unpack(">IB", self.read(5)) return b + (a << 8) - def read_uint48(self): + def read_uint48(self) -> int: a, b, c = struct.unpack(" int: a, b, c = struct.unpack(" int: return struct.unpack(' int: return struct.unpack(' int: return struct.unpack(' Optional[Union[int, Tuple[str, int]]]: try: - return struct.unpack('B', n[0])[0] \ - + (struct.unpack('B', n[1])[0] << 8) \ - + (struct.unpack('B', n[2])[0] << 16) + return struct.unpack('B', n[0:1])[0] \ + + (struct.unpack('B', n[1:2])[0] << 8) \ + + (struct.unpack('B', n[2:3])[0] << 16) except TypeError: return n[0] + (n[1] << 8) + (n[2] << 16) - def unpack_int32(self, n): + def unpack_int32(self, n: bytes) -> Optional[Union[int, Tuple[str, int]]]: try: - return struct.unpack('B', n[0])[0] \ - + (struct.unpack('B', n[1])[0] << 8) \ - + (struct.unpack('B', n[2])[0] << 16) \ - + (struct.unpack('B', n[3])[0] << 24) + return struct.unpack('B', n[0:1])[0] \ + + (struct.unpack('B', n[1:2])[0] << 8) \ + + (struct.unpack('B', n[2:3])[0] << 16) \ + + (struct.unpack('B', n[3:4])[0] << 24) except TypeError: return n[0] + (n[1] << 8) + (n[2] << 16) + (n[3] << 24) - def read_binary_json(self, size): + def read_binary_json(self, size: int) -> Optional[str]: length = self.read_uint_by_size(size) if length == 0: # handle NULL value @@ -371,7 +365,10 @@ def read_binary_json(self, size): return self.read_binary_json_type(t, length) - def read_binary_json_type(self, t, length): + def read_binary_json_type(self, t: int, length: int) \ + -> Optional[Union[ + Dict[bytes, Union[bool, str, None]], + List[int], bool, int, bytes]]: large = (t in (JSONB_TYPE_LARGE_OBJECT, JSONB_TYPE_LARGE_ARRAY)) if t in (JSONB_TYPE_SMALL_OBJECT, JSONB_TYPE_LARGE_OBJECT): return self.read_binary_json_object(length - 1, large) @@ -404,7 +401,7 @@ def read_binary_json_type(self, t, length): raise ValueError('Json type %d is not handled' % t) - def read_binary_json_type_inlined(self, t, large): + def read_binary_json_type_inlined(self, t: int, large: bool) -> Optional[Union[bool, int]]: if t == JSONB_TYPE_LITERAL: value = self.read_uint32() if large else self.read_uint16() if value == JSONB_LITERAL_NULL: @@ -424,7 +421,8 @@ def read_binary_json_type_inlined(self, t, large): raise ValueError('Json type %d is not handled' % t) - def read_binary_json_object(self, length, large): + def read_binary_json_object(self, length: int, large: bool) \ + -> Dict[bytes, Union[bool, str, None]]: if large: elements = self.read_uint32() size = self.read_uint32() @@ -438,13 +436,13 @@ def read_binary_json_object(self, length, large): if large: key_offset_lengths = [( self.read_uint32(), # offset (we don't actually need that) - self.read_uint16() # size of the key - ) for _ in range(elements)] + self.read_uint16() # size of the key + ) for _ in range(elements)] else: key_offset_lengths = [( self.read_uint16(), # offset (we don't actually need that) - self.read_uint16() # size of key - ) for _ in range(elements)] + self.read_uint16() # size of key + ) for _ in range(elements)] value_type_inlined_lengths = [read_offset_or_inline(self, large) for _ in range(elements)] @@ -462,7 +460,7 @@ def read_binary_json_object(self, length, large): return out - def read_binary_json_array(self, length, large): + def read_binary_json_array(self, length: int, large: bool) -> List[int]: if large: elements = self.read_uint32() size = self.read_uint32() @@ -477,20 +475,18 @@ def read_binary_json_array(self, length, large): read_offset_or_inline(self, large) for _ in range(elements)] - def _read(x): + def _read(x: Tuple[int, Optional[bytes], Optional[Union[bool, int]]]) -> int: if x[1] is None: return x[2] return self.read_binary_json_type(x[0], length) return [_read(x) for x in values_type_offset_inline] - def read_string(self): - """Read a 'Length Coded String' from the data buffer. - + def read_string(self) -> bytes: + """ + Read a 'Length Coded String' from the data buffer. Read __data_buffer until NULL character (0 = \0 = \x00) - - Returns: - Binary string parsed from __data_buffer + :return string: Binary string parsed from __data_buffer """ string = b'' while True: @@ -500,3 +496,18 @@ def read_string(self): string += char return string + + +def read_offset_or_inline(packet: Union[MysqlPacket, BinLogPacketWrapper], large: bool) \ + -> Tuple[int, Optional[bytes], Optional[Union[bool, int]]]: + t = packet.read_uint8() + + if t in (JSONB_TYPE_LITERAL, + JSONB_TYPE_INT16, JSONB_TYPE_UINT16): + return t, None, packet.read_binary_json_type_inlined(t, large) + if large and t in (JSONB_TYPE_INT32, JSONB_TYPE_UINT32): + return t, None, packet.read_binary_json_type_inlined(t, large) + + if large: + return t, packet.read_uint32(), None + return t, packet.read_uint16(), None From 8db60c07d0111d333ca5d2f042aee1a8ba8c0440 Mon Sep 17 00:00:00 2001 From: starcat37 Date: Wed, 6 Sep 2023 22:10:57 +0900 Subject: [PATCH 5/7] fix: modify binlogstream.py typing using mypy --- pymysqlreplication/binlogstream.py | 104 ++++++++++++----------------- 1 file changed, 42 insertions(+), 62 deletions(-) diff --git a/pymysqlreplication/binlogstream.py b/pymysqlreplication/binlogstream.py index ef59803a..19c47ea8 100644 --- a/pymysqlreplication/binlogstream.py +++ b/pymysqlreplication/binlogstream.py @@ -10,7 +10,7 @@ from .constants.BINLOG import TABLE_MAP_EVENT, ROTATE_EVENT, FORMAT_DESCRIPTION_EVENT from .event import ( - QueryEvent, RotateEvent, FormatDescriptionEvent, + BinLogEvent, QueryEvent, RotateEvent, FormatDescriptionEvent, XidEvent, GtidEvent, StopEvent, XAPrepareEvent, BeginLoadQueryEvent, ExecuteLoadQueryEvent, HeartbeatLogEvent, NotImplementedEvent, MariadbGtidEvent, @@ -22,7 +22,7 @@ from .packet import BinLogPacketWrapper from .row_event import ( UpdateRowsEvent, WriteRowsEvent, DeleteRowsEvent, TableMapEvent) -from typing import ByteString, Union, Optional, List, Tuple, Dict, Any, Iterator, FrozenSet, Type +from typing import ByteString, Union, Optional, List, Tuple, Dict, Any, Iterator, FrozenSet, Type, Set, Iterable try: from pymysql.constants.COMMAND import COM_BINLOG_DUMP_GTID @@ -57,10 +57,10 @@ def __init__(self, value: Union[str, Tuple[str, str, str, int], Dict[str, Union[ if isinstance(value, (tuple, list)): try: - self.hostname: str = value[0] - self.username: str = value[1] - self.password: str = value[2] - self.port: int = int(value[3]) + self.hostname = value[0] + self.username = value[1] + self.password = value[2] + self.port = int(value[3]) except IndexError: pass elif isinstance(value, dict): @@ -70,7 +70,7 @@ def __init__(self, value: Union[str, Tuple[str, str, str, int], Dict[str, Union[ except KeyError: pass else: - self.hostname: Union[str, tuple] = value + self.hostname = value def __repr__(self) -> str: return '' % \ @@ -172,7 +172,7 @@ def __init__(self, connection_settings: Dict, server_id: int, ignored_schemas[List[str]]: An array with the schemas you want to skip freeze_schema[bool]: If true do not support ALTER TABLE. It's faster. skip_to_timestamp[float]: Ignore all events until reaching specified timestamp. - report_slave[ReportSlave]: Report slave in SHOW SLAVE HOSTS. + report_slave[Union[str, Tuple[str, str, str, int]]]: Report slave in SHOW SLAVE HOSTS. slave_uuid[str]: Report slave_uuid or replica_uuid in SHOW SLAVE HOSTS(MySQL 8.0.21-) or SHOW REPLICAS(MySQL 8.0.22+) depends on your MySQL version. fail_on_table_metadata_unavailable[bool]: Should raise exception if we @@ -199,7 +199,7 @@ def __init__(self, connection_settings: Dict, server_id: int, self.__connected_ctl: bool = False self.__resume_stream: bool = resume_stream self.__blocking: bool = blocking - self._ctl_connection_settings: Dict = ctl_connection_settings + self._ctl_connection_settings: Optional[Dict] = ctl_connection_settings if ctl_connection_settings: self._ctl_connection_settings.setdefault("charset", "utf8") @@ -208,7 +208,7 @@ def __init__(self, connection_settings: Dict, server_id: int, self.__only_schemas: Optional[List[str]] = only_schemas self.__ignored_schemas: Optional[List[str]] = ignored_schemas self.__freeze_schema: bool = freeze_schema - self.__allowed_events: FrozenSet[str] = self._allowed_event_list( + self.__allowed_events: FrozenSet[type[BinLogEvent]] = self._allowed_event_list( only_events, ignored_events, filter_non_implemented_events) self.__fail_on_table_metadata_unavailable: bool = fail_on_table_metadata_unavailable self.__ignore_decode_errors: bool = ignore_decode_errors @@ -216,7 +216,7 @@ def __init__(self, connection_settings: Dict, server_id: int, # We can't filter on packet level TABLE_MAP and rotate event because # we need them for handling other operations - self.__allowed_events_in_packet: FrozenSet[str] = frozenset( + self.__allowed_events_in_packet: FrozenSet[type[BinLogEvent]] = frozenset( [TableMapEvent, RotateEvent]).union(self.__allowed_events) self.__server_id: int = server_id @@ -244,28 +244,28 @@ def __init__(self, connection_settings: Dict, server_id: int, self.pymysql_wrapper: Connection = pymysql_wrapper else: self.pymysql_wrapper: Optional[Union[Connection, Type[Connection]]] = pymysql.connect - self.mysql_version: Tuple = (0, 0, 0) + self.mysql_version: Tuple[int, int, int] = (0, 0, 0) def close(self) -> None: if self.__connected_stream: self._stream_connection.close() - self.__connected_stream: bool = False + self.__connected_stream = False if self.__connected_ctl: # break reference cycle between stream reader and underlying # mysql connection object self._ctl_connection._get_table_information = None self._ctl_connection.close() - self.__connected_ctl: bool = False + self.__connected_ctl = False def __connect_to_ctl(self) -> None: if not self._ctl_connection_settings: - self._ctl_connection_settings: Dict[str, Any] = dict(self.__connection_settings) + self._ctl_connection_settings: Dict = dict(self.__connection_settings) self._ctl_connection_settings["db"] = "information_schema" self._ctl_connection_settings["cursorclass"] = DictCursor self._ctl_connection_settings["autocommit"] = True self._ctl_connection: Connection = self.pymysql_wrapper(**self._ctl_connection_settings) self._ctl_connection._get_table_information = self.__get_table_information - self.__connected_ctl: bool = True + self.__connected_ctl = True def __checksum_enabled(self) -> bool: """ @@ -273,7 +273,7 @@ def __checksum_enabled(self) -> bool: """ cur: Cursor = self._stream_connection.cursor() cur.execute("SHOW GLOBAL VARIABLES LIKE 'BINLOG_CHECKSUM'") - result: Optional[Tuple[str, str]] = cur.fetchone() + result: Optional[Tuple[Any]] = cur.fetchone() cur.close() if result is None: @@ -305,7 +305,7 @@ def __connect_to_stream(self) -> None: # log_file (string.EOF) -- filename of the binlog on the master self._stream_connection: Connection = self.pymysql_wrapper(**self.__connection_settings) - self.__use_checksum: bool = self.__checksum_enabled() + self.__use_checksum = self.__checksum_enabled() # If checksum is enabled we need to inform the server about the that # we support it @@ -315,23 +315,23 @@ def __connect_to_stream(self) -> None: cur.close() if self.slave_uuid: - cur: Cursor = self._stream_connection.cursor() + cur = self._stream_connection.cursor() cur.execute("SET @slave_uuid = %s, @replica_uuid = %s", (self.slave_uuid, self.slave_uuid)) cur.close() if self.slave_heartbeat: # 4294967 is documented as the max value for heartbeats - net_timeout: float = float(self.__connection_settings.get('read_timeout', + net_timeout = float(self.__connection_settings.get('read_timeout', 4294967)) # If heartbeat is too low, the connection will disconnect before, # this is also the behavior in mysql - heartbeat: float = float(min(net_timeout / 2., self.slave_heartbeat)) + heartbeat = float(min(net_timeout / 2., self.slave_heartbeat)) if heartbeat > 4294967: heartbeat = 4294967 # master_heartbeat_period is nanoseconds - heartbeat: int = int(heartbeat * 1000000000) - cur: Cursor = self._stream_connection.cursor() + heartbeat = int(heartbeat * 1000000000) + cur = self._stream_connection.cursor() cur.execute("SET @master_heartbeat_period= %d" % heartbeat) cur.close() @@ -339,7 +339,7 @@ def __connect_to_stream(self) -> None: # Mariadb, when it tries to replace GTID events with dummy ones. Given that this library understands GTID # events, setting the capability to 4 circumvents this error. # If the DB is mysql, this won't have any effect so no need to run this in a condition - cur: Cursor = self._stream_connection.cursor() + cur = self._stream_connection.cursor() cur.execute("SET @mariadb_slave_capability=4") cur.close() @@ -347,14 +347,14 @@ def __connect_to_stream(self) -> None: if not self.auto_position: if self.is_mariadb: - prelude = self.__set_mariadb_settings() + prelude: bytes = self.__set_mariadb_settings() else: # only when log_file and log_pos both provided, the position info is # valid, if not, get the current position from master if self.log_file is None or self.log_pos is None: - cur: Cursor = self._stream_connection.cursor() + cur = self._stream_connection.cursor() cur.execute("SHOW MASTER STATUS") - master_status: Optional[Tuple[str, int, Any]] = cur.fetchone() + master_status: Optional[Tuple[Any]] = cur.fetchone() if master_status is None: raise BinLogNotEnabled() self.log_file, self.log_pos = master_status[:2] @@ -368,7 +368,7 @@ def __connect_to_stream(self) -> None: else: prelude += struct.pack(' None: prelude += self.log_file.encode() else: if self.is_mariadb: - prelude = self.__set_mariadb_settings() + prelude: bytes = self.__set_mariadb_settings() else: # Format for mysql packet master_auto_position # @@ -425,10 +425,10 @@ def __connect_to_stream(self) -> None: 8 + # binlog_pos_info_size 4) # encoded_data_size - prelude: ByteString = b'' + struct.pack(' None: else: self._stream_connection._write_bytes(prelude) self._stream_connection._next_seq_id = 1 - self.__connected_stream: bool = True + self.__connected_stream = True def __set_mariadb_settings(self) -> bytes: # https://mariadb.com/kb/en/5-slave-registration/ @@ -517,7 +517,7 @@ def fetchone(self) -> Union[BinLogPacketWrapper, None]: if pymysql.__version__ < LooseVersion("0.6"): pkt: MysqlPacket = self._stream_connection.read_packet() else: - pkt: MysqlPacket = self._stream_connection._read_packet() + pkt = self._stream_connection._read_packet() except pymysql.OperationalError as error: code, message = error.args if code in MYSQL_EXPECTED_ERROR_CODES: @@ -559,7 +559,7 @@ def fetchone(self) -> Union[BinLogPacketWrapper, None]: # invalidates all our cached table id to schema mappings. This means we have to load them all # again for each logfile which is potentially wasted effort but we can't really do much better # without being broken in restart case - self.table_map: Dict = {} + self.table_map = {} elif binlog_event.log_pos: self.log_pos = binlog_event.log_pos @@ -609,37 +609,17 @@ def fetchone(self) -> Union[BinLogPacketWrapper, None]: return binlog_event.event - def _allowed_event_list(self, only_events: Optional[List[str]], ignored_events: Optional[List[str]], - filter_non_implemented_events: bool) -> FrozenSet[str]: + def _allowed_event_list(self, only_events: Optional[Union[Set[type(BinLogEvent)], List[str]]], + ignored_events: Optional[List[str]], filter_non_implemented_events: bool) \ + -> FrozenSet[type[BinLogEvent]]: if only_events is not None: events = set(only_events) else: - events = set(( - QueryEvent, - RotateEvent, - StopEvent, - FormatDescriptionEvent, - XAPrepareEvent, - XidEvent, - GtidEvent, - BeginLoadQueryEvent, - ExecuteLoadQueryEvent, - UpdateRowsEvent, - WriteRowsEvent, - DeleteRowsEvent, - TableMapEvent, - HeartbeatLogEvent, - NotImplementedEvent, - MariadbGtidEvent, - RowsQueryLogEvent, - MariadbAnnotateRowsEvent, - RandEvent, - MariadbStartEncryptionEvent, - MariadbGtidListEvent, - MariadbBinLogCheckPointEvent, - UserVarEvent, - PreviousGtidsEvent - )) + events = {QueryEvent, RotateEvent, StopEvent, FormatDescriptionEvent, XAPrepareEvent, XidEvent, GtidEvent, + BeginLoadQueryEvent, ExecuteLoadQueryEvent, UpdateRowsEvent, WriteRowsEvent, DeleteRowsEvent, + TableMapEvent, HeartbeatLogEvent, NotImplementedEvent, MariadbGtidEvent, RowsQueryLogEvent, + MariadbAnnotateRowsEvent, RandEvent, MariadbStartEncryptionEvent, MariadbGtidListEvent, + MariadbBinLogCheckPointEvent, UserVarEvent, PreviousGtidsEvent} if ignored_events is not None: for e in ignored_events: events.remove(e) From 2ea58573c132975cacde918b859bc84ddc5026a8 Mon Sep 17 00:00:00 2001 From: starcat37 Date: Wed, 6 Sep 2023 22:15:14 +0900 Subject: [PATCH 6/7] fix: fix 'type' object is not subscriptable error --- pymysqlreplication/binlogstream.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pymysqlreplication/binlogstream.py b/pymysqlreplication/binlogstream.py index 19c47ea8..d5654e6b 100644 --- a/pymysqlreplication/binlogstream.py +++ b/pymysqlreplication/binlogstream.py @@ -208,7 +208,7 @@ def __init__(self, connection_settings: Dict, server_id: int, self.__only_schemas: Optional[List[str]] = only_schemas self.__ignored_schemas: Optional[List[str]] = ignored_schemas self.__freeze_schema: bool = freeze_schema - self.__allowed_events: FrozenSet[type[BinLogEvent]] = self._allowed_event_list( + self.__allowed_events: FrozenSet[Type[BinLogEvent]] = self._allowed_event_list( only_events, ignored_events, filter_non_implemented_events) self.__fail_on_table_metadata_unavailable: bool = fail_on_table_metadata_unavailable self.__ignore_decode_errors: bool = ignore_decode_errors @@ -216,7 +216,7 @@ def __init__(self, connection_settings: Dict, server_id: int, # We can't filter on packet level TABLE_MAP and rotate event because # we need them for handling other operations - self.__allowed_events_in_packet: FrozenSet[type[BinLogEvent]] = frozenset( + self.__allowed_events_in_packet: FrozenSet[Type[BinLogEvent]] = frozenset( [TableMapEvent, RotateEvent]).union(self.__allowed_events) self.__server_id: int = server_id @@ -609,9 +609,9 @@ def fetchone(self) -> Union[BinLogPacketWrapper, None]: return binlog_event.event - def _allowed_event_list(self, only_events: Optional[Union[Set[type(BinLogEvent)], List[str]]], + def _allowed_event_list(self, only_events: Optional[Union[Set[Type[BinLogEvent]], List[str]]], ignored_events: Optional[List[str]], filter_non_implemented_events: bool) \ - -> FrozenSet[type[BinLogEvent]]: + -> FrozenSet[Type[BinLogEvent]]: if only_events is not None: events = set(only_events) else: From 16b713f3e275d6afdfebd37ac8ae1a722bd7a1d1 Mon Sep 17 00:00:00 2001 From: starcat37 Date: Wed, 6 Sep 2023 22:38:34 +0900 Subject: [PATCH 7/7] fix: delete unnecessary typing --- pymysqlreplication/binlogstream.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pymysqlreplication/binlogstream.py b/pymysqlreplication/binlogstream.py index d5654e6b..6d65a2cd 100644 --- a/pymysqlreplication/binlogstream.py +++ b/pymysqlreplication/binlogstream.py @@ -192,7 +192,7 @@ def __init__(self, connection_settings: Dict, server_id: int, verify_checksum[bool]: If true, verify events read from the binary log by examining checksums. """ - self.__connection_settings: Dict = connection_settings + self.__connection_settings = connection_settings self.__connection_settings.setdefault("charset", "utf8") self.__connected_stream: bool = False @@ -418,7 +418,7 @@ def __connect_to_stream(self) -> None: gtid_set: GtidSet = GtidSet(self.auto_position) encoded_data_size: int = gtid_set.encoded_length - header_size: int = (2 + # binlog_flags + header_size = (2 + # binlog_flags 4 + # server_id 4 + # binlog_name_info_size 4 + # empty binlog name