Skip to content

v0.45 bug fix json object and array #494

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 2 additions & 46 deletions pymysqlreplication/event.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from pymysqlreplication.constants.STATUS_VAR_KEY import *
from pymysqlreplication.exceptions import StatusVariableMismatch
from pymysqlreplication.util.bytes import parse_decimal_from_bytes
from typing import Union, Optional


Expand Down Expand Up @@ -729,7 +730,7 @@ def _read_decimal(self, buffer: bytes) -> decimal.Decimal:
self.precision = self.temp_value_buffer[0]
self.decimals = self.temp_value_buffer[1]
raw_decimal = self.temp_value_buffer[2:]
return self._parse_decimal_from_bytes(raw_decimal, self.precision, self.decimals)
return parse_decimal_from_bytes(raw_decimal, self.precision, self.decimals)

def _read_default(self) -> bytes:
"""
Expand All @@ -738,51 +739,6 @@ def _read_default(self) -> bytes:
"""
return self.packet.read(self.value_len)

@staticmethod
def _parse_decimal_from_bytes(raw_decimal: bytes, precision: int, decimals: int) -> decimal.Decimal:
"""
Parse decimal from bytes.
"""
digits_per_integer = 9
compressed_bytes = [0, 1, 1, 2, 2, 3, 3, 4, 4, 4]
integral = precision - decimals

uncomp_integral, comp_integral = divmod(integral, digits_per_integer)
uncomp_fractional, comp_fractional = divmod(decimals, digits_per_integer)

res = "-" if not raw_decimal[0] & 0x80 else ""
mask = -1 if res == "-" else 0
raw_decimal = bytearray([raw_decimal[0] ^ 0x80]) + raw_decimal[1:]

def decode_decimal_decompress_value(comp_indx, data, mask):
size = compressed_bytes[comp_indx]
if size > 0:
databuff = bytearray(data[:size])
for i in range(size):
databuff[i] = (databuff[i] ^ mask) & 0xFF
return size, int.from_bytes(databuff, byteorder='big')
return 0, 0

pointer, value = decode_decimal_decompress_value(comp_integral, raw_decimal, mask)
res += str(value)

for _ in range(uncomp_integral):
value = struct.unpack('>i', raw_decimal[pointer:pointer+4])[0] ^ mask
res += '%09d' % value
pointer += 4

res += "."

for _ in range(uncomp_fractional):
value = struct.unpack('>i', raw_decimal[pointer:pointer+4])[0] ^ mask
res += '%09d' % value
pointer += 4

size, value = decode_decimal_decompress_value(comp_fractional, raw_decimal[pointer:], mask)
if size > 0:
res += '%0*d' % (comp_fractional, value)
return decimal.Decimal(res)

def _dump(self) -> None:
super(UserVarEvent, self)._dump()
print("User variable name: %s" % self.name)
Expand Down
271 changes: 139 additions & 132 deletions pymysqlreplication/packet.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import struct

from pymysqlreplication import constants, event, row_event
from pymysqlreplication.constants import FIELD_TYPE
from pymysqlreplication.util.bytes import *

# Constants from PyMYSQL source code
NULL_COLUMN = 251
Expand All @@ -15,7 +17,6 @@
UNSIGNED_INT24_LENGTH = 3
UNSIGNED_INT64_LENGTH = 8


JSONB_TYPE_SMALL_OBJECT = 0x0
JSONB_TYPE_LARGE_OBJECT = 0x1
JSONB_TYPE_SMALL_ARRAY = 0x2
Expand All @@ -35,19 +36,141 @@
JSONB_LITERAL_TRUE = 0x1
JSONB_LITERAL_FALSE = 0x2


def read_offset_or_inline(packet, large):
t = packet.read_uint8()

if t in (JSONB_TYPE_LITERAL,
JSONB_TYPE_INT16, JSONB_TYPE_UINT16):
return (t, None, packet.read_binary_json_type_inlined(t, large))
if large and t in (JSONB_TYPE_INT32, JSONB_TYPE_UINT32):
return (t, None, packet.read_binary_json_type_inlined(t, large))

if large:
return (t, packet.read_uint32(), None)
return (t, packet.read_uint16(), None)
JSONB_SMALL_OFFSET_SIZE = 2
JSONB_LARGE_OFFSET_SIZE = 4
JSONB_KEY_ENTRY_SIZE_SMALL = 2 + JSONB_SMALL_OFFSET_SIZE
JSONB_KEY_ENTRY_SIZE_LARGE = 2 + JSONB_LARGE_OFFSET_SIZE
JSONB_VALUE_ENTRY_SIZE_SMALL = 1 + JSONB_SMALL_OFFSET_SIZE
JSONB_VALUE_ENTRY_SIZE_LARGE = 1 + JSONB_LARGE_OFFSET_SIZE


def is_json_inline_value(type: bytes, is_small: bool) -> bool:
if type in [JSONB_TYPE_UINT16, JSONB_TYPE_INT16, JSONB_TYPE_LITERAL]:
return True
elif type in [JSONB_TYPE_INT32, JSONB_TYPE_UINT32]:
return not is_small
return False


def parse_json(type: bytes, data: bytes):
if type == JSONB_TYPE_SMALL_OBJECT:
v = parse_json_object_or_array(data, True, True)
elif type == JSONB_TYPE_LARGE_OBJECT:
v = parse_json_object_or_array(data, False, True)
elif type == JSONB_TYPE_SMALL_ARRAY:
v = parse_json_object_or_array(data, True, False)
elif type == JSONB_TYPE_LARGE_ARRAY:
v = parse_json_object_or_array(data, False, False)
elif type == JSONB_TYPE_LITERAL:
v = parse_literal(data)
elif type == JSONB_TYPE_INT16:
v = parse_int16(data)
elif type == JSONB_TYPE_UINT16:
v = parse_uint16(data)
elif type == JSONB_TYPE_INT32:
v = parse_int32(data)
elif type == JSONB_TYPE_UINT32:
v = parse_uint32(data)
elif type == JSONB_TYPE_INT64:
v = parse_int64(data)
elif type == JSONB_TYPE_UINT64:
v = parse_uint64(data)
elif type == JSONB_TYPE_DOUBLE:
v = parse_double(data)
elif type == JSONB_TYPE_STRING:
length, n = decode_variable_length(data)
v = parse_string(n, length, data)
elif type == JSONB_TYPE_OPAQUE:
v = parse_opaque(data)
else:
raise ValueError("Json type %d is not handled" % t)
return v


def parse_json_object_or_array(bytes, is_small, is_object):
offset_size = JSONB_SMALL_OFFSET_SIZE if is_small else JSONB_LARGE_OFFSET_SIZE
count = decode_count(bytes, is_small)
size = decode_count(bytes[offset_size:], is_small)
if is_small:
key_entry_size = JSONB_KEY_ENTRY_SIZE_SMALL
value_entry_size = JSONB_VALUE_ENTRY_SIZE_SMALL
else:
key_entry_size = JSONB_KEY_ENTRY_SIZE_LARGE
value_entry_size = JSONB_VALUE_ENTRY_SIZE_LARGE
if is_data_short(bytes, size):
raise ValueError(
"Before MySQL 5.7.22, json type generated column may have invalid value"
)

header_size = 2 * offset_size + count * value_entry_size

if is_object:
header_size += count * key_entry_size

if header_size > size:
raise ValueError("header size > size")

keys = []
if is_object:
keys = []
for i in range(count):
entry_offset = 2 * offset_size + key_entry_size * i
key_offset = decode_count(bytes[entry_offset:], is_small)
key_length = decode_uint(bytes[entry_offset + offset_size :])
keys.append(bytes[key_offset : key_offset + key_length])

values = {}
for i in range(count):
entry_offset = 2 * offset_size + value_entry_size * i
if is_object:
entry_offset += key_entry_size * count
json_type = bytes[entry_offset]
if is_json_inline_value(json_type, is_small):
values[i] = parse_json(
json_type, bytes[entry_offset + 1 : entry_offset + value_entry_size]
)
continue
value_offset = decode_count(bytes[entry_offset + 1 :], is_small)
if is_data_short(bytes, value_offset):
return None
values[i] = parse_json(json_type, bytes[value_offset:])
if not is_object:
return list(values.values())
out = {}
for i in range(count):
out[keys[i]] = values[i]
return out


def parse_literal(data: bytes):
json_type = data[0]
if json_type == JSONB_LITERAL_NULL:
return None
elif json_type == JSONB_LITERAL_TRUE:
return True
elif json_type == JSONB_LITERAL_FALSE:
return False

raise ValueError("NOT LITERAL TYPE")


def parse_opaque(data: bytes):
if is_data_short(data, 1):
return None
type_ = data[0]
data = data[1:]

length, n = decode_variable_length(data)
data = data[n : n + length]

if type_ in [FIELD_TYPE.NEWDECIMAL, FIELD_TYPE.DECIMAL]:
return decode_decimal(data)
elif type_ in [FIELD_TYPE.TIME, FIELD_TYPE.TIME2]:
return decode_time(data)
elif type_ in [FIELD_TYPE.DATE, FIELD_TYPE.DATETIME, FIELD_TYPE.DATETIME2]:
return decode_datetime(data)
else:
return data.decode(errors="ignore")


class BinLogPacketWrapper(object):
Expand Down Expand Up @@ -365,124 +488,8 @@ def read_binary_json(self, size):
if length == 0:
# handle NULL value
return None
payload = self.read(length)
self.unread(payload)
t = self.read_uint8()

return self.read_binary_json_type(t, length)

def read_binary_json_type(self, t, length):
large = (t in (JSONB_TYPE_LARGE_OBJECT, JSONB_TYPE_LARGE_ARRAY))
if t in (JSONB_TYPE_SMALL_OBJECT, JSONB_TYPE_LARGE_OBJECT):
return self.read_binary_json_object(length - 1, large)
elif t in (JSONB_TYPE_SMALL_ARRAY, JSONB_TYPE_LARGE_ARRAY):
return self.read_binary_json_array(length - 1, large)
elif t in (JSONB_TYPE_STRING,):
return self.read_variable_length_string()
elif t in (JSONB_TYPE_LITERAL,):
value = self.read_uint8()
if value == JSONB_LITERAL_NULL:
return None
elif value == JSONB_LITERAL_TRUE:
return True
elif value == JSONB_LITERAL_FALSE:
return False
elif t == JSONB_TYPE_INT16:
return self.read_int16()
elif t == JSONB_TYPE_UINT16:
return self.read_uint16()
elif t in (JSONB_TYPE_DOUBLE,):
return struct.unpack('<d', self.read(8))[0]
elif t == JSONB_TYPE_INT32:
return self.read_int32()
elif t == JSONB_TYPE_UINT32:
return self.read_uint32()
elif t == JSONB_TYPE_INT64:
return self.read_int64()
elif t == JSONB_TYPE_UINT64:
return self.read_uint64()

raise ValueError('Json type %d is not handled' % t)

def read_binary_json_type_inlined(self, t, large):
if t == JSONB_TYPE_LITERAL:
value = self.read_uint32() if large else self.read_uint16()
if value == JSONB_LITERAL_NULL:
return None
elif value == JSONB_LITERAL_TRUE:
return True
elif value == JSONB_LITERAL_FALSE:
return False
elif t == JSONB_TYPE_INT16:
return self.read_int32() if large else self.read_int16()
elif t == JSONB_TYPE_UINT16:
return self.read_uint32() if large else self.read_uint16()
elif t == JSONB_TYPE_INT32:
return self.read_int32()
elif t == JSONB_TYPE_UINT32:
return self.read_uint32()

raise ValueError('Json type %d is not handled' % t)

def read_binary_json_object(self, length, large):
if large:
elements = self.read_uint32()
size = self.read_uint32()
else:
elements = self.read_uint16()
size = self.read_uint16()

if size > length:
raise ValueError('Json length is larger than packet length')

if large:
key_offset_lengths = [(
self.read_uint32(), # offset (we don't actually need that)
self.read_uint16() # size of the key
) for _ in range(elements)]
else:
key_offset_lengths = [(
self.read_uint16(), # offset (we don't actually need that)
self.read_uint16() # size of key
) for _ in range(elements)]

value_type_inlined_lengths = [read_offset_or_inline(self, large)
for _ in range(elements)]

keys = [self.read(x[1]) for x in key_offset_lengths]

out = {}
for i in range(elements):
if value_type_inlined_lengths[i][1] is None:
data = value_type_inlined_lengths[i][2]
else:
t = value_type_inlined_lengths[i][0]
data = self.read_binary_json_type(t, length)
out[keys[i]] = data

return out

def read_binary_json_array(self, length, large):
if large:
elements = self.read_uint32()
size = self.read_uint32()
else:
elements = self.read_uint16()
size = self.read_uint16()

if size > length:
raise ValueError('Json length is larger than packet length')

values_type_offset_inline = [
read_offset_or_inline(self, large)
for _ in range(elements)]

def _read(x):
if x[1] is None:
return x[2]
return self.read_binary_json_type(x[0], length)

return [_read(x) for x in values_type_offset_inline]
data = self.read(length)
return parse_json(data[0], data[1:])

def read_string(self):
"""Read a 'Length Coded String' from the data buffer.
Expand Down
25 changes: 25 additions & 0 deletions pymysqlreplication/tests/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -574,6 +574,31 @@ def create_binlog_packet_wrapper(pkt):
self.assertEqual(binlog_event.event._is_event_valid, True)
self.assertNotEqual(wrong_event.event._is_event_valid, True)

def test_json_update(self):
self.stream.close()
self.stream = BinLogStreamReader(
self.database, server_id=1024, only_events=[UpdateRowsEvent]
)
create_query = (
"CREATE TABLE setting_table( id SERIAL AUTO_INCREMENT, setting JSON);"
)
insert_query = """INSERT INTO setting_table (setting) VALUES ('{"btn": true, "model": false}');"""

update_query = """ UPDATE setting_table
SET setting = JSON_REMOVE(setting, '$.model')
WHERE id=1;
"""
self.execute(create_query)
self.execute(insert_query)
self.execute(update_query)
self.execute("COMMIT;")
event = self.stream.fetchone()
self.assertEqual(
event.rows[0]["before_values"]["setting"],
{b"btn": True, b"model": False},
),
self.assertEqual(event.rows[0]["after_values"]["setting"], {b"btn": True}),


class TestMultipleRowBinLogStreamReader(base.PyMySQLReplicationTestCase):
def ignoredEvents(self):
Expand Down
Loading