diff --git a/pymysqlreplication/row_event.py b/pymysqlreplication/row_event.py index fcd138d3..b5ff096e 100644 --- a/pymysqlreplication/row_event.py +++ b/pymysqlreplication/row_event.py @@ -5,6 +5,7 @@ import datetime import json +from typing import List, Dict, Any, Optional from pymysql.charset import charset_by_name from .event import BinLogEvent @@ -15,25 +16,33 @@ from .table import Table from .bitmap import BitCount, BitGet + class RowsEvent(BinLogEvent): - def __init__(self, from_packet, event_size, table_map, ctl_connection, **kwargs): + def __init__( + self, + from_packet: Any, + event_size: int, + table_map: Dict[int, Any], + ctl_connection: Any, + **kwargs: Any + ) -> None: super().__init__(from_packet, event_size, table_map, - ctl_connection, **kwargs) - self.__rows = None + ctl_connection, **kwargs) + self.__rows: Optional[List[Any]] = None self.__only_tables = kwargs["only_tables"] self.__ignored_tables = kwargs["ignored_tables"] self.__only_schemas = kwargs["only_schemas"] self.__ignored_schemas = kwargs["ignored_schemas"] - #Header - self.table_id = self._read_table_id() + # Header + self.table_id: int = self._read_table_id() # Additional information try: - self.primary_key = table_map[self.table_id].data["primary_key"] - self.schema = self.table_map[self.table_id].schema - self.table = self.table_map[self.table_id].table - except KeyError: #If we have filter the corresponding TableMap Event + self.primary_key: Any = table_map[self.table_id].data["primary_key"] + self.schema: str = self.table_map[self.table_id].schema + self.table: str = self.table_map[self.table_id].table + except KeyError: # If we have filter the corresponding TableMap Event self._processed = False return @@ -51,32 +60,31 @@ def __init__(self, from_packet, event_size, table_map, ctl_connection, **kwargs) self._processed = False return - - #Event V2 + # Event V2 if self.event_type == BINLOG.WRITE_ROWS_EVENT_V2 or \ self.event_type == BINLOG.DELETE_ROWS_EVENT_V2 or \ self.event_type == BINLOG.UPDATE_ROWS_EVENT_V2: - self.flags, self.extra_data_length = struct.unpack(' 2: - self.extra_data_type = struct.unpack(' 2: + self.extra_data_type = struct.unpack(' int: bit = null_bitmap[int(position / 8)] if type(bit) is str: bit = ord(bit) return bit & (1 << (position % 8)) - def _read_column_data(self, cols_bitmap): + def _read_column_data(self, + cols_bitmap: int + ) -> Dict[str, Any]: """Use for WRITE, UPDATE and DELETE events. Return an array of column data """ - values = {} + values: Dict[str, Any] = {} # null bitmap length = (bits set in 'columns-present-bitmap'+7)/8 # See http://dev.mysql.com/doc/internals/en/rows-event.html @@ -121,8 +133,16 @@ def _read_column_data(self, cols_bitmap): return values - def __read_values_name(self, column, null_bitmap, null_bitmap_index, cols_bitmap, unsigned, zerofill, - fixed_binary_length, i): + def __read_values_name(self, + column: Column, + null_bitmap: Any, + null_bitmap_index: int, + cols_bitmap: Any, + unsigned: bool, + zerofill: bool, + fixed_binary_length: int, + i: int + ): if BitGet(cols_bitmap, i) == 0: return None @@ -231,19 +251,33 @@ def __read_values_name(self, column, null_bitmap, null_bitmap_index, cols_bitmap return self.packet.read_binary_json(column.length_size) else: raise NotImplementedError("Unknown MySQL column type: %d" % - (column.type)) + column.type) + + def __add_fsp_to_time(self, + time: datetime, + column: Column + ) -> datetime: + """ + Read and add the fractional part of time - def __add_fsp_to_time(self, time, column): - """Read and add the fractional part of time For more details about new date format: - http://dev.mysql.com/doc/internals/en/date-and-time-data-type-representation.html + https://dev.mysql.com/doc/refman/8.0/en/date-and-time-types.html + + :param time: The datetime object representing the time. + :type time: datetime.datetime + :param column: The MySQL column containing fractional seconds information. + :type column: Any (actual type should be defined) + :return: The datetime object with added fractional seconds. + :rtype: datetime.datetime """ microsecond = self.__read_fsp(column) if microsecond > 0: time = time.replace(microsecond=microsecond) return time - def __read_fsp(self, column): + def __read_fsp(self, + column: Column + ) -> int: read = 0 if column.fsp == 1 or column.fsp == 2: read = 1 @@ -255,15 +289,18 @@ def __read_fsp(self, column): microsecond = self.packet.read_int_be_by_size(read) if column.fsp % 2: microsecond = int(microsecond / 10) - return microsecond * (10 ** (6-column.fsp)) + return microsecond * (10 ** (6 - column.fsp)) return 0 @staticmethod - def charset_to_encoding(name): + def charset_to_encoding(name: str) -> str: charset = charset_by_name(name) return charset.encoding if charset else name - def __read_string(self, size, column): + def __read_string(self, + size: int, + column: Column + ) -> str: string = self.packet.read_length_coded_pascal_string(size) if column.character_set_name is not None: encoding = self.charset_to_encoding(column.character_set_name) @@ -271,7 +308,9 @@ def __read_string(self, size, column): string = string.decode(encoding, decode_errors) return string - def __read_bit(self, column): + def __read_bit(self, + column: Column + ) -> str: """Read MySQL BIT type""" resp = "" for byte in range(0, column.bytes): @@ -294,7 +333,7 @@ def __read_bit(self, column): resp += current_byte[::-1] return resp - def __read_time(self): + def __read_time(self) -> datetime.timedelta: time = self.packet.read_uint24() date = datetime.timedelta( hours=int(time / 10000), @@ -302,7 +341,9 @@ def __read_time(self): seconds=int(time % 100)) return date - def __read_time2(self, column): + def __read_time2(self, + column: Column + ) -> datetime.timedelta: """TIME encoding for nonfractional part: 1 bit sign (1= non-negative, 0= negative) @@ -329,7 +370,7 @@ def __read_time2(self, column): ) * sign return t - def __read_date(self): + def __read_date(self) -> Optional[datetime.date]: time = self.packet.read_uint24() if time == 0: # nasty mysql 0000-00-00 dates return None @@ -347,7 +388,7 @@ def __read_date(self): ) return date - def __read_datetime(self): + def __read_datetime(self) -> Optional[datetime.date]: value = self.packet.read_uint64() if value == 0: # nasty mysql 0000-00-00 dates return None @@ -370,7 +411,9 @@ def __read_datetime(self): second=int(time % 100)) return date - def __read_datetime2(self, column): + def __read_datetime2(self, + column: Column + ) -> Optional[datetime.datetime]: """DATETIME 1 bit sign (1= non-negative, 0= negative) @@ -397,13 +440,21 @@ def __read_datetime2(self, column): return None return self.__add_fsp_to_time(t, column) - def __read_new_decimal(self, column): - """Read MySQL's new decimal format introduced in MySQL 5""" + def __read_new_decimal(self, + column: Column + ) -> decimal.Decimal: + """ + Read MySQL's new decimal format introduced in MySQL 5 - # This project was a great source of inspiration for - # understanding this storage format. - # https://github.com/jeremycole/mysql_binlog + This project was a great source of inspiration for + understanding this storage format. + https://github.com/jeremycole/mysql_binlog + :param column: The MySQL column containing the new decimal value. + :type column: Any (actual type should be defined) + :return: The Python Decimal object representing the new decimal value. + :rtype: decimal.Decimal + """ digits_per_integer = 9 compressed_bytes = [0, 1, 1, 2, 2, 3, 3, 4, 4, 4] integral = (column.precision - column.decimals) @@ -414,7 +465,7 @@ def __read_new_decimal(self, column): * digits_per_integer) # Support negative - # The sign is encoded in the high bit of the the byte + # The sign is encoded in the high bit of the byte # But this bit can also be used in the value value = self.packet.read_uint8() if value & 0x80 != 0: @@ -447,13 +498,25 @@ def __read_new_decimal(self, column): return decimal.Decimal(res) - def __read_binary_slice(self, binary, start, size, data_length): + def __read_binary_slice(self, + binary: int, + start: int, + size: int, + data_length: int + ) -> int: """ - Read a part of binary data and extract a number - binary: the data - start: From which bit (1 to X) - size: How many bits should be read - data_length: data size + Read a part of binary data and extract a number. + + :param binary: The binary data. + :type binary: int + :param start: From which bit (1 to X). + :type start: int + :param size: How many bits should be read. + :type size: int + :param data_length: Size of the data. + :type data_length: int + :return: Extracted number from binary data. + :rtype: int """ binary = binary >> data_length - (start + size) mask = ((1 << size) - 1) @@ -482,25 +545,30 @@ def rows(self): class DeleteRowsEvent(RowsEvent): - """This event is trigger when a row in the database is removed + """ + This event is trigger when a row in the database is removed For each row you have a hash with a single key: values which contain the data of the removed line. """ - def __init__(self, from_packet, event_size, table_map, ctl_connection, **kwargs): + def __init__(self, + from_packet: Any, + event_size: int, + table_map: Dict[int, Any], + ctl_connection: Any, + **kwargs: Any + ) -> None: super().__init__(from_packet, event_size, - table_map, ctl_connection, **kwargs) + table_map, ctl_connection, **kwargs) if self._processed: self.columns_present_bitmap = self.packet.read( (self.number_of_columns + 7) / 8) - def _fetch_one_row(self): - row = {} - - row["values"] = self._read_column_data(self.columns_present_bitmap) + def _fetch_one_row(self) -> Dict[str, Any]: + row: Dict[str, Any] = {"values": self._read_column_data(self.columns_present_bitmap)} return row - def _dump(self): + def _dump(self) -> None: super()._dump() print("Values:") for row in self.rows: @@ -510,25 +578,30 @@ def _dump(self): class WriteRowsEvent(RowsEvent): - """This event is triggered when a row in database is added + """ + This event is triggered when a row in database is added For each row you have a hash with a single key: values which contain the data of the new line. """ - def __init__(self, from_packet, event_size, table_map, ctl_connection, **kwargs): + def __init__(self, + from_packet: Any, + event_size: int, + table_map: Dict[int, Any], + ctl_connection: Any, + **kwargs: Any + ) -> None: super().__init__(from_packet, event_size, - table_map, ctl_connection, **kwargs) + table_map, ctl_connection, **kwargs) if self._processed: self.columns_present_bitmap = self.packet.read( (self.number_of_columns + 7) / 8) - def _fetch_one_row(self): - row = {} - - row["values"] = self._read_column_data(self.columns_present_bitmap) + def _fetch_one_row(self) -> Dict[str, Any]: + row: Dict[str, Any] = {"values": self._read_column_data(self.columns_present_bitmap)} return row - def _dump(self): + def _dump(self) -> None: super()._dump() print("Values:") for row in self.rows: @@ -548,25 +621,28 @@ class UpdateRowsEvent(RowsEvent): http://dev.mysql.com/doc/refman/5.6/en/replication-options-binary-log.html#sysvar_binlog_row_image """ - def __init__(self, from_packet, event_size, table_map, ctl_connection, **kwargs): + def __init__(self, + from_packet: Any, + event_size: int, + table_map: Dict[int, Any], + ctl_connection: Any, + **kwargs: Any + ) -> None: super().__init__(from_packet, event_size, - table_map, ctl_connection, **kwargs) + table_map, ctl_connection, **kwargs) if self._processed: - #Body + # Body self.columns_present_bitmap = self.packet.read( (self.number_of_columns + 7) / 8) self.columns_present_bitmap2 = self.packet.read( (self.number_of_columns + 7) / 8) - def _fetch_one_row(self): - row = {} - - row["before_values"] = self._read_column_data(self.columns_present_bitmap) - - row["after_values"] = self._read_column_data(self.columns_present_bitmap2) + def _fetch_one_row(self) -> Dict[str, Any]: + row: Dict[str, Any] = {"before_values": self._read_column_data(self.columns_present_bitmap), + "after_values": self._read_column_data(self.columns_present_bitmap2)} return row - def _dump(self): + def _dump(self) -> None: super()._dump() print("Affected columns: %d" % self.number_of_columns) print("Values:") @@ -579,14 +655,21 @@ def _dump(self): class TableMapEvent(BinLogEvent): - """This event describes the structure of a table. + """ + This event describes the structure of a table. It's sent before a change happens on a table. An end user of the lib should have no usage of this """ - def __init__(self, from_packet, event_size, table_map, ctl_connection, **kwargs): + def __init__(self, + from_packet: Any, + event_size: int, + table_map: Dict[int, Any], + ctl_connection: Any, + **kwargs: Any + ) -> None: super().__init__(from_packet, event_size, - table_map, ctl_connection, **kwargs) + table_map, ctl_connection, **kwargs) self.__only_tables = kwargs["only_tables"] self.__ignored_tables = kwargs["ignored_tables"] self.__only_schemas = kwargs["only_schemas"] @@ -676,12 +759,12 @@ def __init__(self, from_packet, event_size, table_map, ctl_connection, **kwargs) ## Refer to definition of and call to row.event._is_null() to interpret bitmap corresponding to columns self.null_bitmask = self.packet.read((self.column_count + 7) / 8) - def get_table(self): + def get_table(self) -> Table: return self.table_obj - def _dump(self): + def _dump(self) -> None: super()._dump() - print("Table id: %d" % (self.table_id)) - print("Schema: %s" % (self.schema)) - print("Table: %s" % (self.table)) - print("Columns: %s" % (self.column_count)) + print("Table id: %d" % self.table_id) + print("Schema: %s" % self.schema) + print("Table: %s" % self.table) + print("Columns: %s" % self.column_count)