diff --git a/pymysqlreplication/gtid.py b/pymysqlreplication/gtid.py index c2a30145..e0f7c1dc 100644 --- a/pymysqlreplication/gtid.py +++ b/pymysqlreplication/gtid.py @@ -1,14 +1,16 @@ # -*- coding: utf-8 -*- +from __future__ import annotations +from typing import List, Tuple, Union, Set, Optional import re import struct import binascii from io import BytesIO -def overlap(i1, i2): +def overlap(i1: Tuple[int, int], i2: Tuple[int, int]) -> bool: return i1[0] < i2[1] and i1[1] > i2[0] -def contains(i1, i2): +def contains(i1: Tuple[int, int], i2: Tuple[int, int]) -> bool: return i2[0] >= i1[0] and i2[1] <= i1[1] class Gtid(object): @@ -47,8 +49,19 @@ class Gtid(object): Exception: Adding an already present transaction number (one that overlaps). Exception: Adding a Gtid with a different SID. """ + def __init__(self, gtid: str, sid=None, intervals: List[Tuple[int, int]]=[]): + if sid: + intervals = intervals + else: + sid, intervals = Gtid.parse(gtid) + + self.sid: str = sid + self.intervals: List[Tuple[int, int]] = [] + for itvl in intervals: + self.__add_interval(itvl) + @staticmethod - def parse_interval(interval): + def parse_interval(interval: str) -> Tuple[int, int]: """ We parse a human-generated string here. So our end value b is incremented to conform to the internal representation format. @@ -64,7 +77,7 @@ def parse_interval(interval): return (a, b+1) @staticmethod - def parse(gtid): + def parse(gtid: str) -> Tuple[str, List[Tuple[int, int]]]: """Parse a GTID from mysql textual format. Raises: @@ -83,7 +96,7 @@ def parse(gtid): return (sid, intervals_parsed) - def __add_interval(self, itvl): + def __add_interval(self, itvl: Tuple[int, int]): """ Use the internal representation format and add it to our intervals, merging if required. @@ -113,7 +126,7 @@ def __add_interval(self, itvl): self.intervals = sorted(new + [itvl]) - def __sub_interval(self, itvl): + def __sub_interval(self, itvl: Tuple[int, int]): """Using the internal representation, remove an interval Raises: Exception if itvl malformated""" @@ -138,25 +151,16 @@ def __sub_interval(self, itvl): self.intervals = new - def __contains__(self, other): + def __contains__(self, other: Gtid) -> bool: if other.sid != self.sid: return False return all(any(contains(me, them) for me in self.intervals) for them in other.intervals) - def __init__(self, gtid, sid=None, intervals=[]): - if sid: - intervals = intervals - else: - sid, intervals = Gtid.parse(gtid) - self.sid = sid - self.intervals = [] - for itvl in intervals: - self.__add_interval(itvl) - def __add__(self, other): + def __add__(self, other: Gtid) -> Gtid: """Include the transactions of this gtid. Raises: @@ -172,7 +176,7 @@ def __add__(self, other): return result - def __sub__(self, other): + def __sub__(self, other: Gtid) -> Gtid: """Remove intervals. Do not raise, if different SID simply ignore""" result = Gtid(str(self)) @@ -184,7 +188,7 @@ def __sub__(self, other): return result - def __str__(self): + def __str__(self) -> str: """We represent the human value here - a single number for one transaction, or a closed interval (decrementing b)""" return '%s:%s' % (self.sid, @@ -192,18 +196,18 @@ def __str__(self): else str(x[0]) for x in self.intervals)) - def __repr__(self): + def __repr__(self) -> str: return '' % self @property - def encoded_length(self): + def encoded_length(self) -> int: return (16 + # sid 8 + # n_intervals 2 * # stop/start 8 * # stop/start mark encoded as int64 len(self.intervals)) - def encode(self): + def encode(self) -> bytes: buffer = b'' # sid buffer += binascii.unhexlify(self.sid.replace('-', '')) @@ -219,7 +223,7 @@ def encode(self): return buffer @classmethod - def decode(cls, payload): + def decode(cls, payload: BytesIO) -> Gtid: assert isinstance(payload, BytesIO), \ 'payload is expected to be a BytesIO' sid = b'' @@ -236,13 +240,11 @@ def decode(cls, payload): (n_intervals,) = struct.unpack(' bool: + """Equality between a Gtid and an other Gtid. + + Raise: NotImplemented if compared with anything else. + """ + if not isinstance(other, Gtid): + raise NotImplemented + if other.sid != self.sid: return False return self.intervals == other.intervals - def __lt__(self, other): + def __lt__(self, other: Gtid) -> bool: + """Check if a Gtid is lesser an other Gtid. + """ if other.sid != self.sid: return self.sid < other.sid return self.intervals < other.intervals - def __le__(self, other): + def __le__(self, other: Gtid) -> bool: + """Check if a Gtid is lesser or equal an other Gtid. + """ if other.sid != self.sid: return self.sid <= other.sid return self.intervals <= other.intervals - def __gt__(self, other): + def __gt__(self, other: Gtid) -> bool: if other.sid != self.sid: return self.sid > other.sid return self.intervals > other.intervals - def __ge__(self, other): + def __ge__(self, other: Gtid) -> bool: if other.sid != self.sid: return self.sid >= other.sid return self.intervals >= other.intervals @@ -278,7 +291,7 @@ def __ge__(self, other): class GtidSet(object): """Represents a set of Gtid""" - def __init__(self, gtid_set): + def __init__(self, gtid_set: Optional[Union[str, Union[Set[Gtid], Set[str], List[Gtid], List[str]]]]): """ Construct a GtidSet initial state depends of the nature of `gtid_set` param. @@ -300,7 +313,7 @@ def _to_gtid(element): return Gtid(element.strip(' \n')) if not gtid_set: - self.gtids = [] + self.gtids: List[Gtid] = [] elif isinstance(gtid_set, (list, set)): self.gtids = [_to_gtid(x) for x in gtid_set] else: @@ -317,7 +330,7 @@ def merge_gtid(self, gtid): new_gtids.append(gtid) self.gtids = new_gtids - def __contains__(self, other): + def __contains__(self, other: Union[GtidSet, Gtid]) -> bool: """ Raises: - NotImplementedError other is not a GtidSet neither a Gtid, @@ -329,7 +342,7 @@ def __contains__(self, other): return any(other in x for x in self.gtids) raise NotImplementedError - def __add__(self, other): + def __add__(self, other: Union[GtidSet, Gtid]) -> GtidSet: """ Merge current instance with an other GtidSet or with a Gtid alone. @@ -350,33 +363,39 @@ def __add__(self, other): raise NotImplementedError - def __str__(self): + def __str__(self) -> str: """ Returns a comma separated string of gtids. """ return ','.join(str(x) for x in self.gtids) - def __repr__(self): + def __repr__(self) -> str: return '' % self.gtids @property - def encoded_length(self): + def encoded_length(self) -> int: return (8 + # n_sids sum(x.encoded_length for x in self.gtids)) - def encoded(self): + def encoded(self) -> bytes: return b'' + (struct.pack(' GtidSet: assert isinstance(payload, BytesIO), \ 'payload is expected to be a BytesIO' (n_sid,) = struct.unpack(' bool: + """Equality between a GtidSet and an other GtidSet. + + Raise: NotImplemented if compared with anything else. + """ + if not isinstance(other, GtidSet): + raise NotImplemented return self.gtids == other.gtids