Skip to content

Commit ee43ae0

Browse files
authored
ENH: Series.explode to support pyarrow-backed list types (pandas-dev#53602)
* ENH: Series.explode to support pyarrow-backed list types * gh refs * update test
1 parent abcd440 commit ee43ae0

File tree

4 files changed

+56
-10
lines changed

4 files changed

+56
-10
lines changed

doc/source/whatsnew/v2.1.0.rst

+1
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ Other enhancements
102102
- :meth:`DataFrame.stack` gained the ``sort`` keyword to dictate whether the resulting :class:`MultiIndex` levels are sorted (:issue:`15105`)
103103
- :meth:`DataFrame.unstack` gained the ``sort`` keyword to dictate whether the resulting :class:`MultiIndex` levels are sorted (:issue:`15105`)
104104
- :meth:`DataFrameGroupby.agg` and :meth:`DataFrameGroupby.transform` now support grouping by multiple keys when the index is not a :class:`MultiIndex` for ``engine="numba"`` (:issue:`53486`)
105+
- :meth:`Series.explode` now supports pyarrow-backed list types (:issue:`53602`)
105106
- :meth:`Series.str.join` now supports ``ArrowDtype(pa.string())`` (:issue:`53646`)
106107
- :meth:`SeriesGroupby.agg` and :meth:`DataFrameGroupby.agg` now support passing in multiple functions for ``engine="numba"`` (:issue:`53486`)
107108
- :meth:`SeriesGroupby.transform` and :meth:`DataFrameGroupby.transform` now support passing in a string as the function for ``engine="numba"`` (:issue:`53579`)

pandas/core/arrays/arrow/array.py

+24-6
Original file line numberDiff line numberDiff line change
@@ -347,9 +347,9 @@ def _box_pa(
347347
-------
348348
pa.Array or pa.ChunkedArray or pa.Scalar
349349
"""
350-
if is_list_like(value):
351-
return cls._box_pa_array(value, pa_type)
352-
return cls._box_pa_scalar(value, pa_type)
350+
if isinstance(value, pa.Scalar) or not is_list_like(value):
351+
return cls._box_pa_scalar(value, pa_type)
352+
return cls._box_pa_array(value, pa_type)
353353

354354
@classmethod
355355
def _box_pa_scalar(cls, value, pa_type: pa.DataType | None = None) -> pa.Scalar:
@@ -1549,6 +1549,24 @@ def _reduce(self, name: str, *, skipna: bool = True, **kwargs):
15491549

15501550
return result.as_py()
15511551

1552+
def _explode(self):
1553+
"""
1554+
See Series.explode.__doc__.
1555+
"""
1556+
values = self
1557+
counts = pa.compute.list_value_length(values._pa_array)
1558+
counts = counts.fill_null(1).to_numpy()
1559+
fill_value = pa.scalar([None], type=self._pa_array.type)
1560+
mask = counts == 0
1561+
if mask.any():
1562+
values = values.copy()
1563+
values[mask] = fill_value
1564+
counts = counts.copy()
1565+
counts[mask] = 1
1566+
values = values.fillna(fill_value)
1567+
values = type(self)(pa.compute.list_flatten(values._pa_array))
1568+
return values, counts
1569+
15521570
def __setitem__(self, key, value) -> None:
15531571
"""Set one or more values inplace.
15541572
@@ -1591,10 +1609,10 @@ def __setitem__(self, key, value) -> None:
15911609
raise IndexError(
15921610
f"index {key} is out of bounds for axis 0 with size {n}"
15931611
)
1594-
if is_list_like(value):
1595-
raise ValueError("Length of indexer and values mismatch")
1596-
elif isinstance(value, pa.Scalar):
1612+
if isinstance(value, pa.Scalar):
15971613
value = value.as_py()
1614+
elif is_list_like(value):
1615+
raise ValueError("Length of indexer and values mismatch")
15981616
chunks = [
15991617
*self._pa_array[:key].chunks,
16001618
pa.array([value], type=self._pa_array.type, from_pandas=True),

pandas/core/series.py

+9-4
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,10 @@
7272
pandas_dtype,
7373
validate_all_hashable,
7474
)
75-
from pandas.core.dtypes.dtypes import ExtensionDtype
75+
from pandas.core.dtypes.dtypes import (
76+
ArrowDtype,
77+
ExtensionDtype,
78+
)
7679
from pandas.core.dtypes.generic import ABCDataFrame
7780
from pandas.core.dtypes.inference import is_hashable
7881
from pandas.core.dtypes.missing import (
@@ -4267,12 +4270,14 @@ def explode(self, ignore_index: bool = False) -> Series:
42674270
3 4
42684271
dtype: object
42694272
"""
4270-
if not len(self) or not is_object_dtype(self.dtype):
4273+
if isinstance(self.dtype, ArrowDtype) and self.dtype.type == list:
4274+
values, counts = self._values._explode()
4275+
elif len(self) and is_object_dtype(self.dtype):
4276+
values, counts = reshape.explode(np.asarray(self._values))
4277+
else:
42714278
result = self.copy()
42724279
return result.reset_index(drop=True) if ignore_index else result
42734280

4274-
values, counts = reshape.explode(np.asarray(self._values))
4275-
42764281
if ignore_index:
42774282
index = default_index(len(values))
42784283
else:

pandas/tests/series/methods/test_explode.py

+22
Original file line numberDiff line numberDiff line change
@@ -141,3 +141,25 @@ def test_explode_scalars_can_ignore_index():
141141
result = s.explode(ignore_index=True)
142142
expected = pd.Series([1, 2, 3])
143143
tm.assert_series_equal(result, expected)
144+
145+
146+
@pytest.mark.parametrize("ignore_index", [True, False])
147+
def test_explode_pyarrow_list_type(ignore_index):
148+
# GH 53602
149+
pa = pytest.importorskip("pyarrow")
150+
151+
data = [
152+
[None, None],
153+
[1],
154+
[],
155+
[2, 3],
156+
None,
157+
]
158+
ser = pd.Series(data, dtype=pd.ArrowDtype(pa.list_(pa.int64())))
159+
result = ser.explode(ignore_index=ignore_index)
160+
expected = pd.Series(
161+
data=[None, None, 1, None, 2, 3, None],
162+
index=None if ignore_index else [0, 0, 1, 2, 3, 3, 4],
163+
dtype=pd.ArrowDtype(pa.int64()),
164+
)
165+
tm.assert_series_equal(result, expected)

0 commit comments

Comments
 (0)