Skip to content

typing - gtid.py #76

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Aug 31, 2023
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 65 additions & 53 deletions pymysqlreplication/gtid.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,18 @@
import binascii
from copy import deepcopy
from io import BytesIO
from typing import List, Optional, Tuple, Union, Set

def overlap(i1, i2):

def overlap(i1: Tuple[int, int], i2: Tuple[int, int]) -> bool:
return i1[0] < i2[1] and i1[1] > i2[0]

def contains(i1, i2):
def contains(i1: Tuple[int, int], i2: Tuple[int, int]) -> bool:
return i2[0] >= i1[0] and i2[1] <= i1[1]

class Gtid(object):
"""A mysql GTID is composed of a server-id and a set of right-open
"""
A mysql GTID is composed of a server-id and a set of right-open
intervals [a,b), and represent all transactions x that happened on
server SID such as

Expand Down Expand Up @@ -49,7 +52,7 @@ class Gtid(object):
Exception: Adding a Gtid with a different SID.
"""
@staticmethod
def parse_interval(interval):
def parse_interval(interval: str) -> Tuple[int, int]:
"""
We parse a human-generated string here. So our end value b
is incremented to conform to the internal representation format.
Expand All @@ -65,8 +68,9 @@ def parse_interval(interval):
return (a, b+1)

@staticmethod
def parse(gtid):
"""Parse a GTID from mysql textual format.
def parse(gtid: str) -> Tuple[str, List[Tuple[int, int]]]:
"""
Parse a GTID from mysql textual format.

Raises:
- ValueError: if GTID format is incorrect.
Expand All @@ -84,15 +88,15 @@ def parse(gtid):

return (sid, intervals_parsed)

def __add_interval(self, itvl):
def __add_interval(self, itvl: Tuple[int, int]) -> None:
"""
Use the internal representation format and add it
to our intervals, merging if required.

Raises:
Exception: if Malformated interval or Overlapping interval
"""
new = []
new: List[Tuple[int, int]] = []

if itvl[0] > itvl[1]:
raise Exception('Malformed interval %s' % (itvl,))
Expand All @@ -114,11 +118,13 @@ def __add_interval(self, itvl):

self.intervals = sorted(new + [itvl])

def __sub_interval(self, itvl):
"""Using the internal representation, remove an interval
def __sub_interval(self, itvl: Tuple[int, int]) -> None:
"""
Using the internal representation, remove an interval

Raises: Exception if itvl malformated"""
new = []
Raises: Exception if itvl malformated
"""
new: List[Tuple[int, int]] = []

if itvl[0] > itvl[1]:
raise Exception('Malformed interval %s' % (itvl,))
Expand All @@ -139,8 +145,9 @@ def __sub_interval(self, itvl):

self.intervals = new

def __contains__(self, other):
"""Test if other is contained within self.
def __contains__(self, other: 'Gtid') -> bool:
"""
Test if other is contained within self.
First we compare sid they must be equals.

Then we search if intervals from other are contained within
Expand All @@ -152,22 +159,22 @@ def __contains__(self, other):
return all(any(contains(me, them) for me in self.intervals)
for them in other.intervals)

def __init__(self, gtid, sid=None, intervals=[]):
if sid:
intervals = intervals
else:
def __init__(self, gtid: str, sid: Optional[str] = None, intervals: Optional[List[Tuple[int, int]]] = None) -> None:
if sid is None:
sid, intervals = Gtid.parse(gtid)

self.sid = sid
self.intervals = []
for itvl in intervals:
self.__add_interval(itvl)

def __add__(self, other):
"""Include the transactions of this gtid.
def __add__(self, other: 'Gtid') -> 'Gtid':
"""
Include the transactions of this gtid.

Raises:
Exception: if the attempted merge has different SID"""
Exception: if the attempted merge has different SID
"""
Copy link
Member

@heehehe heehehe Aug 31, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

여기 """ 부분 indentation이 위의 """랑 맞아야할 것 같아요!

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@heehehe 수정했습니다!
확인해줘서 감사합니다 👍

if self.sid != other.sid:
raise Exception('Attempt to merge different SID'
'%s != %s' % (self.sid, other.sid))
Expand All @@ -179,9 +186,10 @@ def __add__(self, other):

return result

def __sub__(self, other):
"""Remove intervals. Do not raise, if different SID simply
ignore"""
def __sub__(self, other: 'Gtid') -> 'Gtid':
"""
Remove intervals. Do not raise, if different SID simply ignore
"""
result = deepcopy(self)
if self.sid != other.sid:
return result
Expand All @@ -191,27 +199,30 @@ def __sub__(self, other):

return result

def __str__(self):
"""We represent the human value here - a single number
for one transaction, or a closed interval (decrementing b)"""
def __str__(self) -> str:
"""
We represent the human value here - a single number
for one transaction, or a closed interval (decrementing b)
"""
return '%s:%s' % (self.sid,
':'.join(('%d-%d' % (x[0], x[1]-1)) if x[0] +1 != x[1]
else str(x[0])
for x in self.intervals))

def __repr__(self):
def __repr__(self) -> str:
return '<Gtid "%s">' % self

@property
def encoded_length(self):
def encoded_length(self) -> int:
return (16 + # sid
8 + # n_intervals
2 * # stop/start
8 * # stop/start mark encoded as int64
len(self.intervals))

def encode(self):
"""Encode a Gtid in binary
def encode(self) -> bytes:
"""
Encode a Gtid in binary
Bytes are in **little endian**.

Format:
Expand Down Expand Up @@ -251,8 +262,9 @@ def encode(self):
return buffer

@classmethod
def decode(cls, payload):
"""Decode from binary a Gtid
def decode(cls, payload: BytesIO) -> 'Gtid':
"""
Decode from binary a Gtid

:param BytesIO payload to decode
"""
Expand Down Expand Up @@ -281,35 +293,35 @@ def decode(cls, payload):
else '%d' % x
for x in intervals])))

def __eq__(self, other):
def __eq__(self, other: 'Gtid') -> bool:
if other.sid != self.sid:
return False
return self.intervals == other.intervals

def __lt__(self, other):
def __lt__(self, other: 'Gtid') -> bool:
if other.sid != self.sid:
return self.sid < other.sid
return self.intervals < other.intervals

def __le__(self, other):
def __le__(self, other: 'Gtid') -> bool:
if other.sid != self.sid:
return self.sid <= other.sid
return self.intervals <= other.intervals

def __gt__(self, other):
def __gt__(self, other: 'Gtid') -> bool:
if other.sid != self.sid:
return self.sid > other.sid
return self.intervals > other.intervals

def __ge__(self, other):
def __ge__(self, other: 'Gtid') -> bool:
if other.sid != self.sid:
return self.sid >= other.sid
return self.intervals >= other.intervals


class GtidSet(object):
"""Represents a set of Gtid"""
def __init__(self, gtid_set):
def __init__(self, gtid_set: Optional[Union[None, str, Set[Gtid], List[Gtid], Gtid]] = None) -> None:
"""
Construct a GtidSet initial state depends of the nature of `gtid_set` param.

Expand All @@ -325,21 +337,21 @@ def __init__(self, gtid_set):
- Exception: if Gtid interval are either malformated or overlapping
"""

def _to_gtid(element):
def _to_gtid(element: str) -> Gtid:
if isinstance(element, Gtid):
return element
return Gtid(element.strip(' \n'))

if not gtid_set:
self.gtids = []
self.gtids: List[Gtid] = []
elif isinstance(gtid_set, (list, set)):
self.gtids = [_to_gtid(x) for x in gtid_set]
self.gtids: List[Gtid] = [_to_gtid(x) for x in gtid_set]
else:
self.gtids = [Gtid(x.strip(' \n')) for x in gtid_set.split(',')]
self.gtids: List[Gtid] = [Gtid(x.strip(' \n')) for x in gtid_set.split(',')]

def merge_gtid(self, gtid):
def merge_gtid(self, gtid: Gtid) -> None:
"""Insert a Gtid in current GtidSet."""
new_gtids = []
new_gtids: List[Gtid] = []
for existing in self.gtids:
if existing.sid == gtid.sid:
new_gtids.append(existing + gtid)
Expand All @@ -349,7 +361,7 @@ def merge_gtid(self, gtid):
new_gtids.append(gtid)
self.gtids = new_gtids

def __contains__(self, other):
def __contains__(self, other: Union['GtidSet', Gtid]) -> bool:
"""
Test if self contains other, could be a GtidSet or a Gtid.

Expand All @@ -363,7 +375,7 @@ def __contains__(self, other):
return any(other in x for x in self.gtids)
raise NotImplementedError

def __add__(self, other):
def __add__(self, other: Union['GtidSet', Gtid]) -> 'GtidSet':
"""
Merge current instance with an other GtidSet or with a Gtid alone.

Expand All @@ -384,21 +396,21 @@ def __add__(self, other):

raise NotImplementedError

def __str__(self):
def __str__(self) -> str:
"""
Returns a comma separated string of gtids.
"""
return ','.join(str(x) for x in self.gtids)

def __repr__(self):
def __repr__(self) -> str:
return '<GtidSet %r>' % self.gtids

@property
def encoded_length(self):
def encoded_length(self) -> int:
return (8 + # n_sids
sum(x.encoded_length for x in self.gtids))

def encoded(self):
def encoded(self) -> bytes:
"""Encode a GtidSet in binary
Bytes are in **little endian**.

Expand All @@ -421,7 +433,7 @@ def encoded(self):
encode = encoded

@classmethod
def decode(cls, payload):
def decode(cls, payload: BytesIO) -> 'GtidSet':
"""Decode a GtidSet from binary.

:param BytesIO payload to decode
Expand All @@ -432,5 +444,5 @@ def decode(cls, payload):

return cls([Gtid.decode(payload) for _ in range(0, n_sid)])

def __eq__(self, other):
def __eq__(self, other: 'GtidSet') -> bool:
return self.gtids == other.gtids