Skip to content

Commit f2639df

Browse files
author
José Duarte
authored
ENH: allow JSON (de)serialization of ExtensionDtypes (#44722)
1 parent bd5ecc3 commit f2639df

File tree

11 files changed

+484
-11
lines changed

11 files changed

+484
-11
lines changed

doc/source/development/developer.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,7 @@ As an example of fully-formed metadata:
180180
'numpy_type': 'int64',
181181
'metadata': None}
182182
],
183-
'pandas_version': '0.20.0',
183+
'pandas_version': '1.4.0',
184184
'creator': {
185185
'library': 'pyarrow',
186186
'version': '0.13.0'

doc/source/user_guide/io.rst

+11-1
Original file line numberDiff line numberDiff line change
@@ -1903,6 +1903,7 @@ with optional parameters:
19031903
``index``; dict like {index -> {column -> value}}
19041904
``columns``; dict like {column -> {index -> value}}
19051905
``values``; just the values array
1906+
``table``; adhering to the JSON `Table Schema`_
19061907

19071908
* ``date_format`` : string, type of date conversion, 'epoch' for timestamp, 'iso' for ISO8601.
19081909
* ``double_precision`` : The number of decimal places to use when encoding floating point values, default 10.
@@ -2477,7 +2478,6 @@ A few notes on the generated table schema:
24772478
* For ``MultiIndex``, ``mi.names`` is used. If any level has no name,
24782479
then ``level_<i>`` is used.
24792480

2480-
24812481
``read_json`` also accepts ``orient='table'`` as an argument. This allows for
24822482
the preservation of metadata such as dtypes and index names in a
24832483
round-trippable manner.
@@ -2519,8 +2519,18 @@ indicate missing values and the subsequent read cannot distinguish the intent.
25192519
25202520
os.remove("test.json")
25212521
2522+
When using ``orient='table'`` along with user-defined ``ExtensionArray``,
2523+
the generated schema will contain an additional ``extDtype`` key in the respective
2524+
``fields`` element. This extra key is not standard but does enable JSON roundtrips
2525+
for extension types (e.g. ``read_json(df.to_json(orient="table"), orient="table")``).
2526+
2527+
The ``extDtype`` key carries the name of the extension, if you have properly registered
2528+
the ``ExtensionDtype``, pandas will use said name to perform a lookup into the registry
2529+
and re-convert the serialized data into your custom dtype.
2530+
25222531
.. _Table Schema: https://specs.frictionlessdata.io/table-schema/
25232532

2533+
25242534
HTML
25252535
----
25262536

doc/source/whatsnew/v1.4.0.rst

+1
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,7 @@ Other enhancements
231231
- :meth:`UInt64Index.map` now retains ``dtype`` where possible (:issue:`44609`)
232232
- :meth:`read_json` can now parse unsigned long long integers (:issue:`26068`)
233233
- :meth:`DataFrame.take` now raises a ``TypeError`` when passed a scalar for the indexer (:issue:`42875`)
234+
- :class:`ExtensionDtype` and :class:`ExtensionArray` are now (de)serialized when exporting a :class:`DataFrame` with :meth:`DataFrame.to_json` using ``orient='table'`` (:issue:`20612`, :issue:`44705`).
234235
-
235236

236237

pandas/core/generic.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -2568,7 +2568,7 @@ def to_json(
25682568
"primaryKey": [
25692569
"index"
25702570
],
2571-
"pandas_version": "0.20.0"
2571+
"pandas_version": "1.4.0"
25722572
}},
25732573
"data": [
25742574
{{

pandas/io/json/_json.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,6 @@
6868
loads = json.loads
6969
dumps = json.dumps
7070

71-
TABLE_SCHEMA_VERSION = "0.20.0"
72-
7371

7472
# interface to/from
7573
def to_json(
@@ -565,7 +563,7 @@ def read_json(
565563
{{"name":"col 1","type":"string"}},\
566564
{{"name":"col 2","type":"string"}}],\
567565
"primaryKey":["index"],\
568-
"pandas_version":"0.20.0"}},\
566+
"pandas_version":"1.4.0"}},\
569567
"data":[\
570568
{{"index":"row 1","col 1":"a","col 2":"b"}},\
571569
{{"index":"row 2","col 1":"c","col 2":"d"}}]\

pandas/io/json/_table_schema.py

+12-2
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,13 @@
1818
JSONSerializable,
1919
)
2020

21+
from pandas.core.dtypes.base import _registry as registry
2122
from pandas.core.dtypes.common import (
2223
is_bool_dtype,
2324
is_categorical_dtype,
2425
is_datetime64_dtype,
2526
is_datetime64tz_dtype,
27+
is_extension_array_dtype,
2628
is_integer_dtype,
2729
is_numeric_dtype,
2830
is_period_dtype,
@@ -40,6 +42,8 @@
4042

4143
loads = json.loads
4244

45+
TABLE_SCHEMA_VERSION = "1.4.0"
46+
4347

4448
def as_json_table_type(x: DtypeObj) -> str:
4549
"""
@@ -83,6 +87,8 @@ def as_json_table_type(x: DtypeObj) -> str:
8387
return "duration"
8488
elif is_categorical_dtype(x):
8589
return "any"
90+
elif is_extension_array_dtype(x):
91+
return "any"
8692
elif is_string_dtype(x):
8793
return "string"
8894
else:
@@ -130,6 +136,8 @@ def convert_pandas_type_to_json_field(arr):
130136
field["freq"] = dtype.freq.freqstr
131137
elif is_datetime64tz_dtype(dtype):
132138
field["tz"] = dtype.tz.zone
139+
elif is_extension_array_dtype(dtype):
140+
field["extDtype"] = dtype.name
133141
return field
134142

135143

@@ -195,6 +203,8 @@ def convert_json_field_to_pandas_type(field):
195203
return CategoricalDtype(
196204
categories=field["constraints"]["enum"], ordered=field["ordered"]
197205
)
206+
elif "extDtype" in field:
207+
return registry.find(field["extDtype"])
198208
else:
199209
return "object"
200210

@@ -253,7 +263,7 @@ def build_table_schema(
253263
{'name': 'B', 'type': 'string'}, \
254264
{'name': 'C', 'type': 'datetime'}], \
255265
'primaryKey': ['idx'], \
256-
'pandas_version': '0.20.0'}
266+
'pandas_version': '1.4.0'}
257267
"""
258268
if index is True:
259269
data = set_default_names(data)
@@ -287,7 +297,7 @@ def build_table_schema(
287297
schema["primaryKey"] = primary_key
288298

289299
if version:
290-
schema["pandas_version"] = "0.20.0"
300+
schema["pandas_version"] = TABLE_SCHEMA_VERSION
291301
return schema
292302

293303

+6
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
from pandas.tests.extension.date.array import (
2+
DateArray,
3+
DateDtype,
4+
)
5+
6+
__all__ = ["DateArray", "DateDtype"]

pandas/tests/extension/date/array.py

+180
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,180 @@
1+
import datetime as dt
2+
from typing import (
3+
Any,
4+
Optional,
5+
Sequence,
6+
Tuple,
7+
Union,
8+
cast,
9+
)
10+
11+
import numpy as np
12+
13+
from pandas._typing import (
14+
Dtype,
15+
PositionalIndexer,
16+
)
17+
18+
from pandas.core.dtypes.dtypes import register_extension_dtype
19+
20+
from pandas.api.extensions import (
21+
ExtensionArray,
22+
ExtensionDtype,
23+
)
24+
from pandas.api.types import pandas_dtype
25+
26+
27+
@register_extension_dtype
28+
class DateDtype(ExtensionDtype):
29+
@property
30+
def type(self):
31+
return dt.date
32+
33+
@property
34+
def name(self):
35+
return "DateDtype"
36+
37+
@classmethod
38+
def construct_from_string(cls, string: str):
39+
if not isinstance(string, str):
40+
raise TypeError(
41+
f"'construct_from_string' expects a string, got {type(string)}"
42+
)
43+
44+
if string == cls.__name__:
45+
return cls()
46+
else:
47+
raise TypeError(f"Cannot construct a '{cls.__name__}' from '{string}'")
48+
49+
@classmethod
50+
def construct_array_type(cls):
51+
return DateArray
52+
53+
@property
54+
def na_value(self):
55+
return dt.date.min
56+
57+
def __repr__(self) -> str:
58+
return self.name
59+
60+
61+
class DateArray(ExtensionArray):
62+
def __init__(
63+
self,
64+
dates: Union[
65+
dt.date,
66+
Sequence[dt.date],
67+
Tuple[np.ndarray, np.ndarray, np.ndarray],
68+
np.ndarray,
69+
],
70+
) -> None:
71+
if isinstance(dates, dt.date):
72+
self._year = np.array([dates.year])
73+
self._month = np.array([dates.month])
74+
self._day = np.array([dates.year])
75+
return
76+
77+
ldates = len(dates)
78+
if isinstance(dates, list):
79+
# pre-allocate the arrays since we know the size before hand
80+
self._year = np.zeros(ldates, dtype=np.uint16) # 65535 (0, 9999)
81+
self._month = np.zeros(ldates, dtype=np.uint8) # 255 (1, 31)
82+
self._day = np.zeros(ldates, dtype=np.uint8) # 255 (1, 12)
83+
# populate them
84+
for i, (y, m, d) in enumerate(
85+
map(lambda date: (date.year, date.month, date.day), dates)
86+
):
87+
self._year[i] = y
88+
self._month[i] = m
89+
self._day[i] = d
90+
91+
elif isinstance(dates, tuple):
92+
# only support triples
93+
if ldates != 3:
94+
raise ValueError("only triples are valid")
95+
# check if all elements have the same type
96+
if any(map(lambda x: not isinstance(x, np.ndarray), dates)):
97+
raise TypeError("invalid type")
98+
ly, lm, ld = (len(cast(np.ndarray, d)) for d in dates)
99+
if not ly == lm == ld:
100+
raise ValueError(
101+
f"tuple members must have the same length: {(ly, lm, ld)}"
102+
)
103+
self._year = dates[0].astype(np.uint16)
104+
self._month = dates[1].astype(np.uint8)
105+
self._day = dates[2].astype(np.uint8)
106+
107+
elif isinstance(dates, np.ndarray) and dates.dtype == "U10":
108+
self._year = np.zeros(ldates, dtype=np.uint16) # 65535 (0, 9999)
109+
self._month = np.zeros(ldates, dtype=np.uint8) # 255 (1, 31)
110+
self._day = np.zeros(ldates, dtype=np.uint8) # 255 (1, 12)
111+
112+
for (i,), (y, m, d) in np.ndenumerate(np.char.split(dates, sep="-")):
113+
self._year[i] = int(y)
114+
self._month[i] = int(m)
115+
self._day[i] = int(d)
116+
117+
else:
118+
raise TypeError(f"{type(dates)} is not supported")
119+
120+
@property
121+
def dtype(self) -> ExtensionDtype:
122+
return DateDtype()
123+
124+
def astype(self, dtype, copy=True):
125+
dtype = pandas_dtype(dtype)
126+
127+
if isinstance(dtype, DateDtype):
128+
data = self.copy() if copy else self
129+
else:
130+
data = self.to_numpy(dtype=dtype, copy=copy, na_value=dt.date.min)
131+
132+
return data
133+
134+
@property
135+
def nbytes(self) -> int:
136+
return self._year.nbytes + self._month.nbytes + self._day.nbytes
137+
138+
def __len__(self) -> int:
139+
return len(self._year) # all 3 arrays are enforced to have the same length
140+
141+
def __getitem__(self, item: PositionalIndexer):
142+
if isinstance(item, int):
143+
return dt.date(self._year[item], self._month[item], self._day[item])
144+
else:
145+
raise NotImplementedError("only ints are supported as indexes")
146+
147+
def __setitem__(self, key: Union[int, slice, np.ndarray], value: Any):
148+
if not isinstance(key, int):
149+
raise NotImplementedError("only ints are supported as indexes")
150+
151+
if not isinstance(value, dt.date):
152+
raise TypeError("you can only set datetime.date types")
153+
154+
self._year[key] = value.year
155+
self._month[key] = value.month
156+
self._day[key] = value.day
157+
158+
def __repr__(self) -> str:
159+
return f"DateArray{list(zip(self._year, self._month, self._day))}"
160+
161+
def copy(self) -> "DateArray":
162+
return DateArray((self._year.copy(), self._month.copy(), self._day.copy()))
163+
164+
def isna(self) -> np.ndarray:
165+
return np.logical_and(
166+
np.logical_and(
167+
self._year == dt.date.min.year, self._month == dt.date.min.month
168+
),
169+
self._day == dt.date.min.day,
170+
)
171+
172+
@classmethod
173+
def _from_sequence(cls, scalars, *, dtype: Optional[Dtype] = None, copy=False):
174+
if isinstance(scalars, dt.date):
175+
pass
176+
elif isinstance(scalars, DateArray):
177+
pass
178+
elif isinstance(scalars, np.ndarray):
179+
scalars = scalars.astype("U10") # 10 chars for yyyy-mm-dd
180+
return DateArray(scalars)

pandas/tests/extension/decimal/array.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -67,8 +67,11 @@ class DecimalArray(OpsMixin, ExtensionScalarOpsMixin, ExtensionArray):
6767

6868
def __init__(self, values, dtype=None, copy=False, context=None):
6969
for i, val in enumerate(values):
70-
if is_float(val) and np.isnan(val):
71-
values[i] = DecimalDtype.na_value
70+
if is_float(val):
71+
if np.isnan(val):
72+
values[i] = DecimalDtype.na_value
73+
else:
74+
values[i] = DecimalDtype.type(val)
7275
elif not isinstance(val, decimal.Decimal):
7376
raise TypeError("All values must be of type " + str(decimal.Decimal))
7477
values = np.asarray(values, dtype=object)

0 commit comments

Comments
 (0)