Skip to content

Commit 4fbd41c

Browse files
TomAugspurgermeeseeksmachine
authored andcommitted
Backport PR pandas-dev#56167: [ENH]: Expand types allowed in Series.struct.field
1 parent 0d0c792 commit 4fbd41c

File tree

3 files changed

+165
-16
lines changed

3 files changed

+165
-16
lines changed

doc/source/whatsnew/v2.2.0.rst

+8
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,14 @@ DataFrame. (:issue:`54938`)
251251
)
252252
series.struct.explode()
253253
254+
Use :meth:`Series.struct.field` to index into a (possible nested)
255+
struct field.
256+
257+
258+
.. ipython:: python
259+
260+
series.struct.field("project")
261+
254262
.. _whatsnew_220.enhancements.list_accessor:
255263

256264
Series.list accessor for PyArrow list data

pandas/core/arrays/arrow/accessors.py

+110-15
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,18 @@
66
ABCMeta,
77
abstractmethod,
88
)
9-
from typing import TYPE_CHECKING
9+
from typing import (
10+
TYPE_CHECKING,
11+
cast,
12+
)
1013

1114
from pandas.compat import (
1215
pa_version_under10p1,
1316
pa_version_under11p0,
1417
)
1518

19+
from pandas.core.dtypes.common import is_list_like
20+
1621
if not pa_version_under10p1:
1722
import pyarrow as pa
1823
import pyarrow.compute as pc
@@ -267,15 +272,27 @@ def dtypes(self) -> Series:
267272
names = [struct.name for struct in pa_type]
268273
return Series(types, index=Index(names))
269274

270-
def field(self, name_or_index: str | int) -> Series:
275+
def field(
276+
self,
277+
name_or_index: list[str]
278+
| list[bytes]
279+
| list[int]
280+
| pc.Expression
281+
| bytes
282+
| str
283+
| int,
284+
) -> Series:
271285
"""
272286
Extract a child field of a struct as a Series.
273287
274288
Parameters
275289
----------
276-
name_or_index : str | int
290+
name_or_index : str | bytes | int | expression | list
277291
Name or index of the child field to extract.
278292
293+
For list-like inputs, this will index into a nested
294+
struct.
295+
279296
Returns
280297
-------
281298
pandas.Series
@@ -285,6 +302,19 @@ def field(self, name_or_index: str | int) -> Series:
285302
--------
286303
Series.struct.explode : Return all child fields as a DataFrame.
287304
305+
Notes
306+
-----
307+
The name of the resulting Series will be set using the following
308+
rules:
309+
310+
- For string, bytes, or integer `name_or_index` (or a list of these, for
311+
a nested selection), the Series name is set to the selected
312+
field's name.
313+
- For a :class:`pyarrow.compute.Expression`, this is set to
314+
the string form of the expression.
315+
- For list-like `name_or_index`, the name will be set to the
316+
name of the final field selected.
317+
288318
Examples
289319
--------
290320
>>> import pyarrow as pa
@@ -314,27 +344,92 @@ def field(self, name_or_index: str | int) -> Series:
314344
1 2
315345
2 1
316346
Name: version, dtype: int64[pyarrow]
347+
348+
Or an expression
349+
350+
>>> import pyarrow.compute as pc
351+
>>> s.struct.field(pc.field("project"))
352+
0 pandas
353+
1 pandas
354+
2 numpy
355+
Name: project, dtype: string[pyarrow]
356+
357+
For nested struct types, you can pass a list of values to index
358+
multiple levels:
359+
360+
>>> version_type = pa.struct([
361+
... ("major", pa.int64()),
362+
... ("minor", pa.int64()),
363+
... ])
364+
>>> s = pd.Series(
365+
... [
366+
... {"version": {"major": 1, "minor": 5}, "project": "pandas"},
367+
... {"version": {"major": 2, "minor": 1}, "project": "pandas"},
368+
... {"version": {"major": 1, "minor": 26}, "project": "numpy"},
369+
... ],
370+
... dtype=pd.ArrowDtype(pa.struct(
371+
... [("version", version_type), ("project", pa.string())]
372+
... ))
373+
... )
374+
>>> s.struct.field(["version", "minor"])
375+
0 5
376+
1 1
377+
2 26
378+
Name: minor, dtype: int64[pyarrow]
379+
>>> s.struct.field([0, 0])
380+
0 1
381+
1 2
382+
2 1
383+
Name: major, dtype: int64[pyarrow]
317384
"""
318385
from pandas import Series
319386

387+
def get_name(
388+
level_name_or_index: list[str]
389+
| list[bytes]
390+
| list[int]
391+
| pc.Expression
392+
| bytes
393+
| str
394+
| int,
395+
data: pa.ChunkedArray,
396+
):
397+
if isinstance(level_name_or_index, int):
398+
name = data.type.field(level_name_or_index).name
399+
elif isinstance(level_name_or_index, (str, bytes)):
400+
name = level_name_or_index
401+
elif isinstance(level_name_or_index, pc.Expression):
402+
name = str(level_name_or_index)
403+
elif is_list_like(level_name_or_index):
404+
# For nested input like [2, 1, 2]
405+
# iteratively get the struct and field name. The last
406+
# one is used for the name of the index.
407+
level_name_or_index = list(reversed(level_name_or_index))
408+
selected = data
409+
while level_name_or_index:
410+
# we need the cast, otherwise mypy complains about
411+
# getting ints, bytes, or str here, which isn't possible.
412+
level_name_or_index = cast(list, level_name_or_index)
413+
name_or_index = level_name_or_index.pop()
414+
name = get_name(name_or_index, selected)
415+
selected = selected.type.field(selected.type.get_field_index(name))
416+
name = selected.name
417+
else:
418+
raise ValueError(
419+
"name_or_index must be an int, str, bytes, "
420+
"pyarrow.compute.Expression, or list of those"
421+
)
422+
return name
423+
320424
pa_arr = self._data.array._pa_array
321-
if isinstance(name_or_index, int):
322-
index = name_or_index
323-
elif isinstance(name_or_index, str):
324-
index = pa_arr.type.get_field_index(name_or_index)
325-
else:
326-
raise ValueError(
327-
"name_or_index must be an int or str, "
328-
f"got {type(name_or_index).__name__}"
329-
)
425+
name = get_name(name_or_index, pa_arr)
426+
field_arr = pc.struct_field(pa_arr, name_or_index)
330427

331-
pa_field = pa_arr.type[index]
332-
field_arr = pc.struct_field(pa_arr, [index])
333428
return Series(
334429
field_arr,
335430
dtype=ArrowDtype(field_arr.type),
336431
index=self._data.index,
337-
name=pa_field.name,
432+
name=name,
338433
)
339434

340435
def explode(self) -> DataFrame:

pandas/tests/series/accessors/test_struct_accessor.py

+47-1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,11 @@
22

33
import pytest
44

5+
from pandas.compat.pyarrow import (
6+
pa_version_under11p0,
7+
pa_version_under13p0,
8+
)
9+
510
from pandas import (
611
ArrowDtype,
712
DataFrame,
@@ -11,6 +16,7 @@
1116
import pandas._testing as tm
1217

1318
pa = pytest.importorskip("pyarrow")
19+
pc = pytest.importorskip("pyarrow.compute")
1420

1521

1622
def test_struct_accessor_dtypes():
@@ -53,6 +59,7 @@ def test_struct_accessor_dtypes():
5359
tm.assert_series_equal(actual, expected)
5460

5561

62+
@pytest.mark.skipif(pa_version_under13p0, reason="pyarrow>=13.0.0 required")
5663
def test_struct_accessor_field():
5764
index = Index([-100, 42, 123])
5865
ser = Series(
@@ -94,10 +101,11 @@ def test_struct_accessor_field():
94101
def test_struct_accessor_field_with_invalid_name_or_index():
95102
ser = Series([], dtype=ArrowDtype(pa.struct([("field", pa.int64())])))
96103

97-
with pytest.raises(ValueError, match="name_or_index must be an int or str"):
104+
with pytest.raises(ValueError, match="name_or_index must be an int, str,"):
98105
ser.struct.field(1.1)
99106

100107

108+
@pytest.mark.skipif(pa_version_under11p0, reason="pyarrow>=11.0.0 required")
101109
def test_struct_accessor_explode():
102110
index = Index([-100, 42, 123])
103111
ser = Series(
@@ -148,3 +156,41 @@ def test_struct_accessor_api_for_invalid(invalid):
148156
),
149157
):
150158
invalid.struct
159+
160+
161+
@pytest.mark.parametrize(
162+
["indices", "name"],
163+
[
164+
(0, "int_col"),
165+
([1, 2], "str_col"),
166+
(pc.field("int_col"), "int_col"),
167+
("int_col", "int_col"),
168+
(b"string_col", b"string_col"),
169+
([b"string_col"], "string_col"),
170+
],
171+
)
172+
@pytest.mark.skipif(pa_version_under13p0, reason="pyarrow>=13.0.0 required")
173+
def test_struct_accessor_field_expanded(indices, name):
174+
arrow_type = pa.struct(
175+
[
176+
("int_col", pa.int64()),
177+
(
178+
"struct_col",
179+
pa.struct(
180+
[
181+
("int_col", pa.int64()),
182+
("float_col", pa.float64()),
183+
("str_col", pa.string()),
184+
]
185+
),
186+
),
187+
(b"string_col", pa.string()),
188+
]
189+
)
190+
191+
data = pa.array([], type=arrow_type)
192+
ser = Series(data, dtype=ArrowDtype(arrow_type))
193+
expected = pc.struct_field(data, indices)
194+
result = ser.struct.field(indices)
195+
tm.assert_equal(result.array._pa_array.combine_chunks(), expected)
196+
assert result.name == name

0 commit comments

Comments
 (0)