diff --git a/pymysqlreplication/binlogstream.py b/pymysqlreplication/binlogstream.py index c153fcda..a2b7cf92 100644 --- a/pymysqlreplication/binlogstream.py +++ b/pymysqlreplication/binlogstream.py @@ -5,7 +5,8 @@ import pymysql from pymysql.constants.COMMAND import COM_BINLOG_DUMP, COM_REGISTER_SLAVE -from pymysql.cursors import DictCursor +from pymysql.cursors import Cursor, DictCursor +from pymysql.connections import Connection, MysqlPacket from .constants.BINLOG import TABLE_MAP_EVENT, ROTATE_EVENT, FORMAT_DESCRIPTION_EVENT from .event import ( @@ -20,6 +21,7 @@ from .packet import BinLogPacketWrapper from .row_event import ( UpdateRowsEvent, WriteRowsEvent, DeleteRowsEvent, TableMapEvent) +from typing import ByteString, Union, Optional, List, Tuple, Dict, Any, Iterator, FrozenSet, Type try: from pymysql.constants.COMMAND import COM_BINLOG_DUMP_GTID @@ -34,28 +36,30 @@ class ReportSlave(object): - """Represent the values that you may report when connecting as a slave - to a master. SHOW SLAVE HOSTS related""" - - hostname = '' - username = '' - password = '' - port = 0 + """ + Represent the values that you may report + when connecting as a slave to a master. SHOW SLAVE HOSTS related. + """ - def __init__(self, value): + def __init__(self, value: Union[str, Tuple[str, str, str, int], Dict[str, Union[str, int]]]) -> None: """ Attributes: - value: string or tuple + value: string, tuple or dict if string, then it will be used hostname if tuple it will be used as (hostname, user, password, port) + if dict, keys 'hostname', 'username', 'password', 'port' will be used. """ + self.hostname: str = '' + self.username: str = '' + self.password: str = '' + self.port: int = 0 if isinstance(value, (tuple, list)): try: - self.hostname = value[0] - self.username = value[1] - self.password = value[2] - self.port = int(value[3]) + self.hostname: str = value[0] + self.username: str = value[1] + self.password: str = value[2] + self.port: int = int(value[3]) except IndexError: pass elif isinstance(value, dict): @@ -65,17 +69,17 @@ def __init__(self, value): except KeyError: pass else: - self.hostname = value + self.hostname: Union[str, tuple] = value - def __repr__(self): + def __repr__(self) -> str: return '' % \ (self.hostname, self.username, self.password, self.port) - def encoded(self, server_id, master_id=0): + def encoded(self, server_id: int, master_id: int = 0) -> ByteString: """ - server_id: the slave server-id - master_id: usually 0. Appears as "master id" in SHOW SLAVE HOSTS - on the master. Unknown what else it impacts. + :ivar server_id: int - the slave server-id + :ivar master_id: int - usually 0. Appears as "master id" in SHOW SLAVE HOSTS on the master. + Unknown what else it impacts. """ # 1 [15] COM_REGISTER_SLAVE @@ -90,23 +94,23 @@ def encoded(self, server_id, master_id=0): # 4 replication rank # 4 master-id - lhostname = len(self.hostname.encode()) - lusername = len(self.username.encode()) - lpassword = len(self.password.encode()) + lhostname: int = len(self.hostname.encode()) + lusername: int = len(self.username.encode()) + lpassword: int = len(self.password.encode()) - packet_len = (1 + # command - 4 + # server-id - 1 + # hostname length - lhostname + - 1 + # username length - lusername + - 1 + # password length - lpassword + - 2 + # slave mysql port - 4 + # replication rank - 4) # master-id + packet_len: int = (1 + # command + 4 + # server-id + 1 + # hostname length + lhostname + + 1 + # username length + lusername + + 1 + # password length + lpassword + + 2 + # slave mysql port + 4 + # replication rank + 4) # master-id - MAX_STRING_LEN = 257 # one byte for length + 256 chars + MAX_STRING_LEN: int = 257 # one byte for length + 256 chars return (struct.pack(' None: """ Attributes: - ctl_connection_settings: Connection settings for cluster holding + ctl_connection_settings[Dict]: Connection settings for cluster holding schema information - resume_stream: Start for event from position or the latest event of + resume_stream[bool]: Start for event from position or the latest event of binlog or from older available event - blocking: When master has finished reading/sending binlog it will + blocking[bool]: When master has finished reading/sending binlog it will send EOF instead of blocking connection. - only_events: Array of allowed events - ignored_events: Array of ignored events - log_file: Set replication start log file - log_pos: Set replication start log pos (resume_stream should be + only_events[List[str]]: Array of allowed events + ignored_events[List[str]]: Array of ignored events + log_file[str]: Set replication start log file + log_pos[int]: Set replication start log pos (resume_stream should be true) - end_log_pos: Set replication end log pos - auto_position: Use master_auto_position gtid to set position - only_tables: An array with the tables you want to watch (only works + end_log_pos[int]: Set replication end log pos + auto_position[str]: Use master_auto_position gtid to set position + only_tables[List[str]]: An array with the tables you want to watch (only works in binlog_format ROW) - ignored_tables: An array with the tables you want to skip - only_schemas: An array with the schemas you want to watch - ignored_schemas: An array with the schemas you want to skip - freeze_schema: If true do not support ALTER TABLE. It's faster. - skip_to_timestamp: Ignore all events until reaching specified - timestamp. - report_slave: Report slave in SHOW SLAVE HOSTS. - slave_uuid: Report slave_uuid or replica_uuid in SHOW SLAVE HOSTS(MySQL 8.0.21-) or + ignored_tables[List[str]]: An array with the tables you want to skip + only_schemas[List[str]]: An array with the schemas you want to watch + ignored_schemas[List[str]]: An array with the schemas you want to skip + freeze_schema[bool]: If true do not support ALTER TABLE. It's faster. + skip_to_timestamp[float]: Ignore all events until reaching specified timestamp. + report_slave[ReportSlave]: Report slave in SHOW SLAVE HOSTS. + slave_uuid[str]: Report slave_uuid or replica_uuid in SHOW SLAVE HOSTS(MySQL 8.0.21-) or SHOW REPLICAS(MySQL 8.0.22+) depends on your MySQL version. - fail_on_table_metadata_unavailable: Should raise exception if we - can't get table information on - row_events - slave_heartbeat: (seconds) Should master actively send heartbeat on + fail_on_table_metadata_unavailable[bool]: Should raise exception if we + can't get table information on row_events + slave_heartbeat[float]: (seconds) Should master actively send heartbeat on connection. This also reduces traffic in GTID replication on replication resumption (in case many event to skip in binlog). See MASTER_HEARTBEAT_PERIOD in mysql documentation for semantics - is_mariadb: Flag to indicate it's a MariaDB server, used with auto_position + is_mariadb[bool]: Flag to indicate it's a MariaDB server, used with auto_position to point to Mariadb specific GTID. - annotate_rows_event: Parameter value to enable annotate rows event in mariadb, + annotate_rows_event[bool]: Parameter value to enable annotate rows event in mariadb, used with 'is_mariadb' - ignore_decode_errors: If true, any decode errors encountered + ignore_decode_errors[bool]: If true, any decode errors encountered when reading column data will be ignored. """ - self.__connection_settings = connection_settings + self.__connection_settings: Dict = connection_settings self.__connection_settings.setdefault("charset", "utf8") - self.__connected_stream = False - self.__connected_ctl = False - self.__resume_stream = resume_stream - self.__blocking = blocking - self._ctl_connection_settings = ctl_connection_settings + self.__connected_stream: bool = False + self.__connected_ctl: bool = False + self.__resume_stream: bool = resume_stream + self.__blocking: bool = blocking + self._ctl_connection_settings: Dict = ctl_connection_settings if ctl_connection_settings: self._ctl_connection_settings.setdefault("charset", "utf8") - self.__only_tables = only_tables - self.__ignored_tables = ignored_tables - self.__only_schemas = only_schemas - self.__ignored_schemas = ignored_schemas - self.__freeze_schema = freeze_schema - self.__allowed_events = self._allowed_event_list( + self.__only_tables: Optional[List[str]] = only_tables + self.__ignored_tables: Optional[List[str]] = ignored_tables + self.__only_schemas: Optional[List[str]] = only_schemas + self.__ignored_schemas: Optional[List[str]] = ignored_schemas + self.__freeze_schema: bool = freeze_schema + self.__allowed_events: FrozenSet[str] = self._allowed_event_list( only_events, ignored_events, filter_non_implemented_events) - self.__fail_on_table_metadata_unavailable = fail_on_table_metadata_unavailable - self.__ignore_decode_errors = ignore_decode_errors + self.__fail_on_table_metadata_unavailable: bool = fail_on_table_metadata_unavailable + self.__ignore_decode_errors: bool = ignore_decode_errors # We can't filter on packet level TABLE_MAP and rotate event because # we need them for handling other operations - self.__allowed_events_in_packet = frozenset( + self.__allowed_events_in_packet: FrozenSet[str] = frozenset( [TableMapEvent, RotateEvent]).union(self.__allowed_events) - self.__server_id = server_id - self.__use_checksum = False + self.__server_id: int = server_id + self.__use_checksum: bool = False # Store table meta information - self.table_map = {} - self.log_pos = log_pos - self.end_log_pos = end_log_pos - self.log_file = log_file - self.auto_position = auto_position - self.skip_to_timestamp = skip_to_timestamp - self.is_mariadb = is_mariadb - self.__annotate_rows_event = annotate_rows_event + self.table_map: Dict = {} + self.log_pos: Optional[int] = log_pos + self.end_log_pos: Optional[int] = end_log_pos + self.log_file: Optional[str] = log_file + self.auto_position: Optional[str] = auto_position + self.skip_to_timestamp: Optional[float] = skip_to_timestamp + self.is_mariadb: bool = is_mariadb + self.__annotate_rows_event: bool = annotate_rows_event if end_log_pos: - self.is_past_end_log_pos = False + self.is_past_end_log_pos: bool = False if report_slave: - self.report_slave = ReportSlave(report_slave) - self.slave_uuid = slave_uuid - self.slave_heartbeat = slave_heartbeat + self.report_slave: ReportSlave = ReportSlave(report_slave) + self.slave_uuid: Optional[str] = slave_uuid + self.slave_heartbeat: Optional[float] = slave_heartbeat if pymysql_wrapper: - self.pymysql_wrapper = pymysql_wrapper + self.pymysql_wrapper: Connection = pymysql_wrapper else: - self.pymysql_wrapper = pymysql.connect - self.mysql_version = (0, 0, 0) + self.pymysql_wrapper: Optional[Union[Connection, Type[Connection]]] = pymysql.connect + self.mysql_version: Tuple = (0, 0, 0) - def close(self): + def close(self) -> None: if self.__connected_stream: self._stream_connection.close() - self.__connected_stream = False + self.__connected_stream: bool = False if self.__connected_ctl: # break reference cycle between stream reader and underlying # mysql connection object self._ctl_connection._get_table_information = None self._ctl_connection.close() - self.__connected_ctl = False + self.__connected_ctl: bool = False - def __connect_to_ctl(self): + def __connect_to_ctl(self) -> None: if not self._ctl_connection_settings: - self._ctl_connection_settings = dict(self.__connection_settings) + self._ctl_connection_settings: Dict[str, Any] = dict(self.__connection_settings) self._ctl_connection_settings["db"] = "information_schema" self._ctl_connection_settings["cursorclass"] = DictCursor self._ctl_connection_settings["autocommit"] = True - self._ctl_connection = self.pymysql_wrapper(**self._ctl_connection_settings) + self._ctl_connection: Connection = self.pymysql_wrapper(**self._ctl_connection_settings) self._ctl_connection._get_table_information = self.__get_table_information - self.__connected_ctl = True + self.__connected_ctl: bool = True - def __checksum_enabled(self): - """Return True if binlog-checksum = CRC32. Only for MySQL > 5.6""" - cur = self._stream_connection.cursor() + def __checksum_enabled(self) -> bool: + """ + Return True if binlog-checksum = CRC32. Only for MySQL > 5.6 + """ + cur: Cursor = self._stream_connection.cursor() cur.execute("SHOW GLOBAL VARIABLES LIKE 'BINLOG_CHECKSUM'") - result = cur.fetchone() + result: Optional[Tuple[str, str]] = cur.fetchone() cur.close() if result is None: @@ -274,11 +279,11 @@ def __checksum_enabled(self): return False return True - def _register_slave(self): + def _register_slave(self) -> None: if not self.report_slave: return - packet = self.report_slave.encoded(self.__server_id) + packet: bytes = self.report_slave.encoded(self.__server_id) if pymysql.__version__ < LooseVersion("0.6"): self._stream_connection.wfile.write(packet) @@ -289,40 +294,40 @@ def _register_slave(self): self._stream_connection._next_seq_id = 1 self._stream_connection._read_packet() - def __connect_to_stream(self): + def __connect_to_stream(self) -> None: # log_pos (4) -- position in the binlog-file to start the stream with # flags (2) BINLOG_DUMP_NON_BLOCK (0 or 1) # server_id (4) -- server id of this slave # log_file (string.EOF) -- filename of the binlog on the master - self._stream_connection = self.pymysql_wrapper(**self.__connection_settings) + self._stream_connection: Connection = self.pymysql_wrapper(**self.__connection_settings) - self.__use_checksum = self.__checksum_enabled() + self.__use_checksum: bool = self.__checksum_enabled() # If checksum is enabled we need to inform the server about the that # we support it if self.__use_checksum: - cur = self._stream_connection.cursor() + cur: Cursor = self._stream_connection.cursor() cur.execute("SET @master_binlog_checksum= @@global.binlog_checksum") cur.close() if self.slave_uuid: - cur = self._stream_connection.cursor() + cur: Cursor = self._stream_connection.cursor() cur.execute("SET @slave_uuid = %s, @replica_uuid = %s", (self.slave_uuid, self.slave_uuid)) cur.close() if self.slave_heartbeat: # 4294967 is documented as the max value for heartbeats - net_timeout = float(self.__connection_settings.get('read_timeout', + net_timeout: float = float(self.__connection_settings.get('read_timeout', 4294967)) # If heartbeat is too low, the connection will disconnect before, # this is also the behavior in mysql - heartbeat = float(min(net_timeout / 2., self.slave_heartbeat)) + heartbeat: float = float(min(net_timeout / 2., self.slave_heartbeat)) if heartbeat > 4294967: heartbeat = 4294967 # master_heartbeat_period is nanoseconds - heartbeat = int(heartbeat * 1000000000) - cur = self._stream_connection.cursor() + heartbeat: int = int(heartbeat * 1000000000) + cur: Cursor = self._stream_connection.cursor() cur.execute("SET @master_heartbeat_period= %d" % heartbeat) cur.close() @@ -330,7 +335,7 @@ def __connect_to_stream(self): # Mariadb, when it tries to replace GTID events with dummy ones. Given that this library understands GTID # events, setting the capability to 4 circumvents this error. # If the DB is mysql, this won't have any effect so no need to run this in a condition - cur = self._stream_connection.cursor() + cur: Cursor = self._stream_connection.cursor() cur.execute("SET @mariadb_slave_capability=4") cur.close() @@ -343,15 +348,15 @@ def __connect_to_stream(self): # only when log_file and log_pos both provided, the position info is # valid, if not, get the current position from master if self.log_file is None or self.log_pos is None: - cur = self._stream_connection.cursor() + cur: Cursor = self._stream_connection.cursor() cur.execute("SHOW MASTER STATUS") - master_status = cur.fetchone() + master_status: Optional[Tuple[str, int, Any]] = cur.fetchone() if master_status is None: raise BinLogNotEnabled() self.log_file, self.log_pos = master_status[:2] cur.close() - prelude = struct.pack(' bytes: # https://mariadb.com/kb/en/5-slave-registration/ - cur = self._stream_connection.cursor() + cur: Cursor = self._stream_connection.cursor() if self.auto_position != None: cur.execute("SET @slave_connect_state='%s'" % self.auto_position) cur.execute("SET @slave_gtid_strict_mode=1") @@ -461,21 +466,21 @@ def __set_mariadb_settings(self): cur.close() # https://mariadb.com/kb/en/com_binlog_dump/ - header_size = ( + header_size: int = ( 4 + # binlog pos 2 + # binlog flags 4 + # slave server_id, 4 # requested binlog file name , set it to empty ) - prelude = struct.pack(' Union[BinLogPacketWrapper, None]: while True: if self.end_log_pos and self.is_past_end_log_pos: return None @@ -506,9 +511,9 @@ def fetchone(self): try: if pymysql.__version__ < LooseVersion("0.6"): - pkt = self._stream_connection.read_packet() + pkt: MysqlPacket = self._stream_connection.read_packet() else: - pkt = self._stream_connection._read_packet() + pkt: MysqlPacket = self._stream_connection._read_packet() except pymysql.OperationalError as error: code, message = error.args if code in MYSQL_EXPECTED_ERROR_CODES: @@ -524,7 +529,7 @@ def fetchone(self): if not pkt.is_ok_packet(): continue - binlog_event = BinLogPacketWrapper(pkt, self.table_map, + binlog_event: BinLogPacketWrapper = BinLogPacketWrapper(pkt, self.table_map, self._ctl_connection, self.mysql_version, self.__use_checksum, @@ -549,7 +554,7 @@ def fetchone(self): # invalidates all our cached table id to schema mappings. This means we have to load them all # again for each logfile which is potentially wasted effort but we can't really do much better # without being broken in restart case - self.table_map = {} + self.table_map: Dict = {} elif binlog_event.log_pos: self.log_pos = binlog_event.log_pos @@ -599,8 +604,8 @@ def fetchone(self): return binlog_event.event - def _allowed_event_list(self, only_events, ignored_events, - filter_non_implemented_events): + def _allowed_event_list(self, only_events: Optional[List[str]], ignored_events: Optional[List[str]], + filter_non_implemented_events: bool) -> FrozenSet[str]: if only_events is not None: events = set(only_events) else: @@ -638,13 +643,13 @@ def _allowed_event_list(self, only_events, ignored_events, pass return frozenset(events) - def __get_table_information(self, schema, table): + def __get_table_information(self, schema: str, table: str) -> List[Dict[str, Any]]: for i in range(1, 3): try: if not self.__connected_ctl: self.__connect_to_ctl() - cur = self._ctl_connection.cursor() + cur: Cursor = self._ctl_connection.cursor() cur.execute(""" SELECT COLUMN_NAME, COLLATION_NAME, CHARACTER_SET_NAME, @@ -655,7 +660,7 @@ def __get_table_information(self, schema, table): WHERE table_schema = %s AND table_name = %s """, (schema, table)) - result = sorted(cur.fetchall(), key=lambda x: x['ORDINAL_POSITION']) + result: List = sorted(cur.fetchall(), key=lambda x: x['ORDINAL_POSITION']) cur.close() return result @@ -667,5 +672,5 @@ def __get_table_information(self, schema, table): else: raise error - def __iter__(self): + def __iter__(self) -> Iterator[Union[BinLogPacketWrapper, None]]: return iter(self.fetchone, None)