Skip to content

Commit 66ff669

Browse files
committed
ENH: add Series.struct accessor for ArrowDtype[struct]
Features: * Series.struct.dtypes -- see dtypes and field names * Series.struct.field(name_or_index) -- extract a field as a Series * Series.struct.to_frame() -- convert all fields into a DataFrame
1 parent 53243e8 commit 66ff669

File tree

3 files changed

+70
-1
lines changed

3 files changed

+70
-1
lines changed

pandas/core/arrays/arrow/__init__.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from pandas.core.arrays.arrow.accessors import StructAccessor
12
from pandas.core.arrays.arrow.array import ArrowExtensionArray
23

3-
__all__ = ["ArrowExtensionArray"]
4+
__all__ = ["ArrowExtensionArray", "StructAccessor"]

pandas/core/arrays/arrow/accessors.py

+66
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
from __future__ import annotations
2+
3+
from typing import TYPE_CHECKING
4+
5+
from pandas.compat import pa_version_under7p0
6+
7+
if not pa_version_under7p0:
8+
import pyarrow as pa
9+
import pyarrow.compute as pc
10+
11+
from pandas.core.dtypes.dtypes import ArrowDtype
12+
13+
if TYPE_CHECKING:
14+
from pandas import (
15+
DataFrame,
16+
Series,
17+
)
18+
19+
20+
class StructAccessor:
21+
_validation_msg = "Can only use the '.struct' accessor with 'struct[pyarrow]' data."
22+
23+
def __init__(self, data=None) -> None:
24+
self._parent = data
25+
self._validate(data)
26+
27+
def _validate(self, data):
28+
dtype = data.dtype
29+
if not isinstance(dtype, ArrowDtype):
30+
raise AttributeError(self._validation_message)
31+
32+
if not pa.types.is_struct(dtype.pyarrow_dtype):
33+
raise AttributeError(self._validation_message)
34+
35+
@property
36+
def dtypes(self) -> Series:
37+
from pandas import (
38+
Index,
39+
Series,
40+
)
41+
42+
pa_type = self._parent.dtype.pyarrow_dtype
43+
types = [ArrowDtype(pa_type[i].type) for i in range(pa_type.num_fields)]
44+
names = [pa_type[i].name for i in range(pa_type.num_fields)]
45+
return Series(types, index=Index(names))
46+
47+
def field(self, name_or_index: str | int) -> Series:
48+
from pandas import Series
49+
50+
pa_arr = self._parent.array._pa_array
51+
if isinstance(name_or_index, int):
52+
index = name_or_index
53+
else:
54+
index = pa_arr.type.get_field_index(name_or_index)
55+
56+
pa_field = pa_arr.type[index]
57+
field_arr = pc.struct_field(pa_arr, [index])
58+
return Series(field_arr, dtype=ArrowDtype(field_arr.type), name=pa_field.name)
59+
60+
def to_frame(self) -> DataFrame:
61+
from pandas import concat
62+
63+
pa_type = self._parent.dtype.pyarrow_dtype
64+
return concat(
65+
[self.field(i) for i in range(pa_type.num_fields)], axis="columns"
66+
)

pandas/core/series.py

+2
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@
101101
from pandas.core.accessor import CachedAccessor
102102
from pandas.core.apply import SeriesApply
103103
from pandas.core.arrays import ExtensionArray
104+
from pandas.core.arrays.arrow import StructAccessor
104105
from pandas.core.arrays.categorical import CategoricalAccessor
105106
from pandas.core.arrays.sparse import SparseAccessor
106107
from pandas.core.construction import (
@@ -5787,6 +5788,7 @@ def to_period(self, freq: str | None = None, copy: bool | None = None) -> Series
57875788
cat = CachedAccessor("cat", CategoricalAccessor)
57885789
plot = CachedAccessor("plot", pandas.plotting.PlotAccessor)
57895790
sparse = CachedAccessor("sparse", SparseAccessor)
5791+
struct = CachedAccessor("struct", StructAccessor)
57905792

57915793
# ----------------------------------------------------------------------
57925794
# Add plotting methods to Series

0 commit comments

Comments
 (0)