Skip to content

Commit 1d9c6fc

Browse files
douglasdavismroeschke
authored andcommitted
ENH: add ExtensionArray._explode method; adjust pyarrow extension for use of new interface (pandas-dev#54834)
* add ExtensionArray._explode method; adjust pyarrow extension for use * black * add to whatsnew 2.1.0 * pre-commit fix * add _explode to docs * Update pandas/core/arrays/arrow/array.py Co-authored-by: Matthew Roeschke <[email protected]> * switch whatsnew files * adjust docstring * fix docstring --------- Co-authored-by: Matthew Roeschke <[email protected]>
1 parent 388de53 commit 1d9c6fc

File tree

6 files changed

+54
-5
lines changed

6 files changed

+54
-5
lines changed

doc/source/reference/extensions.rst

+1
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ objects.
3434
3535
api.extensions.ExtensionArray._accumulate
3636
api.extensions.ExtensionArray._concat_same_type
37+
api.extensions.ExtensionArray._explode
3738
api.extensions.ExtensionArray._formatter
3839
api.extensions.ExtensionArray._from_factorized
3940
api.extensions.ExtensionArray._from_sequence

doc/source/whatsnew/v2.2.0.rst

+1
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ enhancement2
7373

7474
Other enhancements
7575
^^^^^^^^^^^^^^^^^^
76+
- :meth:`ExtensionArray._explode` interface method added to allow extension type implementations of the ``explode`` method (:issue:`54833`)
7677
- DataFrame.apply now allows the usage of numba (via ``engine="numba"``) to JIT compile the passed function, allowing for potential speedups (:issue:`54666`)
7778
-
7879

pandas/core/arrays/arrow/array.py

+4
Original file line numberDiff line numberDiff line change
@@ -1609,6 +1609,10 @@ def _explode(self):
16091609
"""
16101610
See Series.explode.__doc__.
16111611
"""
1612+
# child class explode method supports only list types; return
1613+
# default implementation for non list types.
1614+
if not pa.types.is_list(self.dtype.pyarrow_dtype):
1615+
return super()._explode()
16121616
values = self
16131617
counts = pa.compute.list_value_length(values._pa_array)
16141618
counts = counts.fill_null(1).to_numpy()

pandas/core/arrays/base.py

+36
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,7 @@ class ExtensionArray:
142142
view
143143
_accumulate
144144
_concat_same_type
145+
_explode
145146
_formatter
146147
_from_factorized
147148
_from_sequence
@@ -1924,6 +1925,41 @@ def _hash_pandas_object(
19241925
values, encoding=encoding, hash_key=hash_key, categorize=categorize
19251926
)
19261927

1928+
def _explode(self) -> tuple[Self, npt.NDArray[np.uint64]]:
1929+
"""
1930+
Transform each element of list-like to a row.
1931+
1932+
For arrays that do not contain list-like elements the default
1933+
implementation of this method just returns a copy and an array
1934+
of ones (unchanged index).
1935+
1936+
Returns
1937+
-------
1938+
ExtensionArray
1939+
Array with the exploded values.
1940+
np.ndarray[uint64]
1941+
The original lengths of each list-like for determining the
1942+
resulting index.
1943+
1944+
See Also
1945+
--------
1946+
Series.explode : The method on the ``Series`` object that this
1947+
extension array method is meant to support.
1948+
1949+
Examples
1950+
--------
1951+
>>> import pyarrow as pa
1952+
>>> a = pd.array([[1, 2, 3], [4], [5, 6]],
1953+
... dtype=pd.ArrowDtype(pa.list_(pa.int64())))
1954+
>>> a._explode()
1955+
(<ArrowExtensionArray>
1956+
[1, 2, 3, 4, 5, 6]
1957+
Length: 6, dtype: int64[pyarrow], array([3, 1, 2], dtype=int32))
1958+
"""
1959+
values = self.copy()
1960+
counts = np.ones(shape=(len(self),), dtype=np.uint64)
1961+
return values, counts
1962+
19271963
def tolist(self) -> list:
19281964
"""
19291965
Return a list of the values.

pandas/core/series.py

+2-5
Original file line numberDiff line numberDiff line change
@@ -76,10 +76,7 @@
7676
pandas_dtype,
7777
validate_all_hashable,
7878
)
79-
from pandas.core.dtypes.dtypes import (
80-
ArrowDtype,
81-
ExtensionDtype,
82-
)
79+
from pandas.core.dtypes.dtypes import ExtensionDtype
8380
from pandas.core.dtypes.generic import ABCDataFrame
8481
from pandas.core.dtypes.inference import is_hashable
8582
from pandas.core.dtypes.missing import (
@@ -4390,7 +4387,7 @@ def explode(self, ignore_index: bool = False) -> Series:
43904387
3 4
43914388
dtype: object
43924389
"""
4393-
if isinstance(self.dtype, ArrowDtype) and self.dtype.type == list:
4390+
if isinstance(self.dtype, ExtensionDtype):
43944391
values, counts = self._values._explode()
43954392
elif len(self) and is_object_dtype(self.dtype):
43964393
values, counts = reshape.explode(np.asarray(self._values))

pandas/tests/series/methods/test_explode.py

+10
Original file line numberDiff line numberDiff line change
@@ -163,3 +163,13 @@ def test_explode_pyarrow_list_type(ignore_index):
163163
dtype=pd.ArrowDtype(pa.int64()),
164164
)
165165
tm.assert_series_equal(result, expected)
166+
167+
168+
@pytest.mark.parametrize("ignore_index", [True, False])
169+
def test_explode_pyarrow_non_list_type(ignore_index):
170+
pa = pytest.importorskip("pyarrow")
171+
data = [1, 2, 3]
172+
ser = pd.Series(data, dtype=pd.ArrowDtype(pa.int64()))
173+
result = ser.explode(ignore_index=ignore_index)
174+
expected = pd.Series([1, 2, 3], dtype="int64[pyarrow]", index=[0, 1, 2])
175+
tm.assert_series_equal(result, expected)

0 commit comments

Comments
 (0)