Skip to content

Commit c4c43d0

Browse files
jorisvandenbosscheproost
authored andcommitted
ENH: Support arrow/parquet roundtrip for nullable integer / string extension dtypes (pandas-dev#29483)
* Add __from_arrow__ support for IntegerArray, StringArray
1 parent b8bdd73 commit c4c43d0

File tree

8 files changed

+133
-4
lines changed

8 files changed

+133
-4
lines changed

doc/source/development/extending.rst

+42
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,48 @@ To use a test, subclass it:
251251
See https://github.com/pandas-dev/pandas/blob/master/pandas/tests/extension/base/__init__.py
252252
for a list of all the tests available.
253253

254+
.. _extending.extension.arrow:
255+
256+
Compatibility with Apache Arrow
257+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
258+
259+
An ``ExtensionArray`` can support conversion to / from ``pyarrow`` arrays
260+
(and thus support for example serialization to the Parquet file format)
261+
by implementing two methods: ``ExtensionArray.__arrow_array__`` and
262+
``ExtensionDtype.__from_arrow__``.
263+
264+
The ``ExtensionArray.__arrow_array__`` ensures that ``pyarrow`` knowns how
265+
to convert the specific extension array into a ``pyarrow.Array`` (also when
266+
included as a column in a pandas DataFrame):
267+
268+
.. code-block:: python
269+
270+
class MyExtensionArray(ExtensionArray):
271+
...
272+
273+
def __arrow_array__(self, type=None):
274+
# convert the underlying array values to a pyarrow Array
275+
import pyarrow
276+
return pyarrow.array(..., type=type)
277+
278+
The ``ExtensionDtype.__from_arrow__`` method then controls the conversion
279+
back from pyarrow to a pandas ExtensionArray. This method receives a pyarrow
280+
``Array`` or ``ChunkedArray`` as only argument and is expected to return the
281+
appropriate pandas ``ExtensionArray`` for this dtype and the passed values:
282+
283+
.. code-block:: none
284+
285+
class ExtensionDtype:
286+
...
287+
288+
def __from_arrow__(self, array: pyarrow.Array/ChunkedArray) -> ExtensionArray:
289+
...
290+
291+
See more in the `Arrow documentation <https://arrow.apache.org/docs/python/extending_types.html>`__.
292+
293+
Those methods have been implemented for the nullable integer and string extension
294+
dtypes included in pandas, and ensure roundtrip to pyarrow and the Parquet file format.
295+
254296
.. _extension dtype dtypes: https://github.com/pandas-dev/pandas/blob/master/pandas/core/dtypes/dtypes.py
255297
.. _extension dtype source: https://github.com/pandas-dev/pandas/blob/master/pandas/core/dtypes/base.py
256298
.. _extension array source: https://github.com/pandas-dev/pandas/blob/master/pandas/core/arrays/base.py

doc/source/user_guide/io.rst

+3
Original file line numberDiff line numberDiff line change
@@ -4716,6 +4716,9 @@ Several caveats.
47164716
* The ``pyarrow`` engine preserves the ``ordered`` flag of categorical dtypes with string types. ``fastparquet`` does not preserve the ``ordered`` flag.
47174717
* Non supported types include ``Period`` and actual Python object types. These will raise a helpful error message
47184718
on an attempt at serialization.
4719+
* The ``pyarrow`` engine preserves extension data types such as the nullable integer and string data
4720+
type (requiring pyarrow >= 1.0.0, and requiring the extension type to implement the needed protocols,
4721+
see the :ref:`extension types documentation <extending.extension.arrow>`).
47194722

47204723
You can specify an ``engine`` to direct the serialization. This can be one of ``pyarrow``, or ``fastparquet``, or ``auto``.
47214724
If the engine is NOT specified, then the ``pd.options.io.parquet.engine`` option is checked; if this is also ``auto``,

doc/source/whatsnew/v1.0.0.rst

+3
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,9 @@ Other enhancements
114114
- Added ``encoding`` argument to :meth:`DataFrame.to_string` for non-ascii text (:issue:`28766`)
115115
- Added ``encoding`` argument to :func:`DataFrame.to_html` for non-ascii text (:issue:`28663`)
116116
- :meth:`Styler.background_gradient` now accepts ``vmin`` and ``vmax`` arguments (:issue:`12145`)
117+
- Roundtripping DataFrames with nullable integer or string data types to parquet
118+
(:meth:`~DataFrame.to_parquet` / :func:`read_parquet`) using the `'pyarrow'` engine
119+
now preserve those data types with pyarrow >= 1.0.0 (:issue:`20612`).
117120

118121
Build Changes
119122
^^^^^^^^^^^^^

pandas/core/arrays/integer.py

+29
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,35 @@ def construct_array_type(cls):
8585
"""
8686
return IntegerArray
8787

88+
def __from_arrow__(self, array):
89+
"""Construct IntegerArray from passed pyarrow Array/ChunkedArray"""
90+
import pyarrow
91+
92+
if isinstance(array, pyarrow.Array):
93+
chunks = [array]
94+
else:
95+
# pyarrow.ChunkedArray
96+
chunks = array.chunks
97+
98+
results = []
99+
for arr in chunks:
100+
buflist = arr.buffers()
101+
data = np.frombuffer(buflist[1], dtype=self.type)[
102+
arr.offset : arr.offset + len(arr)
103+
]
104+
bitmask = buflist[0]
105+
if bitmask is not None:
106+
mask = pyarrow.BooleanArray.from_buffers(
107+
pyarrow.bool_(), len(arr), [None, bitmask]
108+
)
109+
mask = np.asarray(mask)
110+
else:
111+
mask = np.ones(len(arr), dtype=bool)
112+
int_arr = IntegerArray(data.copy(), ~mask, copy=False)
113+
results.append(int_arr)
114+
115+
return IntegerArray._concat_same_type(results)
116+
88117

89118
def integer_array(values, dtype=None, copy=False):
90119
"""

pandas/core/arrays/string_.py

+18
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,24 @@ def construct_array_type(cls) -> "Type[StringArray]":
8585
def __repr__(self) -> str:
8686
return "StringDtype"
8787

88+
def __from_arrow__(self, array):
89+
"""Construct StringArray from passed pyarrow Array/ChunkedArray"""
90+
import pyarrow
91+
92+
if isinstance(array, pyarrow.Array):
93+
chunks = [array]
94+
else:
95+
# pyarrow.ChunkedArray
96+
chunks = array.chunks
97+
98+
results = []
99+
for arr in chunks:
100+
# using _from_sequence to ensure None is convered to np.nan
101+
str_arr = StringArray._from_sequence(np.array(arr))
102+
results.append(str_arr)
103+
104+
return StringArray._concat_same_type(results)
105+
88106

89107
class StringArray(PandasArray):
90108
"""

pandas/tests/arrays/string_/test_string.py

+16
Original file line numberDiff line numberDiff line change
@@ -171,3 +171,19 @@ def test_arrow_array():
171171
arr = pa.array(data)
172172
expected = pa.array(list(data), type=pa.string(), from_pandas=True)
173173
assert arr.equals(expected)
174+
175+
176+
@td.skip_if_no("pyarrow", min_version="0.15.1.dev")
177+
def test_arrow_roundtrip():
178+
# roundtrip possible from arrow 1.0.0
179+
import pyarrow as pa
180+
181+
data = pd.array(["a", "b", None], dtype="string")
182+
df = pd.DataFrame({"a": data})
183+
table = pa.table(df)
184+
assert table.field("a").type == "string"
185+
result = table.to_pandas()
186+
assert isinstance(result["a"].dtype, pd.StringDtype)
187+
tm.assert_frame_equal(result, df)
188+
# ensure the missing value is represented by NaN and not None
189+
assert np.isnan(result.loc[2, "a"])

pandas/tests/arrays/test_integer.py

+12
Original file line numberDiff line numberDiff line change
@@ -829,6 +829,18 @@ def test_arrow_array(data):
829829
assert arr.equals(expected)
830830

831831

832+
@td.skip_if_no("pyarrow", min_version="0.15.1.dev")
833+
def test_arrow_roundtrip(data):
834+
# roundtrip possible from arrow 1.0.0
835+
import pyarrow as pa
836+
837+
df = pd.DataFrame({"a": data})
838+
table = pa.table(df)
839+
assert table.field("a").type == str(data.dtype.numpy_dtype)
840+
result = table.to_pandas()
841+
tm.assert_frame_equal(result, df)
842+
843+
832844
@pytest.mark.parametrize(
833845
"pandasmethname, kwargs",
834846
[

pandas/tests/io/test_parquet.py

+10-4
Original file line numberDiff line numberDiff line change
@@ -514,13 +514,19 @@ def test_additional_extension_arrays(self, pa):
514514
"b": pd.Series(["a", None, "c"], dtype="string"),
515515
}
516516
)
517-
# currently de-serialized as plain int / object
518-
expected = df.assign(a=df.a.astype("int64"), b=df.b.astype("object"))
517+
if LooseVersion(pyarrow.__version__) >= LooseVersion("0.15.1.dev"):
518+
expected = df
519+
else:
520+
# de-serialized as plain int / object
521+
expected = df.assign(a=df.a.astype("int64"), b=df.b.astype("object"))
519522
check_round_trip(df, pa, expected=expected)
520523

521524
df = pd.DataFrame({"a": pd.Series([1, 2, 3, None], dtype="Int64")})
522-
# if missing values in integer, currently de-serialized as float
523-
expected = df.assign(a=df.a.astype("float64"))
525+
if LooseVersion(pyarrow.__version__) >= LooseVersion("0.15.1.dev"):
526+
expected = df
527+
else:
528+
# if missing values in integer, currently de-serialized as float
529+
expected = df.assign(a=df.a.astype("float64"))
524530
check_round_trip(df, pa, expected=expected)
525531

526532

0 commit comments

Comments
 (0)