Skip to content

Commit 49d66ec

Browse files
committed
[ENH]: Expand types allowed in Series.struct.field
This expands the set of types allowed by Series.struct.field to allow those allowed by pyarrow. Closes pandas-dev#56065
1 parent 3530b3d commit 49d66ec

File tree

2 files changed

+127
-14
lines changed

2 files changed

+127
-14
lines changed

pandas/core/arrays/arrow/accessors.py

+88-13
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
pa_version_under11p0,
1414
)
1515

16+
from pandas.core.dtypes.common import is_list_like
17+
1618
if not pa_version_under10p1:
1719
import pyarrow as pa
1820
import pyarrow.compute as pc
@@ -267,7 +269,16 @@ def dtypes(self) -> Series:
267269
names = [struct.name for struct in pa_type]
268270
return Series(types, index=Index(names))
269271

270-
def field(self, name_or_index: str | int) -> Series:
272+
def field(
273+
self,
274+
name_or_index: list[str]
275+
| list[bytes]
276+
| list[int]
277+
| pc.Expression
278+
| bytes
279+
| str
280+
| int,
281+
) -> Series:
271282
"""
272283
Extract a child field of a struct as a Series.
273284
@@ -281,6 +292,17 @@ def field(self, name_or_index: str | int) -> Series:
281292
pandas.Series
282293
The data corresponding to the selected child field.
283294
295+
Notes
296+
-----
297+
The name of the resulting Series will be set using the following
298+
rules:
299+
300+
- For string, bytes, or integer `name_or_index` (or a list of these, for
301+
a nested selection), the Series name is set to the selected
302+
field's name.
303+
- For a :class:`pyarrow.compute.Expression`, this is set to
304+
the string form of the expression.
305+
284306
See Also
285307
--------
286308
Series.struct.explode : Return all child fields as a DataFrame.
@@ -314,27 +336,80 @@ def field(self, name_or_index: str | int) -> Series:
314336
1 2
315337
2 1
316338
Name: version, dtype: int64[pyarrow]
339+
340+
Or an expression
341+
342+
>>> import pyarrow.compute as pc
343+
>>> s.struct.field(pc.field("project"))
344+
0 pandas
345+
1 pandas
346+
2 numpy
347+
Name: project, dtype: string[pyarrow]
348+
349+
For nested struct types, you can pass a list of values:
350+
351+
>>> version_type = pa.struct([
352+
... ("major", pa.int64()),
353+
... ("minor", pa.int64()),
354+
... ])
355+
>>> s = pd.Series(
356+
... [
357+
... {"version": {"major": 1, "minor": 5}, "project": "pandas"},
358+
... {"version": {"major": 2, "minor": 1}, "project": "pandas"},
359+
... {"version": {"major": 1, "minor": 26}, "project": "numpy"},
360+
... ],
361+
... dtype=pd.ArrowDtype(pa.struct(
362+
... [("version", version_type), ("project", pa.string())]
363+
... ))
364+
... )
365+
>>> s.struct.field(["version", "minor"])
366+
0 5
367+
1 1
368+
2 26
369+
Name: minor, dtype: int64[pyarrow]
370+
>>> s.struct.field([0, 0])
371+
0 1
372+
1 2
373+
2 1
374+
Name: major, dtype: int64[pyarrow]
317375
"""
318376
from pandas import Series
319377

378+
def get_name(level_name_or_index, data):
379+
if isinstance(level_name_or_index, int):
380+
index = data.type.field(level_name_or_index).name
381+
elif isinstance(level_name_or_index, (str, bytes)):
382+
index = level_name_or_index
383+
elif isinstance(level_name_or_index, pc.Expression):
384+
index = str(level_name_or_index)
385+
elif is_list_like(level_name_or_index):
386+
# For nested input like [2, 1, 2]
387+
# iteratively get the struct and field name. The last
388+
# one is used for the name of the index.
389+
level_name_or_index = list(reversed(level_name_or_index))
390+
selected = data
391+
while level_name_or_index:
392+
name_or_index = level_name_or_index.pop()
393+
name = get_name(name_or_index, selected)
394+
selected = selected.type.field(selected.type.get_field_index(name))
395+
index = selected.name
396+
return index
397+
else:
398+
raise ValueError(
399+
"name_or_index must be an int, str, bytes, "
400+
"pyarrow.compute.Expression, or list of those"
401+
)
402+
return index
403+
320404
pa_arr = self._data.array._pa_array
321-
if isinstance(name_or_index, int):
322-
index = name_or_index
323-
elif isinstance(name_or_index, str):
324-
index = pa_arr.type.get_field_index(name_or_index)
325-
else:
326-
raise ValueError(
327-
"name_or_index must be an int or str, "
328-
f"got {type(name_or_index).__name__}"
329-
)
405+
name = get_name(name_or_index, pa_arr)
406+
field_arr = pc.struct_field(pa_arr, name_or_index)
330407

331-
pa_field = pa_arr.type[index]
332-
field_arr = pc.struct_field(pa_arr, [index])
333408
return Series(
334409
field_arr,
335410
dtype=ArrowDtype(field_arr.type),
336411
index=self._data.index,
337-
name=pa_field.name,
412+
name=name,
338413
)
339414

340415
def explode(self) -> DataFrame:

pandas/tests/series/accessors/test_struct_accessor.py

+39-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import re
22

3+
import pyarrow.compute as pc
34
import pytest
45

56
from pandas import (
@@ -94,7 +95,7 @@ def test_struct_accessor_field():
9495
def test_struct_accessor_field_with_invalid_name_or_index():
9596
ser = Series([], dtype=ArrowDtype(pa.struct([("field", pa.int64())])))
9697

97-
with pytest.raises(ValueError, match="name_or_index must be an int or str"):
98+
with pytest.raises(ValueError, match="name_or_index must be an int, str,"):
9899
ser.struct.field(1.1)
99100

100101

@@ -148,3 +149,40 @@ def test_struct_accessor_api_for_invalid(invalid):
148149
),
149150
):
150151
invalid.struct
152+
153+
154+
@pytest.mark.parametrize(
155+
["indices", "name"],
156+
[
157+
(0, "int_col"),
158+
([1, 2], "str_col"),
159+
(pc.field("int_col"), "int_col"),
160+
("int_col", "int_col"),
161+
(b"string_col", b"string_col"),
162+
([b"string_col"], "string_col"),
163+
],
164+
)
165+
def test_struct_accessor_field_expanded(indices, name):
166+
arrow_type = pa.struct(
167+
[
168+
("int_col", pa.int64()),
169+
(
170+
"struct_col",
171+
pa.struct(
172+
[
173+
("int_col", pa.int64()),
174+
("float_col", pa.float64()),
175+
("str_col", pa.string()),
176+
]
177+
),
178+
),
179+
(b"string_col", pa.string()),
180+
]
181+
)
182+
183+
data = pa.array([], type=arrow_type)
184+
ser = Series(data, dtype=ArrowDtype(arrow_type))
185+
expected = pc.struct_field(data, indices)
186+
result = ser.struct.field(indices)
187+
tm.assert_equal(result.array._pa_array.combine_chunks(), expected)
188+
assert result.name == name

0 commit comments

Comments
 (0)