diff --git a/doc/source/whatsnew/v2.1.0.rst b/doc/source/whatsnew/v2.1.0.rst index bf1a7cd683a89..bea2ad8c7450c 100644 --- a/doc/source/whatsnew/v2.1.0.rst +++ b/doc/source/whatsnew/v2.1.0.rst @@ -102,6 +102,7 @@ Other enhancements - :meth:`DataFrame.stack` gained the ``sort`` keyword to dictate whether the resulting :class:`MultiIndex` levels are sorted (:issue:`15105`) - :meth:`DataFrame.unstack` gained the ``sort`` keyword to dictate whether the resulting :class:`MultiIndex` levels are sorted (:issue:`15105`) - :meth:`DataFrameGroupby.agg` and :meth:`DataFrameGroupby.transform` now support grouping by multiple keys when the index is not a :class:`MultiIndex` for ``engine="numba"`` (:issue:`53486`) +- :meth:`Series.explode` now supports pyarrow-backed list types (:issue:`53602`) - :meth:`Series.str.join` now supports ``ArrowDtype(pa.string())`` (:issue:`53646`) - :meth:`SeriesGroupby.agg` and :meth:`DataFrameGroupby.agg` now support passing in multiple functions for ``engine="numba"`` (:issue:`53486`) - :meth:`SeriesGroupby.transform` and :meth:`DataFrameGroupby.transform` now support passing in a string as the function for ``engine="numba"`` (:issue:`53579`) diff --git a/pandas/core/arrays/arrow/array.py b/pandas/core/arrays/arrow/array.py index 601b418296e7f..0c1b86440b11d 100644 --- a/pandas/core/arrays/arrow/array.py +++ b/pandas/core/arrays/arrow/array.py @@ -347,9 +347,9 @@ def _box_pa( ------- pa.Array or pa.ChunkedArray or pa.Scalar """ - if is_list_like(value): - return cls._box_pa_array(value, pa_type) - return cls._box_pa_scalar(value, pa_type) + if isinstance(value, pa.Scalar) or not is_list_like(value): + return cls._box_pa_scalar(value, pa_type) + return cls._box_pa_array(value, pa_type) @classmethod def _box_pa_scalar(cls, value, pa_type: pa.DataType | None = None) -> pa.Scalar: @@ -1549,6 +1549,24 @@ def _reduce(self, name: str, *, skipna: bool = True, **kwargs): return result.as_py() + def _explode(self): + """ + See Series.explode.__doc__. + """ + values = self + counts = pa.compute.list_value_length(values._pa_array) + counts = counts.fill_null(1).to_numpy() + fill_value = pa.scalar([None], type=self._pa_array.type) + mask = counts == 0 + if mask.any(): + values = values.copy() + values[mask] = fill_value + counts = counts.copy() + counts[mask] = 1 + values = values.fillna(fill_value) + values = type(self)(pa.compute.list_flatten(values._pa_array)) + return values, counts + def __setitem__(self, key, value) -> None: """Set one or more values inplace. @@ -1591,10 +1609,10 @@ def __setitem__(self, key, value) -> None: raise IndexError( f"index {key} is out of bounds for axis 0 with size {n}" ) - if is_list_like(value): - raise ValueError("Length of indexer and values mismatch") - elif isinstance(value, pa.Scalar): + if isinstance(value, pa.Scalar): value = value.as_py() + elif is_list_like(value): + raise ValueError("Length of indexer and values mismatch") chunks = [ *self._pa_array[:key].chunks, pa.array([value], type=self._pa_array.type, from_pandas=True), diff --git a/pandas/core/series.py b/pandas/core/series.py index 9c7110cc21082..959c153561572 100644 --- a/pandas/core/series.py +++ b/pandas/core/series.py @@ -72,7 +72,10 @@ pandas_dtype, validate_all_hashable, ) -from pandas.core.dtypes.dtypes import ExtensionDtype +from pandas.core.dtypes.dtypes import ( + ArrowDtype, + ExtensionDtype, +) from pandas.core.dtypes.generic import ABCDataFrame from pandas.core.dtypes.inference import is_hashable from pandas.core.dtypes.missing import ( @@ -4267,12 +4270,14 @@ def explode(self, ignore_index: bool = False) -> Series: 3 4 dtype: object """ - if not len(self) or not is_object_dtype(self.dtype): + if isinstance(self.dtype, ArrowDtype) and self.dtype.type == list: + values, counts = self._values._explode() + elif len(self) and is_object_dtype(self.dtype): + values, counts = reshape.explode(np.asarray(self._values)) + else: result = self.copy() return result.reset_index(drop=True) if ignore_index else result - values, counts = reshape.explode(np.asarray(self._values)) - if ignore_index: index = default_index(len(values)) else: diff --git a/pandas/tests/series/methods/test_explode.py b/pandas/tests/series/methods/test_explode.py index 886152326cf3e..c8a9eb6f89fde 100644 --- a/pandas/tests/series/methods/test_explode.py +++ b/pandas/tests/series/methods/test_explode.py @@ -141,3 +141,25 @@ def test_explode_scalars_can_ignore_index(): result = s.explode(ignore_index=True) expected = pd.Series([1, 2, 3]) tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize("ignore_index", [True, False]) +def test_explode_pyarrow_list_type(ignore_index): + # GH 53602 + pa = pytest.importorskip("pyarrow") + + data = [ + [None, None], + [1], + [], + [2, 3], + None, + ] + ser = pd.Series(data, dtype=pd.ArrowDtype(pa.list_(pa.int64()))) + result = ser.explode(ignore_index=ignore_index) + expected = pd.Series( + data=[None, None, 1, None, 2, 3, None], + index=None if ignore_index else [0, 0, 1, 2, 3, 3, 4], + dtype=pd.ArrowDtype(pa.int64()), + ) + tm.assert_series_equal(result, expected)