Skip to content

Commit e54f02e

Browse files
committed
Fix issues with inet type I/O
The inet codec is confusing the network prefix length with the network mask length, which causes incorrect wire encoding. This was masked by a symmetrical bug in the decoder. Fixes #37.
1 parent 57c9ffd commit e54f02e

File tree

2 files changed

+83
-17
lines changed

2 files changed

+83
-17
lines changed

asyncpg/protocol/codecs/network.pyx

Lines changed: 41 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -18,15 +18,37 @@ _ipaddr = ipaddress.ip_address
1818
_ipnet = ipaddress.ip_network
1919

2020

21-
cdef inline _net_encode(WriteBuffer buf, int32_t version, uint8_t bits,
21+
cdef inline uint8_t _ip_max_prefix_len(int32_t family):
22+
# Maximum number of bits in the network prefix of the specified
23+
# IP protocol version.
24+
if family == PGSQL_AF_INET:
25+
return 32
26+
else:
27+
return 128
28+
29+
30+
cdef inline int32_t _ip_addr_len(int32_t family):
31+
# Length of address in bytes for the specified IP protocol version.
32+
if family == PGSQL_AF_INET:
33+
return 4
34+
else:
35+
return 16
36+
37+
38+
cdef inline int8_t _ver_to_family(int32_t version):
39+
if version == 4:
40+
return PGSQL_AF_INET
41+
else:
42+
return PGSQL_AF_INET6
43+
44+
45+
cdef inline _net_encode(WriteBuffer buf, int8_t family, uint32_t bits,
2246
int8_t is_cidr, bytes addr):
2347

2448
cdef:
2549
char *addrbytes
2650
ssize_t addrlen
27-
int8_t family
2851

29-
family = PGSQL_AF_INET if version == 4 else PGSQL_AF_INET6
3052
cpython.PyBytes_AsStringAndSize(addr, &addrbytes, &addrlen)
3153

3254
buf.write_int32(4 + <int32_t>addrlen)
@@ -41,28 +63,31 @@ cdef net_decode(ConnectionSettings settings, FastReadBuffer buf):
4163
cdef:
4264
int32_t family = <int32_t>buf.read(1)[0]
4365
uint8_t bits = <uint8_t>buf.read(1)[0]
44-
uint32_t is_cidr = <uint32_t>buf.read(1)[0]
45-
uint32_t addrlen = <uint32_t>buf.read(1)[0]
66+
int32_t is_cidr = <int32_t>buf.read(1)[0]
67+
int32_t addrlen = <int32_t>buf.read(1)[0]
4668
bytes addr
69+
uint8_t max_prefix_len = _ip_max_prefix_len(family)
4770

4871
if family != PGSQL_AF_INET and family != PGSQL_AF_INET6:
4972
raise ValueError('invalid address family in "{}" value'.format(
5073
'cidr' if is_cidr else 'inet'
5174
))
5275

53-
if bits > (32 if family == PGSQL_AF_INET else 128):
54-
raise ValueError('invalid bits in "{}" value'.format(
76+
max_prefix_len = _ip_max_prefix_len(family)
77+
78+
if bits > max_prefix_len:
79+
raise ValueError('invalid network prefix length in "{}" value'.format(
5580
'cidr' if is_cidr else 'inet'
5681
))
5782

58-
if addrlen != (4 if family == PGSQL_AF_INET else 16):
59-
raise ValueError('invalid length in "{}" value'.format(
83+
if addrlen != _ip_addr_len(family):
84+
raise ValueError('invalid address length in "{}" value'.format(
6085
'cidr' if is_cidr else 'inet'
6186
))
6287

6388
addr = cpython.PyBytes_FromStringAndSize(buf.read(addrlen), addrlen)
6489

65-
if is_cidr or bits > 0:
90+
if is_cidr or bits != max_prefix_len:
6691
return _ipnet(addr).supernet(new_prefix=cpython.PyLong_FromLong(bits))
6792
else:
6893
return _ipaddr(addr)
@@ -71,15 +96,17 @@ cdef net_decode(ConnectionSettings settings, FastReadBuffer buf):
7196
cdef cidr_encode(ConnectionSettings settings, WriteBuffer buf, obj):
7297
cdef:
7398
object ipnet
99+
int8_t family
74100

75101
ipnet = _ipnet(obj)
76-
_net_encode(buf, ipnet.version, ipnet.prefixlen, 1,
77-
ipnet.network_address.packed)
102+
family = _ver_to_family(ipnet.version)
103+
_net_encode(buf, family, ipnet.prefixlen, 1, ipnet.network_address.packed)
78104

79105

80106
cdef inet_encode(ConnectionSettings settings, WriteBuffer buf, obj):
81107
cdef:
82108
object ipaddr
109+
int8_t family
83110

84111
try:
85112
ipaddr = _ipaddr(obj)
@@ -88,7 +115,8 @@ cdef inet_encode(ConnectionSettings settings, WriteBuffer buf, obj):
88115
# for the host datatype.
89116
cidr_encode(settings, buf, obj)
90117
else:
91-
_net_encode(buf, ipaddr.version, 0, 0, ipaddr.packed)
118+
family = _ver_to_family(ipaddr.version)
119+
_net_encode(buf, family, _ip_max_prefix_len(family), 0, ipaddr.packed)
92120

93121

94122
cdef init_network_codecs():

tests/test_codecs.py

Lines changed: 42 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -302,7 +302,32 @@ def _timezone(offset):
302302
output=ipaddress.IPv4Network('127.0.0.0/8')),
303303
dict(
304304
input='127.0.0.1/32',
305-
output=ipaddress.IPv4Network('127.0.0.1/32')),
305+
output=ipaddress.IPv4Address('127.0.0.1')),
306+
# Postgres appends /32 when casting to text explicitly, but
307+
# *not* in inet_out.
308+
dict(
309+
input='10.11.12.13',
310+
textoutput='10.11.12.13/32'
311+
),
312+
dict(
313+
input=ipaddress.IPv4Address('10.11.12.13'),
314+
textoutput='10.11.12.13/32'
315+
),
316+
dict(
317+
input=ipaddress.IPv4Network('10.11.12.13'),
318+
textoutput='10.11.12.13/32'
319+
),
320+
dict(
321+
textinput='10.11.12.13',
322+
output=ipaddress.IPv4Address('10.11.12.13'),
323+
),
324+
dict(
325+
# Non-zero address bits after the network prefix are permitted
326+
# by postgres, but are invalid in Python
327+
# (and zeroed out by supernet()).
328+
textinput='10.11.12.13/0',
329+
output=ipaddress.IPv4Network('0.0.0.0/0'),
330+
),
306331
]),
307332
('macaddr', 'macaddr', [
308333
'00:00:00:00:00:00',
@@ -369,20 +394,33 @@ async def test_standard_codecs(self):
369394
"SELECT $1::" + typname
370395
)
371396

372-
textst = await self.con.prepare(
397+
text_in = await self.con.prepare(
373398
"SELECT $1::text::" + typname
374399
)
375400

401+
text_out = await self.con.prepare(
402+
"SELECT $1::" + typname + "::text"
403+
)
404+
376405
for sample in sample_data:
377406
with self.subTest(sample=sample, typname=typname):
378407
stmt = st
379408
if isinstance(sample, dict):
380409
if 'textinput' in sample:
381410
inputval = sample['textinput']
382-
stmt = textst
411+
stmt = text_in
383412
else:
384413
inputval = sample['input']
385-
outputval = sample['output']
414+
415+
if 'textoutput' in sample:
416+
outputval = sample['textoutput']
417+
if stmt is text_in:
418+
raise ValueError(
419+
'cannot test "textin" and'
420+
' "textout" simultaneously')
421+
stmt = text_out
422+
else:
423+
outputval = sample['output']
386424
else:
387425
inputval = outputval = sample
388426

0 commit comments

Comments
 (0)