diff --git a/asyncpg/protocol/codecs/network.pyx b/asyncpg/protocol/codecs/network.pyx index 509a4e0f..fb9d23d0 100644 --- a/asyncpg/protocol/codecs/network.pyx +++ b/asyncpg/protocol/codecs/network.pyx @@ -18,15 +18,37 @@ _ipaddr = ipaddress.ip_address _ipnet = ipaddress.ip_network -cdef inline _net_encode(WriteBuffer buf, int32_t version, uint8_t bits, +cdef inline uint8_t _ip_max_prefix_len(int32_t family): + # Maximum number of bits in the network prefix of the specified + # IP protocol version. + if family == PGSQL_AF_INET: + return 32 + else: + return 128 + + +cdef inline int32_t _ip_addr_len(int32_t family): + # Length of address in bytes for the specified IP protocol version. + if family == PGSQL_AF_INET: + return 4 + else: + return 16 + + +cdef inline int8_t _ver_to_family(int32_t version): + if version == 4: + return PGSQL_AF_INET + else: + return PGSQL_AF_INET6 + + +cdef inline _net_encode(WriteBuffer buf, int8_t family, uint32_t bits, int8_t is_cidr, bytes addr): cdef: char *addrbytes ssize_t addrlen - int8_t family - family = PGSQL_AF_INET if version == 4 else PGSQL_AF_INET6 cpython.PyBytes_AsStringAndSize(addr, &addrbytes, &addrlen) buf.write_int32(4 + addrlen) @@ -41,28 +63,31 @@ cdef net_decode(ConnectionSettings settings, FastReadBuffer buf): cdef: int32_t family = buf.read(1)[0] uint8_t bits = buf.read(1)[0] - uint32_t is_cidr = buf.read(1)[0] - uint32_t addrlen = buf.read(1)[0] + int32_t is_cidr = buf.read(1)[0] + int32_t addrlen = buf.read(1)[0] bytes addr + uint8_t max_prefix_len = _ip_max_prefix_len(family) if family != PGSQL_AF_INET and family != PGSQL_AF_INET6: raise ValueError('invalid address family in "{}" value'.format( 'cidr' if is_cidr else 'inet' )) - if bits > (32 if family == PGSQL_AF_INET else 128): - raise ValueError('invalid bits in "{}" value'.format( + max_prefix_len = _ip_max_prefix_len(family) + + if bits > max_prefix_len: + raise ValueError('invalid network prefix length in "{}" value'.format( 'cidr' if is_cidr else 'inet' )) - if addrlen != (4 if family == PGSQL_AF_INET else 16): - raise ValueError('invalid length in "{}" value'.format( + if addrlen != _ip_addr_len(family): + raise ValueError('invalid address length in "{}" value'.format( 'cidr' if is_cidr else 'inet' )) addr = cpython.PyBytes_FromStringAndSize(buf.read(addrlen), addrlen) - if is_cidr or bits > 0: + if is_cidr or bits != max_prefix_len: return _ipnet(addr).supernet(new_prefix=cpython.PyLong_FromLong(bits)) else: return _ipaddr(addr) @@ -71,15 +96,17 @@ cdef net_decode(ConnectionSettings settings, FastReadBuffer buf): cdef cidr_encode(ConnectionSettings settings, WriteBuffer buf, obj): cdef: object ipnet + int8_t family ipnet = _ipnet(obj) - _net_encode(buf, ipnet.version, ipnet.prefixlen, 1, - ipnet.network_address.packed) + family = _ver_to_family(ipnet.version) + _net_encode(buf, family, ipnet.prefixlen, 1, ipnet.network_address.packed) cdef inet_encode(ConnectionSettings settings, WriteBuffer buf, obj): cdef: object ipaddr + int8_t family try: ipaddr = _ipaddr(obj) @@ -88,7 +115,8 @@ cdef inet_encode(ConnectionSettings settings, WriteBuffer buf, obj): # for the host datatype. cidr_encode(settings, buf, obj) else: - _net_encode(buf, ipaddr.version, 0, 0, ipaddr.packed) + family = _ver_to_family(ipaddr.version) + _net_encode(buf, family, _ip_max_prefix_len(family), 0, ipaddr.packed) cdef init_network_codecs(): diff --git a/tests/test_codecs.py b/tests/test_codecs.py index 9542147b..f4bc0eea 100644 --- a/tests/test_codecs.py +++ b/tests/test_codecs.py @@ -302,7 +302,32 @@ def _timezone(offset): output=ipaddress.IPv4Network('127.0.0.0/8')), dict( input='127.0.0.1/32', - output=ipaddress.IPv4Network('127.0.0.1/32')), + output=ipaddress.IPv4Address('127.0.0.1')), + # Postgres appends /32 when casting to text explicitly, but + # *not* in inet_out. + dict( + input='10.11.12.13', + textoutput='10.11.12.13/32' + ), + dict( + input=ipaddress.IPv4Address('10.11.12.13'), + textoutput='10.11.12.13/32' + ), + dict( + input=ipaddress.IPv4Network('10.11.12.13'), + textoutput='10.11.12.13/32' + ), + dict( + textinput='10.11.12.13', + output=ipaddress.IPv4Address('10.11.12.13'), + ), + dict( + # Non-zero address bits after the network prefix are permitted + # by postgres, but are invalid in Python + # (and zeroed out by supernet()). + textinput='10.11.12.13/0', + output=ipaddress.IPv4Network('0.0.0.0/0'), + ), ]), ('macaddr', 'macaddr', [ '00:00:00:00:00:00', @@ -369,20 +394,33 @@ async def test_standard_codecs(self): "SELECT $1::" + typname ) - textst = await self.con.prepare( + text_in = await self.con.prepare( "SELECT $1::text::" + typname ) + text_out = await self.con.prepare( + "SELECT $1::" + typname + "::text" + ) + for sample in sample_data: with self.subTest(sample=sample, typname=typname): stmt = st if isinstance(sample, dict): if 'textinput' in sample: inputval = sample['textinput'] - stmt = textst + stmt = text_in else: inputval = sample['input'] - outputval = sample['output'] + + if 'textoutput' in sample: + outputval = sample['textoutput'] + if stmt is text_in: + raise ValueError( + 'cannot test "textin" and' + ' "textout" simultaneously') + stmt = text_out + else: + outputval = sample['output'] else: inputval = outputval = sample