Skip to content

Commit 1b0fc42

Browse files
Merge pull request #174 from baloo/baloo/features/decode-gtid-from-packet
Decode gtid from network packet
2 parents edb7b74 + 7558f72 commit 1b0fc42

File tree

2 files changed

+80
-12
lines changed

2 files changed

+80
-12
lines changed

pymysqlreplication/gtid.py

+61-12
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import re
44
import struct
55
import binascii
6+
from io import BytesIO
67

78

89
class Gtid(object):
@@ -16,17 +17,18 @@ def parse_interval(interval):
1617
else:
1718
return (int(m.group(1)), int(m.group(2)))
1819

19-
2020
@staticmethod
2121
def parse(gtid):
22-
m = re.search('^([0-9a-fA-F]{8}(?:-[0-9a-fA-F]{4}){3}-[0-9a-fA-F]{12})((?::[0-9-]+)+)$', gtid)
22+
m = re.search('^([0-9a-fA-F]{8}(?:-[0-9a-fA-F]{4}){3}-[0-9a-fA-F]{12})'
23+
'((?::[0-9-]+)+)$', gtid)
2324
if not m:
2425
raise ValueError('GTID format is incorrect: %r' % (gtid, ))
2526

2627
sid = m.group(1)
2728
intervals = m.group(2)
2829

29-
intervals_parsed = [Gtid.parse_interval(x) for x in intervals.split(':')[1:]]
30+
intervals_parsed = [Gtid.parse_interval(x)
31+
for x in intervals.split(':')[1:]]
3032

3133
return (sid, intervals_parsed)
3234

@@ -39,18 +41,18 @@ def __init__(self, gtid):
3941
def __str__(self):
4042
return '%s:%s' % (self.sid,
4143
':'.join(('%d-%s' % x) if isinstance(x, tuple)
42-
else str(x)
43-
for x in self.intervals))
44+
else str(x)
45+
for x in self.intervals))
4446

4547
def __repr__(self):
4648
return '<Gtid "%s">' % self
4749

4850
@property
4951
def encoded_length(self):
50-
return (16 + # sid
51-
8 + # n_intervals
52-
2 * # stop/start
53-
8 * # stop/start mark encoded as int64
52+
return (16 + # sid
53+
8 + # n_intervals
54+
2 * # stop/start
55+
8 * # stop/start mark encoded as int64
5456
len(self.intervals))
5557

5658
def encode(self):
@@ -79,25 +81,72 @@ def encode(self):
7981

8082
return buffer
8183

84+
@classmethod
85+
def decode(cls, payload):
86+
assert isinstance(payload, BytesIO), \
87+
'payload is expected to be a BytesIO'
88+
sid = b''
89+
sid = sid + binascii.hexlify(payload.read(4))
90+
sid = sid + b'-'
91+
sid = sid + binascii.hexlify(payload.read(2))
92+
sid = sid + b'-'
93+
sid = sid + binascii.hexlify(payload.read(2))
94+
sid = sid + b'-'
95+
sid = sid + binascii.hexlify(payload.read(2))
96+
sid = sid + b'-'
97+
sid = sid + binascii.hexlify(payload.read(6))
98+
99+
(n_intervals,) = struct.unpack('<Q', payload.read(8))
100+
intervals = []
101+
for i in range(0, n_intervals):
102+
start, end = struct.unpack('<QQ', payload.read(16))
103+
if end == start + 1:
104+
intervals.append(start)
105+
else:
106+
intervals.append((start, end))
107+
108+
return cls('%s:%s' % (sid.decode('ascii'), ':'.join([
109+
'%d-%d' % x
110+
if isinstance(x, tuple)
111+
else '%d' % x
112+
for x in intervals])))
113+
82114

83115
class GtidSet(object):
84116
def __init__(self, gtid_set):
117+
def _to_gtid(element):
118+
if isinstance(element, Gtid):
119+
return element
120+
return Gtid(element.strip(' \n'))
121+
85122
if not gtid_set:
86123
self.gtids = []
124+
elif isinstance(gtid_set, (list, set)):
125+
self.gtids = [_to_gtid(x) for x in gtid_set]
87126
else:
88127
self.gtids = [Gtid(x.strip(' \n')) for x in gtid_set.split(',')]
89128

90129
def __str__(self):
91130
return ','.join(str(x) for x in self.gtids)
92131

93132
def __repr__(self):
94-
return '<GtidSet "%s"' % ','.join(repr(x) for x in self.gtids)
133+
return '<GtidSet %r>' % self.gtids
95134

96135
@property
97136
def encoded_length(self):
98-
return (8 + # n_sids
137+
return (8 + # n_sids
99138
sum(x.encoded_length for x in self.gtids))
100139

101140
def encoded(self):
102141
return b'' + (struct.pack('<Q', len(self.gtids)) +
103-
b''.join(x.encode() for x in self.gtids))
142+
b''.join(x.encode() for x in self.gtids))
143+
144+
encode = encoded
145+
146+
@classmethod
147+
def decode(cls, payload):
148+
assert isinstance(payload, BytesIO), \
149+
'payload is expected to be a BytesIO'
150+
(n_sid,) = struct.unpack('<Q', payload.read(8))
151+
152+
return cls([Gtid.decode(payload) for _ in range(0, n_sid)])

pymysqlreplication/tests/test_basic.py

+19
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# -*- coding: utf-8 -*-
22
import time
33
import sys
4+
import io
45
if sys.version_info < (2, 7):
56
import unittest2 as unittest
67
else:
@@ -715,6 +716,24 @@ def test_gtidset_representation_newline(self):
715716
myset = GtidSet(mysql_repr)
716717
self.assertEqual(str(myset), set_repr)
717718

719+
def test_gtidset_representation(self):
720+
set_repr = '57b70f4e-20d3-11e5-a393-4a63946f7eac:1-56,' \
721+
'4350f323-7565-4e59-8763-4b1b83a0ce0e:1-20'
722+
723+
myset = GtidSet(set_repr)
724+
payload = myset.encode()
725+
parsedset = myset.decode(io.BytesIO(payload))
726+
727+
self.assertEqual(str(myset), str(parsedset))
728+
729+
set_repr = '57b70f4e-20d3-11e5-a393-4a63946f7eac:1,' \
730+
'4350f323-7565-4e59-8763-4b1b83a0ce0e:1-20'
731+
732+
myset = GtidSet(set_repr)
733+
payload = myset.encode()
734+
parsedset = myset.decode(io.BytesIO(payload))
735+
736+
self.assertEqual(str(myset), str(parsedset))
718737

719738
if __name__ == "__main__":
720739
import unittest

0 commit comments

Comments
 (0)