Skip to content

Commit 1f80a7d

Browse files
mroeschkeYYYasin19
authored andcommitted
ENH: Add ArrowDype and .array.ArrowExtensionArray to top level (pandas-dev#47818)
* ENH: Add ArrowDype and .array.ArrowExtensionDtype to top level * ensure string[pyarrow] dispatches to StringDtype for now * type ignores * Address availability of Pyarrow * Address typing
1 parent 9297713 commit 1f80a7d

14 files changed

+149
-127
lines changed

pandas/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747

4848
from pandas.core.api import (
4949
# dtype
50+
ArrowDtype,
5051
Int8Dtype,
5152
Int16Dtype,
5253
Int32Dtype,
@@ -308,6 +309,7 @@ def __getattr__(name):
308309
# Pandas is not (yet) a py.typed library: the public API is determined
309310
# based on the documentation.
310311
__all__ = [
312+
"ArrowDtype",
311313
"BooleanDtype",
312314
"Categorical",
313315
"CategoricalDtype",

pandas/core/api.py

+2
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
value_counts,
2626
)
2727
from pandas.core.arrays import Categorical
28+
from pandas.core.arrays.arrow import ArrowDtype
2829
from pandas.core.arrays.boolean import BooleanDtype
2930
from pandas.core.arrays.floating import (
3031
Float32Dtype,
@@ -85,6 +86,7 @@
8586

8687
__all__ = [
8788
"array",
89+
"ArrowDtype",
8890
"bdate_range",
8991
"BooleanDtype",
9092
"Categorical",

pandas/core/arrays/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from pandas.core.arrays.arrow import ArrowExtensionArray
12
from pandas.core.arrays.base import (
23
ExtensionArray,
34
ExtensionOpsMixin,
@@ -21,6 +22,7 @@
2122
from pandas.core.arrays.timedeltas import TimedeltaArray
2223

2324
__all__ = [
25+
"ArrowExtensionArray",
2426
"ExtensionArray",
2527
"ExtensionOpsMixin",
2628
"ExtensionScalarOpsMixin",

pandas/core/arrays/arrow/__init__.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
from pandas.core.arrays.arrow.array import ArrowExtensionArray
2+
from pandas.core.arrays.arrow.dtype import ArrowDtype
23

3-
__all__ = ["ArrowExtensionArray"]
4+
__all__ = ["ArrowDtype", "ArrowExtensionArray"]
-111
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,14 @@
11
from __future__ import annotations
22

33
import inspect
4-
import json
54
import warnings
65

76
import numpy as np
87
import pyarrow
98

10-
from pandas._typing import IntervalInclusiveType
119
from pandas.errors import PerformanceWarning
12-
from pandas.util._decorators import deprecate_kwarg
1310
from pandas.util._exceptions import find_stack_level
1411

15-
from pandas.core.arrays.interval import VALID_INCLUSIVE
16-
1712

1813
def fallback_performancewarning(version: str | None = None) -> None:
1914
"""
@@ -67,109 +62,3 @@ def pyarrow_array_to_numpy_and_mask(
6762
else:
6863
mask = np.ones(len(arr), dtype=bool)
6964
return data, mask
70-
71-
72-
class ArrowPeriodType(pyarrow.ExtensionType):
73-
def __init__(self, freq) -> None:
74-
# attributes need to be set first before calling
75-
# super init (as that calls serialize)
76-
self._freq = freq
77-
pyarrow.ExtensionType.__init__(self, pyarrow.int64(), "pandas.period")
78-
79-
@property
80-
def freq(self):
81-
return self._freq
82-
83-
def __arrow_ext_serialize__(self) -> bytes:
84-
metadata = {"freq": self.freq}
85-
return json.dumps(metadata).encode()
86-
87-
@classmethod
88-
def __arrow_ext_deserialize__(cls, storage_type, serialized) -> ArrowPeriodType:
89-
metadata = json.loads(serialized.decode())
90-
return ArrowPeriodType(metadata["freq"])
91-
92-
def __eq__(self, other):
93-
if isinstance(other, pyarrow.BaseExtensionType):
94-
return type(self) == type(other) and self.freq == other.freq
95-
else:
96-
return NotImplemented
97-
98-
def __hash__(self) -> int:
99-
return hash((str(self), self.freq))
100-
101-
def to_pandas_dtype(self):
102-
import pandas as pd
103-
104-
return pd.PeriodDtype(freq=self.freq)
105-
106-
107-
# register the type with a dummy instance
108-
_period_type = ArrowPeriodType("D")
109-
pyarrow.register_extension_type(_period_type)
110-
111-
112-
class ArrowIntervalType(pyarrow.ExtensionType):
113-
@deprecate_kwarg(old_arg_name="closed", new_arg_name="inclusive")
114-
def __init__(self, subtype, inclusive: IntervalInclusiveType) -> None:
115-
# attributes need to be set first before calling
116-
# super init (as that calls serialize)
117-
assert inclusive in VALID_INCLUSIVE
118-
self._inclusive: IntervalInclusiveType = inclusive
119-
if not isinstance(subtype, pyarrow.DataType):
120-
subtype = pyarrow.type_for_alias(str(subtype))
121-
self._subtype = subtype
122-
123-
storage_type = pyarrow.struct([("left", subtype), ("right", subtype)])
124-
pyarrow.ExtensionType.__init__(self, storage_type, "pandas.interval")
125-
126-
@property
127-
def subtype(self):
128-
return self._subtype
129-
130-
@property
131-
def inclusive(self) -> IntervalInclusiveType:
132-
return self._inclusive
133-
134-
@property
135-
def closed(self) -> IntervalInclusiveType:
136-
warnings.warn(
137-
"Attribute `closed` is deprecated in favor of `inclusive`.",
138-
FutureWarning,
139-
stacklevel=find_stack_level(inspect.currentframe()),
140-
)
141-
return self._inclusive
142-
143-
def __arrow_ext_serialize__(self) -> bytes:
144-
metadata = {"subtype": str(self.subtype), "inclusive": self.inclusive}
145-
return json.dumps(metadata).encode()
146-
147-
@classmethod
148-
def __arrow_ext_deserialize__(cls, storage_type, serialized) -> ArrowIntervalType:
149-
metadata = json.loads(serialized.decode())
150-
subtype = pyarrow.type_for_alias(metadata["subtype"])
151-
inclusive = metadata["inclusive"]
152-
return ArrowIntervalType(subtype, inclusive)
153-
154-
def __eq__(self, other):
155-
if isinstance(other, pyarrow.BaseExtensionType):
156-
return (
157-
type(self) == type(other)
158-
and self.subtype == other.subtype
159-
and self.inclusive == other.inclusive
160-
)
161-
else:
162-
return NotImplemented
163-
164-
def __hash__(self) -> int:
165-
return hash((str(self), str(self.subtype), self.inclusive))
166-
167-
def to_pandas_dtype(self):
168-
import pandas as pd
169-
170-
return pd.IntervalDtype(self.subtype.to_pandas_dtype(), self.inclusive)
171-
172-
173-
# register the type with a dummy instance
174-
_interval_type = ArrowIntervalType(pyarrow.int64(), "left")
175-
pyarrow.register_extension_type(_interval_type)

pandas/core/arrays/arrow/dtype.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,19 @@
33
import re
44

55
import numpy as np
6-
import pyarrow as pa
76

87
from pandas._typing import DtypeObj
8+
from pandas.compat import pa_version_under1p01
99
from pandas.util._decorators import cache_readonly
1010

1111
from pandas.core.dtypes.base import (
1212
StorageExtensionDtype,
1313
register_extension_dtype,
1414
)
1515

16+
if not pa_version_under1p01:
17+
import pyarrow as pa
18+
1619

1720
@register_extension_dtype
1821
class ArrowDtype(StorageExtensionDtype):
@@ -25,6 +28,8 @@ class ArrowDtype(StorageExtensionDtype):
2528

2629
def __init__(self, pyarrow_dtype: pa.DataType) -> None:
2730
super().__init__("pyarrow")
31+
if pa_version_under1p01:
32+
raise ImportError("pyarrow>=1.0.1 is required for ArrowDtype")
2833
if not isinstance(pyarrow_dtype, pa.DataType):
2934
raise ValueError(
3035
f"pyarrow_dtype ({pyarrow_dtype}) must be an instance "
@@ -93,6 +98,9 @@ def construct_from_string(cls, string: str) -> ArrowDtype:
9398
)
9499
if not string.endswith("[pyarrow]"):
95100
raise TypeError(f"'{string}' must end with '[pyarrow]'")
101+
if string == "string[pyarrow]":
102+
# Ensure Registry.find skips ArrowDtype to use StringDtype instead
103+
raise TypeError("string[pyarrow] should be constructed by StringDtype")
96104
base_type = string.split("[pyarrow]")[0]
97105
try:
98106
pa_dtype = pa.type_for_alias(base_type)
+118
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
from __future__ import annotations
2+
3+
import json
4+
import warnings
5+
6+
import pyarrow
7+
8+
from pandas._typing import IntervalInclusiveType
9+
from pandas.util._decorators import deprecate_kwarg
10+
from pandas.util._exceptions import find_stack_level
11+
12+
from pandas.core.arrays.interval import VALID_INCLUSIVE
13+
14+
15+
class ArrowPeriodType(pyarrow.ExtensionType):
16+
def __init__(self, freq) -> None:
17+
# attributes need to be set first before calling
18+
# super init (as that calls serialize)
19+
self._freq = freq
20+
pyarrow.ExtensionType.__init__(self, pyarrow.int64(), "pandas.period")
21+
22+
@property
23+
def freq(self):
24+
return self._freq
25+
26+
def __arrow_ext_serialize__(self) -> bytes:
27+
metadata = {"freq": self.freq}
28+
return json.dumps(metadata).encode()
29+
30+
@classmethod
31+
def __arrow_ext_deserialize__(cls, storage_type, serialized) -> ArrowPeriodType:
32+
metadata = json.loads(serialized.decode())
33+
return ArrowPeriodType(metadata["freq"])
34+
35+
def __eq__(self, other):
36+
if isinstance(other, pyarrow.BaseExtensionType):
37+
return type(self) == type(other) and self.freq == other.freq
38+
else:
39+
return NotImplemented
40+
41+
def __hash__(self) -> int:
42+
return hash((str(self), self.freq))
43+
44+
def to_pandas_dtype(self):
45+
import pandas as pd
46+
47+
return pd.PeriodDtype(freq=self.freq)
48+
49+
50+
# register the type with a dummy instance
51+
_period_type = ArrowPeriodType("D")
52+
pyarrow.register_extension_type(_period_type)
53+
54+
55+
class ArrowIntervalType(pyarrow.ExtensionType):
56+
@deprecate_kwarg(old_arg_name="closed", new_arg_name="inclusive")
57+
def __init__(self, subtype, inclusive: IntervalInclusiveType) -> None:
58+
# attributes need to be set first before calling
59+
# super init (as that calls serialize)
60+
assert inclusive in VALID_INCLUSIVE
61+
self._inclusive: IntervalInclusiveType = inclusive
62+
if not isinstance(subtype, pyarrow.DataType):
63+
subtype = pyarrow.type_for_alias(str(subtype))
64+
self._subtype = subtype
65+
66+
storage_type = pyarrow.struct([("left", subtype), ("right", subtype)])
67+
pyarrow.ExtensionType.__init__(self, storage_type, "pandas.interval")
68+
69+
@property
70+
def subtype(self):
71+
return self._subtype
72+
73+
@property
74+
def inclusive(self) -> IntervalInclusiveType:
75+
return self._inclusive
76+
77+
@property
78+
def closed(self) -> IntervalInclusiveType:
79+
warnings.warn(
80+
"Attribute `closed` is deprecated in favor of `inclusive`.",
81+
FutureWarning,
82+
stacklevel=find_stack_level(),
83+
)
84+
return self._inclusive
85+
86+
def __arrow_ext_serialize__(self) -> bytes:
87+
metadata = {"subtype": str(self.subtype), "inclusive": self.inclusive}
88+
return json.dumps(metadata).encode()
89+
90+
@classmethod
91+
def __arrow_ext_deserialize__(cls, storage_type, serialized) -> ArrowIntervalType:
92+
metadata = json.loads(serialized.decode())
93+
subtype = pyarrow.type_for_alias(metadata["subtype"])
94+
inclusive = metadata["inclusive"]
95+
return ArrowIntervalType(subtype, inclusive)
96+
97+
def __eq__(self, other):
98+
if isinstance(other, pyarrow.BaseExtensionType):
99+
return (
100+
type(self) == type(other)
101+
and self.subtype == other.subtype
102+
and self.inclusive == other.inclusive
103+
)
104+
else:
105+
return NotImplemented
106+
107+
def __hash__(self) -> int:
108+
return hash((str(self), str(self.subtype), self.inclusive))
109+
110+
def to_pandas_dtype(self):
111+
import pandas as pd
112+
113+
return pd.IntervalDtype(self.subtype.to_pandas_dtype(), self.inclusive)
114+
115+
116+
# register the type with a dummy instance
117+
_interval_type = ArrowIntervalType(pyarrow.int64(), "left")
118+
pyarrow.register_extension_type(_interval_type)

pandas/core/arrays/interval.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1554,7 +1554,7 @@ def __arrow_array__(self, type=None):
15541554
"""
15551555
import pyarrow
15561556

1557-
from pandas.core.arrays.arrow._arrow_utils import ArrowIntervalType
1557+
from pandas.core.arrays.arrow.extension_types import ArrowIntervalType
15581558

15591559
try:
15601560
subtype = pyarrow.from_numpy_dtype(self.dtype.subtype)

pandas/core/arrays/period.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -377,7 +377,7 @@ def __arrow_array__(self, type=None):
377377
"""
378378
import pyarrow
379379

380-
from pandas.core.arrays.arrow._arrow_utils import ArrowPeriodType
380+
from pandas.core.arrays.arrow.extension_types import ArrowPeriodType
381381

382382
if type is not None:
383383
if pyarrow.types.is_integer(type):

pandas/core/arrays/string_arrow.py

-1
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,6 @@ class ArrowStringArray(ArrowExtensionArray, BaseStringArray, ObjectStringArrayMi
114114

115115
def __init__(self, values) -> None:
116116
super().__init__(values)
117-
# TODO: Migrate to ArrowDtype instead
118117
self._dtype = StringDtype(storage="pyarrow")
119118

120119
if not pa.types.is_string(self._data.type):

pandas/io/parquet.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ def __init__(self) -> None:
151151
import pyarrow.parquet
152152

153153
# import utils to register the pyarrow extension types
154-
import pandas.core.arrays.arrow._arrow_utils # pyright: ignore # noqa:F401
154+
import pandas.core.arrays.arrow.extension_types # pyright: ignore # noqa:F401
155155

156156
self.api = pyarrow
157157

pandas/tests/api/test_api.py

+1
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ class TestPDApi(Base):
5353

5454
# top-level classes
5555
classes = [
56+
"ArrowDtype",
5657
"Categorical",
5758
"CategoricalIndex",
5859
"DataFrame",

0 commit comments

Comments
 (0)