Skip to content

Typing: Gtid and GtidSet. #393

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 60 additions & 41 deletions pymysqlreplication/gtid.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
# -*- coding: utf-8 -*-
from __future__ import annotations
from typing import List, Tuple, Union, Set, Optional

import re
import struct
import binascii
from io import BytesIO

def overlap(i1, i2):
def overlap(i1: Tuple[int, int], i2: Tuple[int, int]) -> bool:
return i1[0] < i2[1] and i1[1] > i2[0]

def contains(i1, i2):
def contains(i1: Tuple[int, int], i2: Tuple[int, int]) -> bool:
return i2[0] >= i1[0] and i2[1] <= i1[1]

class Gtid(object):
Expand Down Expand Up @@ -47,8 +49,19 @@ class Gtid(object):
Exception: Adding an already present transaction number (one that overlaps).
Exception: Adding a Gtid with a different SID.
"""
def __init__(self, gtid: str, sid=None, intervals: List[Tuple[int, int]]=[]):
if sid:
intervals = intervals
else:
sid, intervals = Gtid.parse(gtid)

self.sid: str = sid
self.intervals: List[Tuple[int, int]] = []
for itvl in intervals:
self.__add_interval(itvl)

@staticmethod
def parse_interval(interval):
def parse_interval(interval: str) -> Tuple[int, int]:
"""
We parse a human-generated string here. So our end value b
is incremented to conform to the internal representation format.
Expand All @@ -64,7 +77,7 @@ def parse_interval(interval):
return (a, b+1)

@staticmethod
def parse(gtid):
def parse(gtid: str) -> Tuple[str, List[Tuple[int, int]]]:
"""Parse a GTID from mysql textual format.

Raises:
Expand All @@ -83,7 +96,7 @@ def parse(gtid):

return (sid, intervals_parsed)

def __add_interval(self, itvl):
def __add_interval(self, itvl: Tuple[int, int]):
"""
Use the internal representation format and add it
to our intervals, merging if required.
Expand Down Expand Up @@ -113,7 +126,7 @@ def __add_interval(self, itvl):

self.intervals = sorted(new + [itvl])

def __sub_interval(self, itvl):
def __sub_interval(self, itvl: Tuple[int, int]):
"""Using the internal representation, remove an interval

Raises: Exception if itvl malformated"""
Expand All @@ -138,25 +151,16 @@ def __sub_interval(self, itvl):

self.intervals = new

def __contains__(self, other):
def __contains__(self, other: Gtid) -> bool:
if other.sid != self.sid:
return False

return all(any(contains(me, them) for me in self.intervals)
for them in other.intervals)

def __init__(self, gtid, sid=None, intervals=[]):
if sid:
intervals = intervals
else:
sid, intervals = Gtid.parse(gtid)

self.sid = sid
self.intervals = []
for itvl in intervals:
self.__add_interval(itvl)

def __add__(self, other):
def __add__(self, other: Gtid) -> Gtid:
"""Include the transactions of this gtid.

Raises:
Expand All @@ -172,7 +176,7 @@ def __add__(self, other):

return result

def __sub__(self, other):
def __sub__(self, other: Gtid) -> Gtid:
"""Remove intervals. Do not raise, if different SID simply
ignore"""
result = Gtid(str(self))
Expand All @@ -184,26 +188,26 @@ def __sub__(self, other):

return result

def __str__(self):
def __str__(self) -> str:
"""We represent the human value here - a single number
for one transaction, or a closed interval (decrementing b)"""
return '%s:%s' % (self.sid,
':'.join(('%d-%d' % (x[0], x[1]-1)) if x[0] +1 != x[1]
else str(x[0])
for x in self.intervals))

def __repr__(self):
def __repr__(self) -> str:
return '<Gtid "%s">' % self

@property
def encoded_length(self):
def encoded_length(self) -> int:
return (16 + # sid
8 + # n_intervals
2 * # stop/start
8 * # stop/start mark encoded as int64
len(self.intervals))

def encode(self):
def encode(self) -> bytes:
buffer = b''
# sid
buffer += binascii.unhexlify(self.sid.replace('-', ''))
Expand All @@ -219,7 +223,7 @@ def encode(self):
return buffer

@classmethod
def decode(cls, payload):
def decode(cls, payload: BytesIO) -> Gtid:
assert isinstance(payload, BytesIO), \
'payload is expected to be a BytesIO'
sid = b''
Expand All @@ -236,49 +240,58 @@ def decode(cls, payload):
(n_intervals,) = struct.unpack('<Q', payload.read(8))
intervals = []
for i in range(0, n_intervals):
start, end = struct.unpack('<QQ', payload.read(16))
(start, end) = struct.unpack('<QQ', payload.read(16))
intervals.append((start, end-1))

return cls('%s:%s' % (sid.decode('ascii'), ':'.join([
'%d-%d' % x
if isinstance(x, tuple)
else '%d' % x
for x in intervals])))

def __cmp__(self, other):
if other.sid != self.sid:
return cmp(self.sid, other.sid)
return cmp(self.intervals, other.intervals)

def __eq__(self, other):
def __eq__(self, other: object) -> bool:
"""Equality between a Gtid and an other Gtid.

Raise: NotImplemented if compared with anything else.
"""
if not isinstance(other, Gtid):
raise NotImplemented

if other.sid != self.sid:
return False
return self.intervals == other.intervals

def __lt__(self, other):
def __lt__(self, other: Gtid) -> bool:
"""Check if a Gtid is lesser an other Gtid.
"""
if other.sid != self.sid:
return self.sid < other.sid
return self.intervals < other.intervals

def __le__(self, other):
def __le__(self, other: Gtid) -> bool:
"""Check if a Gtid is lesser or equal an other Gtid.
"""
if other.sid != self.sid:
return self.sid <= other.sid
return self.intervals <= other.intervals

def __gt__(self, other):
def __gt__(self, other: Gtid) -> bool:
if other.sid != self.sid:
return self.sid > other.sid
return self.intervals > other.intervals

def __ge__(self, other):
def __ge__(self, other: Gtid) -> bool:
if other.sid != self.sid:
return self.sid >= other.sid
return self.intervals >= other.intervals


class GtidSet(object):
"""Represents a set of Gtid"""
def __init__(self, gtid_set):
def __init__(self, gtid_set: Optional[Union[str, Union[Set[Gtid], Set[str], List[Gtid], List[str]]]]):
"""
Construct a GtidSet initial state depends of the nature of `gtid_set` param.

Expand All @@ -300,7 +313,7 @@ def _to_gtid(element):
return Gtid(element.strip(' \n'))

if not gtid_set:
self.gtids = []
self.gtids: List[Gtid] = []
elif isinstance(gtid_set, (list, set)):
self.gtids = [_to_gtid(x) for x in gtid_set]
else:
Expand All @@ -317,7 +330,7 @@ def merge_gtid(self, gtid):
new_gtids.append(gtid)
self.gtids = new_gtids

def __contains__(self, other):
def __contains__(self, other: Union[GtidSet, Gtid]) -> bool:
"""
Raises:
- NotImplementedError other is not a GtidSet neither a Gtid,
Expand All @@ -329,7 +342,7 @@ def __contains__(self, other):
return any(other in x for x in self.gtids)
raise NotImplementedError

def __add__(self, other):
def __add__(self, other: Union[GtidSet, Gtid]) -> GtidSet:
"""
Merge current instance with an other GtidSet or with a Gtid alone.

Expand All @@ -350,33 +363,39 @@ def __add__(self, other):

raise NotImplementedError

def __str__(self):
def __str__(self) -> str:
"""
Returns a comma separated string of gtids.
"""
return ','.join(str(x) for x in self.gtids)

def __repr__(self):
def __repr__(self) -> str:
return '<GtidSet %r>' % self.gtids

@property
def encoded_length(self):
def encoded_length(self) -> int:
return (8 + # n_sids
sum(x.encoded_length for x in self.gtids))

def encoded(self):
def encoded(self) -> bytes:
return b'' + (struct.pack('<Q', len(self.gtids)) +
b''.join(x.encode() for x in self.gtids))

encode = encoded

@classmethod
def decode(cls, payload):
def decode(cls, payload: BytesIO) -> GtidSet:
assert isinstance(payload, BytesIO), \
'payload is expected to be a BytesIO'
(n_sid,) = struct.unpack('<Q', payload.read(8))

return cls([Gtid.decode(payload) for _ in range(0, n_sid)])

def __eq__(self, other):
def __eq__(self, other: object) -> bool:
"""Equality between a GtidSet and an other GtidSet.

Raise: NotImplemented if compared with anything else.
"""
if not isinstance(other, GtidSet):
raise NotImplemented
return self.gtids == other.gtids