Skip to content

typing - gtid.py #76

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Aug 31, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 69 additions & 55 deletions pymysqlreplication/gtid.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,18 @@
import binascii
from copy import deepcopy
from io import BytesIO
from typing import List, Optional, Tuple, Union, Set

def overlap(i1, i2):

def overlap(i1: Tuple[int, int], i2: Tuple[int, int]) -> bool:
return i1[0] < i2[1] and i1[1] > i2[0]

def contains(i1, i2):
def contains(i1: Tuple[int, int], i2: Tuple[int, int]) -> bool:
return i2[0] >= i1[0] and i2[1] <= i1[1]

class Gtid(object):
"""A mysql GTID is composed of a server-id and a set of right-open
"""
A mysql GTID is composed of a server-id and a set of right-open
intervals [a,b), and represent all transactions x that happened on
server SID such as
Expand Down Expand Up @@ -49,7 +52,7 @@ class Gtid(object):
Exception: Adding a Gtid with a different SID.
"""
@staticmethod
def parse_interval(interval):
def parse_interval(interval: str) -> Tuple[int, int]:
"""
We parse a human-generated string here. So our end value b
is incremented to conform to the internal representation format.
Expand All @@ -65,8 +68,9 @@ def parse_interval(interval):
return (a, b+1)

@staticmethod
def parse(gtid):
"""Parse a GTID from mysql textual format.
def parse(gtid: str) -> Tuple[str, List[Tuple[int, int]]]:
"""
Parse a GTID from mysql textual format.
Raises:
- ValueError: if GTID format is incorrect.
Expand All @@ -84,15 +88,15 @@ def parse(gtid):

return (sid, intervals_parsed)

def __add_interval(self, itvl):
def __add_interval(self, itvl: Tuple[int, int]) -> None:
"""
Use the internal representation format and add it
to our intervals, merging if required.
Raises:
Exception: if Malformated interval or Overlapping interval
"""
new = []
new: List[Tuple[int, int]] = []

if itvl[0] > itvl[1]:
raise Exception('Malformed interval %s' % (itvl,))
Expand All @@ -114,11 +118,13 @@ def __add_interval(self, itvl):

self.intervals = sorted(new + [itvl])

def __sub_interval(self, itvl):
"""Using the internal representation, remove an interval
def __sub_interval(self, itvl: Tuple[int, int]) -> None:
"""
Using the internal representation, remove an interval
Raises: Exception if itvl malformated"""
new = []
Raises: Exception if itvl malformated
"""
new: List[Tuple[int, int]] = []

if itvl[0] > itvl[1]:
raise Exception('Malformed interval %s' % (itvl,))
Expand All @@ -139,8 +145,9 @@ def __sub_interval(self, itvl):

self.intervals = new

def __contains__(self, other):
"""Test if other is contained within self.
def __contains__(self, other: 'Gtid') -> bool:
"""
Test if other is contained within self.
First we compare sid they must be equals.
Then we search if intervals from other are contained within
Expand All @@ -152,22 +159,22 @@ def __contains__(self, other):
return all(any(contains(me, them) for me in self.intervals)
for them in other.intervals)

def __init__(self, gtid, sid=None, intervals=[]):
if sid:
intervals = intervals
else:
def __init__(self, gtid: str, sid: Optional[str] = None, intervals: Optional[List[Tuple[int, int]]] = None) -> None:
if sid is None:
sid, intervals = Gtid.parse(gtid)

self.sid = sid
self.intervals = []
for itvl in intervals:
self.__add_interval(itvl)

def __add__(self, other):
"""Include the transactions of this gtid.
def __add__(self, other: 'Gtid') -> 'Gtid':
"""
Include the transactions of this gtid.
Raises:
Exception: if the attempted merge has different SID"""
Exception: if the attempted merge has different SID
"""
if self.sid != other.sid:
raise Exception('Attempt to merge different SID'
'%s != %s' % (self.sid, other.sid))
Expand All @@ -179,9 +186,10 @@ def __add__(self, other):

return result

def __sub__(self, other):
"""Remove intervals. Do not raise, if different SID simply
ignore"""
def __sub__(self, other: 'Gtid') -> 'Gtid':
"""
Remove intervals. Do not raise, if different SID simply ignore
"""
result = deepcopy(self)
if self.sid != other.sid:
return result
Expand All @@ -191,27 +199,30 @@ def __sub__(self, other):

return result

def __str__(self):
"""We represent the human value here - a single number
for one transaction, or a closed interval (decrementing b)"""
def __str__(self) -> str:
"""
We represent the human value here - a single number
for one transaction, or a closed interval (decrementing b)
"""
return '%s:%s' % (self.sid,
':'.join(('%d-%d' % (x[0], x[1]-1)) if x[0] +1 != x[1]
else str(x[0])
for x in self.intervals))

def __repr__(self):
def __repr__(self) -> str:
return '<Gtid "%s">' % self

@property
def encoded_length(self):
def encoded_length(self) -> int:
return (16 + # sid
8 + # n_intervals
2 * # stop/start
8 * # stop/start mark encoded as int64
len(self.intervals))

def encode(self):
"""Encode a Gtid in binary
def encode(self) -> bytes:
"""
Encode a Gtid in binary
Bytes are in **little endian**.
Format:
Expand Down Expand Up @@ -251,8 +262,9 @@ def encode(self):
return buffer

@classmethod
def decode(cls, payload):
"""Decode from binary a Gtid
def decode(cls, payload: BytesIO) -> 'Gtid':
"""
Decode from binary a Gtid
:param BytesIO payload to decode
"""
Expand Down Expand Up @@ -281,35 +293,35 @@ def decode(cls, payload):
else '%d' % x
for x in intervals])))

def __eq__(self, other):
def __eq__(self, other: 'Gtid') -> bool:
if other.sid != self.sid:
return False
return self.intervals == other.intervals

def __lt__(self, other):
def __lt__(self, other: 'Gtid') -> bool:
if other.sid != self.sid:
return self.sid < other.sid
return self.intervals < other.intervals

def __le__(self, other):
def __le__(self, other: 'Gtid') -> bool:
if other.sid != self.sid:
return self.sid <= other.sid
return self.intervals <= other.intervals

def __gt__(self, other):
def __gt__(self, other: 'Gtid') -> bool:
if other.sid != self.sid:
return self.sid > other.sid
return self.intervals > other.intervals

def __ge__(self, other):
def __ge__(self, other: 'Gtid') -> bool:
if other.sid != self.sid:
return self.sid >= other.sid
return self.intervals >= other.intervals


class GtidSet(object):
"""Represents a set of Gtid"""
def __init__(self, gtid_set):
def __init__(self, gtid_set: Optional[Union[None, str, Set[Gtid], List[Gtid], Gtid]] = None) -> None:
"""
Construct a GtidSet initial state depends of the nature of `gtid_set` param.
Expand All @@ -325,21 +337,21 @@ def __init__(self, gtid_set):
- Exception: if Gtid interval are either malformated or overlapping
"""

def _to_gtid(element):
def _to_gtid(element: str) -> Gtid:
if isinstance(element, Gtid):
return element
return Gtid(element.strip(' \n'))

if not gtid_set:
self.gtids = []
self.gtids: List[Gtid] = []
elif isinstance(gtid_set, (list, set)):
self.gtids = [_to_gtid(x) for x in gtid_set]
self.gtids: List[Gtid] = [_to_gtid(x) for x in gtid_set]
else:
self.gtids = [Gtid(x.strip(' \n')) for x in gtid_set.split(',')]
self.gtids: List[Gtid] = [Gtid(x.strip(' \n')) for x in gtid_set.split(',')]

def merge_gtid(self, gtid):
def merge_gtid(self, gtid: Gtid) -> None:
"""Insert a Gtid in current GtidSet."""
new_gtids = []
new_gtids: List[Gtid] = []
for existing in self.gtids:
if existing.sid == gtid.sid:
new_gtids.append(existing + gtid)
Expand All @@ -349,7 +361,7 @@ def merge_gtid(self, gtid):
new_gtids.append(gtid)
self.gtids = new_gtids

def __contains__(self, other):
def __contains__(self, other: Union['GtidSet', Gtid]) -> bool:
"""
Test if self contains other, could be a GtidSet or a Gtid.
Expand All @@ -363,7 +375,7 @@ def __contains__(self, other):
return any(other in x for x in self.gtids)
raise NotImplementedError

def __add__(self, other):
def __add__(self, other: Union['GtidSet', Gtid]) -> 'GtidSet':
"""
Merge current instance with an other GtidSet or with a Gtid alone.
Expand All @@ -384,22 +396,23 @@ def __add__(self, other):

raise NotImplementedError

def __str__(self):
def __str__(self) -> str:
"""
Returns a comma separated string of gtids.
"""
return ','.join(str(x) for x in self.gtids)

def __repr__(self):
def __repr__(self) -> str:
return '<GtidSet %r>' % self.gtids

@property
def encoded_length(self):
def encoded_length(self) -> int:
return (8 + # n_sids
sum(x.encoded_length for x in self.gtids))

def encoded(self):
"""Encode a GtidSet in binary
def encoded(self) -> bytes:
"""
Encode a GtidSet in binary
Bytes are in **little endian**.
- `n_sid`: u64 is the number of Gtid to read
Expand All @@ -421,8 +434,9 @@ def encoded(self):
encode = encoded

@classmethod
def decode(cls, payload):
"""Decode a GtidSet from binary.
def decode(cls, payload: BytesIO) -> 'GtidSet':
"""
Decode a GtidSet from binary.
:param BytesIO payload to decode
"""
Expand All @@ -432,5 +446,5 @@ def decode(cls, payload):

return cls([Gtid.decode(payload) for _ in range(0, n_sid)])

def __eq__(self, other):
def __eq__(self, other: 'GtidSet') -> bool:
return self.gtids == other.gtids