Skip to content

Commit 2c1d4bb

Browse files
rohanjain101Rohan Jain
and
Rohan Jain
authored
List accessor (#55777)
* inital implementation, no documentation * revert * non list test * docstring wip * add list accessor to series.rst * whatsnew * fix * fix typehint * private * fix docstring * fail on iter * list_slice only impl in pyarrow 11 * fix docstring? * fix * fix test * fix validation msg * fix * fix * remove private * maybe fix * one more remove --------- Co-authored-by: Rohan Jain <[email protected]>
1 parent 8e0411f commit 2c1d4bb

File tree

7 files changed

+389
-31
lines changed

7 files changed

+389
-31
lines changed

doc/source/reference/series.rst

+17
Original file line numberDiff line numberDiff line change
@@ -526,6 +526,23 @@ Sparse-dtype specific methods and attributes are provided under the
526526
Series.sparse.to_coo
527527

528528

529+
.. _api.series.list:
530+
531+
List accessor
532+
~~~~~~~~~~~~~
533+
534+
Arrow list-dtype specific methods and attributes are provided under the
535+
``Series.list`` accessor.
536+
537+
.. autosummary::
538+
:toctree: api/
539+
:template: autosummary/accessor_method.rst
540+
541+
Series.list.flatten
542+
Series.list.len
543+
Series.list.__getitem__
544+
545+
529546
.. _api.series.struct:
530547

531548
Struct accessor

doc/source/whatsnew/v2.2.0.rst

+23-3
Original file line numberDiff line numberDiff line change
@@ -64,10 +64,30 @@ DataFrame. (:issue:`54938`)
6464
)
6565
series.struct.explode()
6666
67-
.. _whatsnew_220.enhancements.enhancement2:
67+
.. _whatsnew_220.enhancements.list_accessor:
6868

69-
enhancement2
70-
^^^^^^^^^^^^
69+
Series.list accessor for PyArrow list data
70+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
71+
72+
The ``Series.list`` accessor provides attributes and methods for processing
73+
data with ``list[pyarrow]`` dtype Series. For example,
74+
:meth:`Series.list.__getitem__` allows indexing pyarrow lists in
75+
a Series. (:issue:`55323`)
76+
77+
.. ipython:: python
78+
79+
import pyarrow as pa
80+
series = pd.Series(
81+
[
82+
[1, 2, 3],
83+
[4, 5],
84+
[6],
85+
],
86+
dtype=pd.ArrowDtype(
87+
pa.list_(pa.int64())
88+
),
89+
)
90+
series.list[0]
7191
7292
.. _whatsnew_220.enhancements.other:
7393

pandas/core/arrays/arrow/__init__.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
1-
from pandas.core.arrays.arrow.accessors import StructAccessor
1+
from pandas.core.arrays.arrow.accessors import (
2+
ListAccessor,
3+
StructAccessor,
4+
)
25
from pandas.core.arrays.arrow.array import ArrowExtensionArray
36

4-
__all__ = ["ArrowExtensionArray", "StructAccessor"]
7+
__all__ = ["ArrowExtensionArray", "StructAccessor", "ListAccessor"]

pandas/core/arrays/arrow/accessors.py

+203-21
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,16 @@
22

33
from __future__ import annotations
44

5+
from abc import (
6+
ABCMeta,
7+
abstractmethod,
8+
)
59
from typing import TYPE_CHECKING
610

7-
from pandas.compat import pa_version_under10p1
11+
from pandas.compat import (
12+
pa_version_under10p1,
13+
pa_version_under11p0,
14+
)
815

916
if not pa_version_under10p1:
1017
import pyarrow as pa
@@ -13,13 +20,194 @@
1320
from pandas.core.dtypes.dtypes import ArrowDtype
1421

1522
if TYPE_CHECKING:
23+
from collections.abc import Iterator
24+
1625
from pandas import (
1726
DataFrame,
1827
Series,
1928
)
2029

2130

22-
class StructAccessor:
31+
class ArrowAccessor(metaclass=ABCMeta):
32+
@abstractmethod
33+
def __init__(self, data, validation_msg: str) -> None:
34+
self._data = data
35+
self._validation_msg = validation_msg
36+
self._validate(data)
37+
38+
@abstractmethod
39+
def _is_valid_pyarrow_dtype(self, pyarrow_dtype) -> bool:
40+
pass
41+
42+
def _validate(self, data):
43+
dtype = data.dtype
44+
if not isinstance(dtype, ArrowDtype):
45+
# Raise AttributeError so that inspect can handle non-struct Series.
46+
raise AttributeError(self._validation_msg.format(dtype=dtype))
47+
48+
if not self._is_valid_pyarrow_dtype(dtype.pyarrow_dtype):
49+
# Raise AttributeError so that inspect can handle invalid Series.
50+
raise AttributeError(self._validation_msg.format(dtype=dtype))
51+
52+
@property
53+
def _pa_array(self):
54+
return self._data.array._pa_array
55+
56+
57+
class ListAccessor(ArrowAccessor):
58+
"""
59+
Accessor object for list data properties of the Series values.
60+
61+
Parameters
62+
----------
63+
data : Series
64+
Series containing Arrow list data.
65+
"""
66+
67+
def __init__(self, data=None) -> None:
68+
super().__init__(
69+
data,
70+
validation_msg="Can only use the '.list' accessor with "
71+
"'list[pyarrow]' dtype, not {dtype}.",
72+
)
73+
74+
def _is_valid_pyarrow_dtype(self, pyarrow_dtype) -> bool:
75+
return (
76+
pa.types.is_list(pyarrow_dtype)
77+
or pa.types.is_fixed_size_list(pyarrow_dtype)
78+
or pa.types.is_large_list(pyarrow_dtype)
79+
)
80+
81+
def len(self) -> Series:
82+
"""
83+
Return the length of each list in the Series.
84+
85+
Returns
86+
-------
87+
pandas.Series
88+
The length of each list.
89+
90+
Examples
91+
--------
92+
>>> import pyarrow as pa
93+
>>> s = pd.Series(
94+
... [
95+
... [1, 2, 3],
96+
... [3],
97+
... ],
98+
... dtype=pd.ArrowDtype(pa.list_(
99+
... pa.int64()
100+
... ))
101+
... )
102+
>>> s.list.len()
103+
0 3
104+
1 1
105+
dtype: int32[pyarrow]
106+
"""
107+
from pandas import Series
108+
109+
value_lengths = pc.list_value_length(self._pa_array)
110+
return Series(value_lengths, dtype=ArrowDtype(value_lengths.type))
111+
112+
def __getitem__(self, key: int | slice) -> Series:
113+
"""
114+
Index or slice lists in the Series.
115+
116+
Parameters
117+
----------
118+
key : int | slice
119+
Index or slice of indices to access from each list.
120+
121+
Returns
122+
-------
123+
pandas.Series
124+
The list at requested index.
125+
126+
Examples
127+
--------
128+
>>> import pyarrow as pa
129+
>>> s = pd.Series(
130+
... [
131+
... [1, 2, 3],
132+
... [3],
133+
... ],
134+
... dtype=pd.ArrowDtype(pa.list_(
135+
... pa.int64()
136+
... ))
137+
... )
138+
>>> s.list[0]
139+
0 1
140+
1 3
141+
dtype: int64[pyarrow]
142+
"""
143+
from pandas import Series
144+
145+
if isinstance(key, int):
146+
# TODO: Support negative key but pyarrow does not allow
147+
# element index to be an array.
148+
# if key < 0:
149+
# key = pc.add(key, pc.list_value_length(self._pa_array))
150+
element = pc.list_element(self._pa_array, key)
151+
return Series(element, dtype=ArrowDtype(element.type))
152+
elif isinstance(key, slice):
153+
if pa_version_under11p0:
154+
raise NotImplementedError(
155+
f"List slice not supported by pyarrow {pa.__version__}."
156+
)
157+
158+
# TODO: Support negative start/stop/step, ideally this would be added
159+
# upstream in pyarrow.
160+
start, stop, step = key.start, key.stop, key.step
161+
if start is None:
162+
# TODO: When adding negative step support
163+
# this should be setto last element of array
164+
# when step is negative.
165+
start = 0
166+
if step is None:
167+
step = 1
168+
sliced = pc.list_slice(self._pa_array, start, stop, step)
169+
return Series(sliced, dtype=ArrowDtype(sliced.type))
170+
else:
171+
raise ValueError(f"key must be an int or slice, got {type(key).__name__}")
172+
173+
def __iter__(self) -> Iterator:
174+
raise TypeError(f"'{type(self).__name__}' object is not iterable")
175+
176+
def flatten(self) -> Series:
177+
"""
178+
Flatten list values.
179+
180+
Returns
181+
-------
182+
pandas.Series
183+
The data from all lists in the series flattened.
184+
185+
Examples
186+
--------
187+
>>> import pyarrow as pa
188+
>>> s = pd.Series(
189+
... [
190+
... [1, 2, 3],
191+
... [3],
192+
... ],
193+
... dtype=pd.ArrowDtype(pa.list_(
194+
... pa.int64()
195+
... ))
196+
... )
197+
>>> s.list.flatten()
198+
0 1
199+
1 2
200+
2 3
201+
3 3
202+
dtype: int64[pyarrow]
203+
"""
204+
from pandas import Series
205+
206+
flattened = pc.list_flatten(self._pa_array)
207+
return Series(flattened, dtype=ArrowDtype(flattened.type))
208+
209+
210+
class StructAccessor(ArrowAccessor):
23211
"""
24212
Accessor object for structured data properties of the Series values.
25213
@@ -29,23 +217,17 @@ class StructAccessor:
29217
Series containing Arrow struct data.
30218
"""
31219

32-
_validation_msg = (
33-
"Can only use the '.struct' accessor with 'struct[pyarrow]' dtype, not {dtype}."
34-
)
35-
36220
def __init__(self, data=None) -> None:
37-
self._parent = data
38-
self._validate(data)
39-
40-
def _validate(self, data):
41-
dtype = data.dtype
42-
if not isinstance(dtype, ArrowDtype):
43-
# Raise AttributeError so that inspect can handle non-struct Series.
44-
raise AttributeError(self._validation_msg.format(dtype=dtype))
221+
super().__init__(
222+
data,
223+
validation_msg=(
224+
"Can only use the '.struct' accessor with 'struct[pyarrow]' "
225+
"dtype, not {dtype}."
226+
),
227+
)
45228

46-
if not pa.types.is_struct(dtype.pyarrow_dtype):
47-
# Raise AttributeError so that inspect can handle non-struct Series.
48-
raise AttributeError(self._validation_msg.format(dtype=dtype))
229+
def _is_valid_pyarrow_dtype(self, pyarrow_dtype) -> bool:
230+
return pa.types.is_struct(pyarrow_dtype)
49231

50232
@property
51233
def dtypes(self) -> Series:
@@ -80,7 +262,7 @@ def dtypes(self) -> Series:
80262
Series,
81263
)
82264

83-
pa_type = self._parent.dtype.pyarrow_dtype
265+
pa_type = self._data.dtype.pyarrow_dtype
84266
types = [ArrowDtype(struct.type) for struct in pa_type]
85267
names = [struct.name for struct in pa_type]
86268
return Series(types, index=Index(names))
@@ -135,7 +317,7 @@ def field(self, name_or_index: str | int) -> Series:
135317
"""
136318
from pandas import Series
137319

138-
pa_arr = self._parent.array._pa_array
320+
pa_arr = self._data.array._pa_array
139321
if isinstance(name_or_index, int):
140322
index = name_or_index
141323
elif isinstance(name_or_index, str):
@@ -151,7 +333,7 @@ def field(self, name_or_index: str | int) -> Series:
151333
return Series(
152334
field_arr,
153335
dtype=ArrowDtype(field_arr.type),
154-
index=self._parent.index,
336+
index=self._data.index,
155337
name=pa_field.name,
156338
)
157339

@@ -190,7 +372,7 @@ def explode(self) -> DataFrame:
190372
"""
191373
from pandas import concat
192374

193-
pa_type = self._parent.dtype.pyarrow_dtype
375+
pa_type = self._pa_array.type
194376
return concat(
195377
[self.field(i) for i in range(pa_type.num_fields)], axis="columns"
196378
)

pandas/core/series.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,10 @@
100100
from pandas.core.accessor import CachedAccessor
101101
from pandas.core.apply import SeriesApply
102102
from pandas.core.arrays import ExtensionArray
103-
from pandas.core.arrays.arrow import StructAccessor
103+
from pandas.core.arrays.arrow import (
104+
ListAccessor,
105+
StructAccessor,
106+
)
104107
from pandas.core.arrays.categorical import CategoricalAccessor
105108
from pandas.core.arrays.sparse import SparseAccessor
106109
from pandas.core.arrays.string_ import StringDtype
@@ -5891,6 +5894,7 @@ def to_period(self, freq: str | None = None, copy: bool | None = None) -> Series
58915894
plot = CachedAccessor("plot", pandas.plotting.PlotAccessor)
58925895
sparse = CachedAccessor("sparse", SparseAccessor)
58935896
struct = CachedAccessor("struct", StructAccessor)
5897+
list = CachedAccessor("list", ListAccessor)
58945898

58955899
# ----------------------------------------------------------------------
58965900
# Add plotting methods to Series

0 commit comments

Comments
 (0)