From 826bac51d7c14415fe3052feb23d69bb97cd596e Mon Sep 17 00:00:00 2001 From: mjs Date: Wed, 30 Aug 2023 22:42:45 +0900 Subject: [PATCH 1/3] docs: add typing gtid.py & refactor Gtid.__init__ --- pymysqlreplication/gtid.py | 87 +++++++++++++++++++++----------------- 1 file changed, 49 insertions(+), 38 deletions(-) diff --git a/pymysqlreplication/gtid.py b/pymysqlreplication/gtid.py index 3b2554da..cc7f5c2f 100644 --- a/pymysqlreplication/gtid.py +++ b/pymysqlreplication/gtid.py @@ -5,15 +5,17 @@ import binascii from copy import deepcopy from io import BytesIO +from typing import List, Optional, Tuple -def overlap(i1, i2): +def overlap(i1: Tuple[int, int], i2: Tuple[int, int]) -> bool: return i1[0] < i2[1] and i1[1] > i2[0] -def contains(i1, i2): +def contains(i1: Tuple[int, int], i2: Tuple[int, int]) -> bool: return i2[0] >= i1[0] and i2[1] <= i1[1] class Gtid(object): - """A mysql GTID is composed of a server-id and a set of right-open + """ + A mysql GTID is composed of a server-id and a set of right-open intervals [a,b), and represent all transactions x that happened on server SID such as @@ -49,7 +51,7 @@ class Gtid(object): Exception: Adding a Gtid with a different SID. """ @staticmethod - def parse_interval(interval): + def parse_interval(interval: str) -> Tuple[int, int]: """ We parse a human-generated string here. So our end value b is incremented to conform to the internal representation format. @@ -65,8 +67,9 @@ def parse_interval(interval): return (a, b+1) @staticmethod - def parse(gtid): - """Parse a GTID from mysql textual format. + def parse(gtid: str) -> Tuple[str, List[Tuple[int, int]]]: + """ + Parse a GTID from mysql textual format. Raises: - ValueError: if GTID format is incorrect. @@ -84,7 +87,7 @@ def parse(gtid): return (sid, intervals_parsed) - def __add_interval(self, itvl): + def __add_interval(self, itvl: Tuple[int, int]) -> None: """ Use the internal representation format and add it to our intervals, merging if required. @@ -92,7 +95,7 @@ def __add_interval(self, itvl): Raises: Exception: if Malformated interval or Overlapping interval """ - new = [] + new: List[Tuple[int, int]] = [] if itvl[0] > itvl[1]: raise Exception('Malformed interval %s' % (itvl,)) @@ -114,11 +117,13 @@ def __add_interval(self, itvl): self.intervals = sorted(new + [itvl]) - def __sub_interval(self, itvl): - """Using the internal representation, remove an interval + def __sub_interval(self, itvl: Tuple[int, int]) -> None: + """ + Using the internal representation, remove an interval - Raises: Exception if itvl malformated""" - new = [] + Raises: Exception if itvl malformated + """ + new: List[Tuple[int, int]] = [] if itvl[0] > itvl[1]: raise Exception('Malformed interval %s' % (itvl,)) @@ -139,8 +144,9 @@ def __sub_interval(self, itvl): self.intervals = new - def __contains__(self, other): - """Test if other is contained within self. + def __contains__(self, other: 'Gtid') -> bool: + """ + Test if other is contained within self. First we compare sid they must be equals. Then we search if intervals from other are contained within @@ -152,10 +158,8 @@ def __contains__(self, other): return all(any(contains(me, them) for me in self.intervals) for them in other.intervals) - def __init__(self, gtid, sid=None, intervals=[]): - if sid: - intervals = intervals - else: + def __init__(self, gtid: str, sid: Optional[str] = None, intervals: Optional[List[Tuple[int, int]]] = None) -> None: + if sid is None: sid, intervals = Gtid.parse(gtid) self.sid = sid @@ -163,11 +167,13 @@ def __init__(self, gtid, sid=None, intervals=[]): for itvl in intervals: self.__add_interval(itvl) - def __add__(self, other): - """Include the transactions of this gtid. + def __add__(self, other: 'Gtid') -> 'Gtid': + """ + Include the transactions of this gtid. Raises: - Exception: if the attempted merge has different SID""" + Exception: if the attempted merge has different SID + """ if self.sid != other.sid: raise Exception('Attempt to merge different SID' '%s != %s' % (self.sid, other.sid)) @@ -179,9 +185,10 @@ def __add__(self, other): return result - def __sub__(self, other): - """Remove intervals. Do not raise, if different SID simply - ignore""" + def __sub__(self, other: 'Gtid') -> 'Gtid': + """ + Remove intervals. Do not raise, if different SID simply ignore + """ result = deepcopy(self) if self.sid != other.sid: return result @@ -191,27 +198,30 @@ def __sub__(self, other): return result - def __str__(self): - """We represent the human value here - a single number - for one transaction, or a closed interval (decrementing b)""" + def __str__(self) -> str: + """ + We represent the human value here - a single number + for one transaction, or a closed interval (decrementing b) + """ return '%s:%s' % (self.sid, ':'.join(('%d-%d' % (x[0], x[1]-1)) if x[0] +1 != x[1] else str(x[0]) for x in self.intervals)) - def __repr__(self): + def __repr__(self) -> str: return '' % self @property - def encoded_length(self): + def encoded_length(self) -> int: return (16 + # sid 8 + # n_intervals 2 * # stop/start 8 * # stop/start mark encoded as int64 len(self.intervals)) - def encode(self): - """Encode a Gtid in binary + def encode(self) -> bytes: + """ + Encode a Gtid in binary Bytes are in **little endian**. Format: @@ -251,8 +261,9 @@ def encode(self): return buffer @classmethod - def decode(cls, payload): - """Decode from binary a Gtid + def decode(cls, payload: BytesIO) -> 'Gtid': + """ + Decode from binary a Gtid :param BytesIO payload to decode """ @@ -281,27 +292,27 @@ def decode(cls, payload): else '%d' % x for x in intervals]))) - def __eq__(self, other): + def __eq__(self, other: 'Gtid') -> bool: if other.sid != self.sid: return False return self.intervals == other.intervals - def __lt__(self, other): + def __lt__(self, other: 'Gtid') -> bool: if other.sid != self.sid: return self.sid < other.sid return self.intervals < other.intervals - def __le__(self, other): + def __le__(self, other: 'Gtid') -> bool: if other.sid != self.sid: return self.sid <= other.sid return self.intervals <= other.intervals - def __gt__(self, other): + def __gt__(self, other: 'Gtid') -> bool: if other.sid != self.sid: return self.sid > other.sid return self.intervals > other.intervals - def __ge__(self, other): + def __ge__(self, other: 'Gtid') -> bool: if other.sid != self.sid: return self.sid >= other.sid return self.intervals >= other.intervals From ee0cb3ecd118a112a34771e1a6d19160ea6f1aef Mon Sep 17 00:00:00 2001 From: mikaniz Date: Thu, 31 Aug 2023 08:02:10 +0900 Subject: [PATCH 2/3] docs: add typing gtid.py --- pymysqlreplication/gtid.py | 33 +++++++++++++++++---------------- 1 file changed, 17 insertions(+), 16 deletions(-) diff --git a/pymysqlreplication/gtid.py b/pymysqlreplication/gtid.py index cc7f5c2f..a5a979d2 100644 --- a/pymysqlreplication/gtid.py +++ b/pymysqlreplication/gtid.py @@ -5,7 +5,8 @@ import binascii from copy import deepcopy from io import BytesIO -from typing import List, Optional, Tuple +from typing import List, Optional, Tuple, Union, Set + def overlap(i1: Tuple[int, int], i2: Tuple[int, int]) -> bool: return i1[0] < i2[1] and i1[1] > i2[0] @@ -320,7 +321,7 @@ def __ge__(self, other: 'Gtid') -> bool: class GtidSet(object): """Represents a set of Gtid""" - def __init__(self, gtid_set): + def __init__(self, gtid_set: Optional[Union[None, str, Set[Gtid], List[Gtid], Gtid]] = None) -> None: """ Construct a GtidSet initial state depends of the nature of `gtid_set` param. @@ -336,21 +337,21 @@ def __init__(self, gtid_set): - Exception: if Gtid interval are either malformated or overlapping """ - def _to_gtid(element): + def _to_gtid(element: str) -> Gtid: if isinstance(element, Gtid): return element return Gtid(element.strip(' \n')) if not gtid_set: - self.gtids = [] + self.gtids: List[Gtid] = [] elif isinstance(gtid_set, (list, set)): - self.gtids = [_to_gtid(x) for x in gtid_set] + self.gtids: List[Gtid] = [_to_gtid(x) for x in gtid_set] else: - self.gtids = [Gtid(x.strip(' \n')) for x in gtid_set.split(',')] + self.gtids: List[Gtid] = [Gtid(x.strip(' \n')) for x in gtid_set.split(',')] - def merge_gtid(self, gtid): + def merge_gtid(self, gtid: Gtid) -> None: """Insert a Gtid in current GtidSet.""" - new_gtids = [] + new_gtids: List[Gtid] = [] for existing in self.gtids: if existing.sid == gtid.sid: new_gtids.append(existing + gtid) @@ -360,7 +361,7 @@ def merge_gtid(self, gtid): new_gtids.append(gtid) self.gtids = new_gtids - def __contains__(self, other): + def __contains__(self, other: Union['GtidSet', Gtid]) -> bool: """ Test if self contains other, could be a GtidSet or a Gtid. @@ -374,7 +375,7 @@ def __contains__(self, other): return any(other in x for x in self.gtids) raise NotImplementedError - def __add__(self, other): + def __add__(self, other: Union['GtidSet', Gtid]) -> 'GtidSet': """ Merge current instance with an other GtidSet or with a Gtid alone. @@ -395,21 +396,21 @@ def __add__(self, other): raise NotImplementedError - def __str__(self): + def __str__(self) -> str: """ Returns a comma separated string of gtids. """ return ','.join(str(x) for x in self.gtids) - def __repr__(self): + def __repr__(self) -> str: return '' % self.gtids @property - def encoded_length(self): + def encoded_length(self) -> int: return (8 + # n_sids sum(x.encoded_length for x in self.gtids)) - def encoded(self): + def encoded(self) -> bytes: """Encode a GtidSet in binary Bytes are in **little endian**. @@ -432,7 +433,7 @@ def encoded(self): encode = encoded @classmethod - def decode(cls, payload): + def decode(cls, payload: BytesIO) -> 'GtidSet': """Decode a GtidSet from binary. :param BytesIO payload to decode @@ -443,5 +444,5 @@ def decode(cls, payload): return cls([Gtid.decode(payload) for _ in range(0, n_sid)]) - def __eq__(self, other): + def __eq__(self, other: 'GtidSet') -> bool: return self.gtids == other.gtids From d4220ddfbf0f3919d6a47e3af7025209190f9c21 Mon Sep 17 00:00:00 2001 From: mjs Date: Thu, 31 Aug 2023 21:21:37 +0900 Subject: [PATCH 3/3] refactor: update docstring format --- pymysqlreplication/gtid.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/pymysqlreplication/gtid.py b/pymysqlreplication/gtid.py index a5a979d2..df80aac2 100644 --- a/pymysqlreplication/gtid.py +++ b/pymysqlreplication/gtid.py @@ -174,7 +174,7 @@ def __add__(self, other: 'Gtid') -> 'Gtid': Raises: Exception: if the attempted merge has different SID - """ + """ if self.sid != other.sid: raise Exception('Attempt to merge different SID' '%s != %s' % (self.sid, other.sid)) @@ -411,7 +411,8 @@ def encoded_length(self) -> int: sum(x.encoded_length for x in self.gtids)) def encoded(self) -> bytes: - """Encode a GtidSet in binary + """ + Encode a GtidSet in binary Bytes are in **little endian**. - `n_sid`: u64 is the number of Gtid to read @@ -434,7 +435,8 @@ def encoded(self) -> bytes: @classmethod def decode(cls, payload: BytesIO) -> 'GtidSet': - """Decode a GtidSet from binary. + """ + Decode a GtidSet from binary. :param BytesIO payload to decode """