Skip to content

Untangle custom codec confusion #662

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Dec 2, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions asyncpg/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -1156,6 +1156,15 @@ async def set_type_codec(self, typename, *,
.. versionchanged:: 0.13.0
The ``binary`` keyword argument was removed in favor of
``format``.

.. note::

It is recommended to use the ``'binary'`` or ``'tuple'`` *format*
whenever possible and if the underlying type supports it. Asyncpg
currently does not support text I/O for composite and range types,
and some other functionality, such as
:meth:`Connection.copy_to_table`, does not support types with text
codecs.
"""
self._check_open()
typeinfo = await self._introspect_type(typename, schema)
Expand Down
31 changes: 10 additions & 21 deletions asyncpg/introspection.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,23 +37,9 @@

ELSE NULL
END) AS basetype,
t.typreceive::oid != 0 AND t.typsend::oid != 0
AS has_bin_io,
t.typelem AS elemtype,
elem_t.typdelim AS elemdelim,
range_t.rngsubtype AS range_subtype,
(CASE WHEN t.typtype = 'r' THEN
(SELECT
range_elem_t.typreceive::oid != 0 AND
range_elem_t.typsend::oid != 0
FROM
pg_catalog.pg_type AS range_elem_t
WHERE
range_elem_t.oid = range_t.rngsubtype)
ELSE
elem_t.typreceive::oid != 0 AND
elem_t.typsend::oid != 0
END) AS elem_has_bin_io,
(CASE WHEN t.typtype = 'c' THEN
(SELECT
array_agg(ia.atttypid ORDER BY ia.attnum)
Expand Down Expand Up @@ -98,12 +84,12 @@

INTRO_LOOKUP_TYPES = '''\
WITH RECURSIVE typeinfo_tree(
oid, ns, name, kind, basetype, has_bin_io, elemtype, elemdelim,
range_subtype, elem_has_bin_io, attrtypoids, attrnames, depth)
oid, ns, name, kind, basetype, elemtype, elemdelim,
range_subtype, attrtypoids, attrnames, depth)
AS (
SELECT
ti.oid, ti.ns, ti.name, ti.kind, ti.basetype, ti.has_bin_io,
ti.elemtype, ti.elemdelim, ti.range_subtype, ti.elem_has_bin_io,
ti.oid, ti.ns, ti.name, ti.kind, ti.basetype,
ti.elemtype, ti.elemdelim, ti.range_subtype,
ti.attrtypoids, ti.attrnames, 0
FROM
{typeinfo} AS ti
Expand All @@ -113,8 +99,8 @@
UNION ALL

SELECT
ti.oid, ti.ns, ti.name, ti.kind, ti.basetype, ti.has_bin_io,
ti.elemtype, ti.elemdelim, ti.range_subtype, ti.elem_has_bin_io,
ti.oid, ti.ns, ti.name, ti.kind, ti.basetype,
ti.elemtype, ti.elemdelim, ti.range_subtype,
ti.attrtypoids, ti.attrnames, tt.depth + 1
FROM
{typeinfo} ti,
Expand All @@ -126,7 +112,10 @@
)

SELECT DISTINCT
*
*,
basetype::regtype::text AS basetype_name,
elemtype::regtype::text AS elemtype_name,
range_subtype::regtype::text AS range_subtype_name
FROM
typeinfo_tree
ORDER BY
Expand Down
3 changes: 2 additions & 1 deletion asyncpg/protocol/codecs/base.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -168,4 +168,5 @@ cdef class DataCodecConfig:

cdef inline Codec get_codec(self, uint32_t oid, ServerDataFormat format,
bint ignore_custom_codec=*)
cdef inline Codec get_any_local_codec(self, uint32_t oid)
cdef inline Codec get_custom_codec(self, uint32_t oid,
ServerDataFormat format)
137 changes: 64 additions & 73 deletions asyncpg/protocol/codecs/base.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -440,14 +440,7 @@ cdef class DataCodecConfig:
for ti in types:
oid = ti['oid']

if not ti['has_bin_io']:
format = PG_FORMAT_TEXT
else:
format = PG_FORMAT_BINARY

has_text_elements = False

if self.get_codec(oid, format) is not None:
if self.get_codec(oid, PG_FORMAT_ANY) is not None:
continue

name = ti['name']
Expand All @@ -468,92 +461,79 @@ cdef class DataCodecConfig:
name = name[1:]
name = '{}[]'.format(name)

if ti['elem_has_bin_io']:
elem_format = PG_FORMAT_BINARY
else:
elem_format = PG_FORMAT_TEXT

elem_codec = self.get_codec(array_element_oid, elem_format)
elem_codec = self.get_codec(array_element_oid, PG_FORMAT_ANY)
if elem_codec is None:
elem_format = PG_FORMAT_TEXT
elem_codec = self.declare_fallback_codec(
array_element_oid, name, schema)
array_element_oid, ti['elemtype_name'], schema)

elem_delim = <Py_UCS4>ti['elemdelim'][0]

self._derived_type_codecs[oid, elem_format] = \
self._derived_type_codecs[oid, elem_codec.format] = \
Codec.new_array_codec(
oid, name, schema, elem_codec, elem_delim)

elif ti['kind'] == b'c':
# Composite type

if not comp_type_attrs:
raise exceptions.InternalClientError(
'type record missing field types for '
'composite {}'.format(oid))

# Composite type
f'type record missing field types for composite {oid}')

comp_elem_codecs = []
has_text_elements = False

for typoid in comp_type_attrs:
elem_codec = self.get_codec(typoid, PG_FORMAT_BINARY)
if elem_codec is None:
elem_codec = self.get_codec(typoid, PG_FORMAT_TEXT)
has_text_elements = True
elem_codec = self.get_codec(typoid, PG_FORMAT_ANY)
if elem_codec is None:
raise exceptions.InternalClientError(
'no codec for composite attribute type {}'.format(
typoid))
f'no codec for composite attribute type {typoid}')
if elem_codec.format is PG_FORMAT_TEXT:
has_text_elements = True
comp_elem_codecs.append(elem_codec)

element_names = collections.OrderedDict()
for i, attrname in enumerate(ti['attrnames']):
element_names[attrname] = i

# If at least one element is text-encoded, we must
# encode the whole composite as text.
if has_text_elements:
format = PG_FORMAT_TEXT
elem_format = PG_FORMAT_TEXT
else:
elem_format = PG_FORMAT_BINARY

self._derived_type_codecs[oid, format] = \
self._derived_type_codecs[oid, elem_format] = \
Codec.new_composite_codec(
oid, name, schema, format, comp_elem_codecs,
oid, name, schema, elem_format, comp_elem_codecs,
comp_type_attrs, element_names)

elif ti['kind'] == b'd':
# Domain type

if not base_type:
raise exceptions.InternalClientError(
'type record missing base type for domain {}'.format(
oid))
f'type record missing base type for domain {oid}')

elem_codec = self.get_codec(base_type, format)
elem_codec = self.get_codec(base_type, PG_FORMAT_ANY)
if elem_codec is None:
format = PG_FORMAT_TEXT
elem_codec = self.declare_fallback_codec(
base_type, name, schema)
base_type, ti['basetype_name'], schema)

self._derived_type_codecs[oid, format] = elem_codec
self._derived_type_codecs[oid, elem_codec.format] = elem_codec

elif ti['kind'] == b'r':
# Range type

if not range_subtype_oid:
raise exceptions.InternalClientError(
'type record missing base type for range {}'.format(
oid))
f'type record missing base type for range {oid}')

if ti['elem_has_bin_io']:
elem_format = PG_FORMAT_BINARY
else:
elem_format = PG_FORMAT_TEXT

elem_codec = self.get_codec(range_subtype_oid, elem_format)
elem_codec = self.get_codec(range_subtype_oid, PG_FORMAT_ANY)
if elem_codec is None:
elem_format = PG_FORMAT_TEXT
elem_codec = self.declare_fallback_codec(
range_subtype_oid, name, schema)
range_subtype_oid, ti['range_subtype_name'], schema)

self._derived_type_codecs[oid, elem_format] = \
self._derived_type_codecs[oid, elem_codec.format] = \
Codec.new_range_codec(oid, name, schema, elem_codec)

elif ti['kind'] == b'e':
Expand Down Expand Up @@ -665,10 +645,6 @@ cdef class DataCodecConfig:
def declare_fallback_codec(self, uint32_t oid, str name, str schema):
cdef Codec codec

codec = self.get_codec(oid, PG_FORMAT_TEXT)
if codec is not None:
return codec

if oid <= MAXBUILTINOID:
# This is a BKI type, for which asyncpg has no
# defined codec. This should only happen for newly
Expand Down Expand Up @@ -696,34 +672,49 @@ cdef class DataCodecConfig:
bint ignore_custom_codec=False):
cdef Codec codec

if not ignore_custom_codec:
codec = self.get_any_local_codec(oid)
if codec is not None:
if codec.format != format:
# The codec for this OID has been overridden by
# set_{builtin}_type_codec with a different format.
# We must respect that and not return a core codec.
return None
else:
return codec

codec = get_core_codec(oid, format)
if codec is not None:
if format == PG_FORMAT_ANY:
codec = self.get_codec(
oid, PG_FORMAT_BINARY, ignore_custom_codec)
if codec is None:
codec = self.get_codec(
oid, PG_FORMAT_TEXT, ignore_custom_codec)
return codec
else:
try:
return self._derived_type_codecs[oid, format]
except KeyError:
return None
if not ignore_custom_codec:
codec = self.get_custom_codec(oid, PG_FORMAT_ANY)
if codec is not None:
if codec.format != format:
# The codec for this OID has been overridden by
# set_{builtin}_type_codec with a different format.
# We must respect that and not return a core codec.
return None
else:
return codec

codec = get_core_codec(oid, format)
if codec is not None:
return codec
else:
try:
return self._derived_type_codecs[oid, format]
except KeyError:
return None

cdef inline Codec get_any_local_codec(self, uint32_t oid):
cdef inline Codec get_custom_codec(
self,
uint32_t oid,
ServerDataFormat format
):
cdef Codec codec

codec = self._custom_type_codecs.get((oid, PG_FORMAT_BINARY))
if codec is None:
return self._custom_type_codecs.get((oid, PG_FORMAT_TEXT))
if format == PG_FORMAT_ANY:
codec = self.get_custom_codec(oid, PG_FORMAT_BINARY)
if codec is None:
codec = self.get_custom_codec(oid, PG_FORMAT_TEXT)
else:
return codec
codec = self._custom_type_codecs.get((oid, format))

return codec


cdef inline Codec get_core_codec(
Expand Down
11 changes: 1 addition & 10 deletions asyncpg/protocol/settings.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -89,16 +89,7 @@ cdef class ConnectionSettings(pgproto.CodecContext):
cpdef inline Codec get_data_codec(self, uint32_t oid,
ServerDataFormat format=PG_FORMAT_ANY,
bint ignore_custom_codec=False):
if format == PG_FORMAT_ANY:
codec = self._data_codecs.get_codec(
oid, PG_FORMAT_BINARY, ignore_custom_codec)
if codec is None:
codec = self._data_codecs.get_codec(
oid, PG_FORMAT_TEXT, ignore_custom_codec)
return codec
else:
return self._data_codecs.get_codec(
oid, format, ignore_custom_codec)
return self._data_codecs.get_codec(oid, format, ignore_custom_codec)

def __getattr__(self, name):
if not name.startswith('_'):
Expand Down
37 changes: 37 additions & 0 deletions tests/test_codecs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1305,6 +1305,34 @@ async def test_custom_codec_on_enum(self):
finally:
await self.con.execute('DROP TYPE custom_codec_t')

async def test_custom_codec_on_enum_array(self):
"""Test encoding/decoding using a custom codec on an enum array.

Bug: https://github.com/MagicStack/asyncpg/issues/590
"""
await self.con.execute('''
CREATE TYPE custom_codec_t AS ENUM ('foo', 'bar', 'baz')
''')

try:
await self.con.set_type_codec(
'custom_codec_t',
encoder=lambda v: str(v).lstrip('enum :'),
decoder=lambda v: 'enum: ' + str(v))

v = await self.con.fetchval(
"SELECT ARRAY['foo', 'bar']::custom_codec_t[]")
self.assertEqual(v, ['enum: foo', 'enum: bar'])

v = await self.con.fetchval(
'SELECT ARRAY[$1]::custom_codec_t[]', 'foo')
self.assertEqual(v, ['enum: foo'])

v = await self.con.fetchval("SELECT 'foo'::custom_codec_t")
self.assertEqual(v, 'enum: foo')
finally:
await self.con.execute('DROP TYPE custom_codec_t')

async def test_custom_codec_override_binary(self):
"""Test overriding core codecs."""
import json
Expand Down Expand Up @@ -1350,6 +1378,14 @@ def _decoder(value):
res = await conn.fetchval('SELECT $1::json', data)
self.assertEqual(data, res)

res = await conn.fetchval('SELECT $1::json[]', [data])
self.assertEqual([data], res)

await conn.execute('CREATE DOMAIN my_json AS json')

res = await conn.fetchval('SELECT $1::my_json', data)
self.assertEqual(data, res)

def _encoder(value):
return value

Expand All @@ -1365,6 +1401,7 @@ def _decoder(value):
res = await conn.fetchval('SELECT $1::uuid', data)
self.assertEqual(res, data)
finally:
await conn.execute('DROP DOMAIN IF EXISTS my_json')
await conn.close()

async def test_custom_codec_override_tuple(self):
Expand Down