From 84bd0f2da12504f229c29980c162fd9031fefa69 Mon Sep 17 00:00:00 2001 From: Matthew Roeschke Date: Thu, 21 Jul 2022 20:56:08 -0700 Subject: [PATCH 1/5] ENH: Add ArrowDype and .array.ArrowExtensionDtype to top level --- pandas/__init__.py | 2 + pandas/core/api.py | 2 + pandas/core/arrays/__init__.py | 2 + pandas/core/arrays/arrow/__init__.py | 3 +- pandas/core/arrays/arrow/_arrow_utils.py | 111 ---------------- pandas/core/arrays/arrow/extension_types.py | 118 ++++++++++++++++++ pandas/core/arrays/interval.py | 2 +- pandas/core/arrays/period.py | 2 +- pandas/io/parquet.py | 2 +- pandas/tests/arrays/interval/test_interval.py | 10 +- .../tests/arrays/period/test_arrow_compat.py | 10 +- 11 files changed, 139 insertions(+), 125 deletions(-) create mode 100644 pandas/core/arrays/arrow/extension_types.py diff --git a/pandas/__init__.py b/pandas/__init__.py index eb5ce71141f46..5016bde000c3b 100644 --- a/pandas/__init__.py +++ b/pandas/__init__.py @@ -47,6 +47,7 @@ from pandas.core.api import ( # dtype + ArrowDtype, Int8Dtype, Int16Dtype, Int32Dtype, @@ -308,6 +309,7 @@ def __getattr__(name): # Pandas is not (yet) a py.typed library: the public API is determined # based on the documentation. __all__ = [ + "ArrowDtype", "BooleanDtype", "Categorical", "CategoricalDtype", diff --git a/pandas/core/api.py b/pandas/core/api.py index c2bedb032d479..3d2547fcea230 100644 --- a/pandas/core/api.py +++ b/pandas/core/api.py @@ -25,6 +25,7 @@ value_counts, ) from pandas.core.arrays import Categorical +from pandas.core.arrays.arrow import ArrowDtype from pandas.core.arrays.boolean import BooleanDtype from pandas.core.arrays.floating import ( Float32Dtype, @@ -85,6 +86,7 @@ __all__ = [ "array", + "ArrowDtype", "bdate_range", "BooleanDtype", "Categorical", diff --git a/pandas/core/arrays/__init__.py b/pandas/core/arrays/__init__.py index e301e82a0ee75..79be8760db931 100644 --- a/pandas/core/arrays/__init__.py +++ b/pandas/core/arrays/__init__.py @@ -1,3 +1,4 @@ +from pandas.core.arrays.arrow import ArrowExtensionArray from pandas.core.arrays.base import ( ExtensionArray, ExtensionOpsMixin, @@ -21,6 +22,7 @@ from pandas.core.arrays.timedeltas import TimedeltaArray __all__ = [ + "ArrowExtensionArray", "ExtensionArray", "ExtensionOpsMixin", "ExtensionScalarOpsMixin", diff --git a/pandas/core/arrays/arrow/__init__.py b/pandas/core/arrays/arrow/__init__.py index 58b268cbdd221..e7fa6fae0a5a1 100644 --- a/pandas/core/arrays/arrow/__init__.py +++ b/pandas/core/arrays/arrow/__init__.py @@ -1,3 +1,4 @@ from pandas.core.arrays.arrow.array import ArrowExtensionArray +from pandas.core.arrays.arrow.dtype import ArrowDtype -__all__ = ["ArrowExtensionArray"] +__all__ = ["ArrowDtype", "ArrowExtensionArray"] diff --git a/pandas/core/arrays/arrow/_arrow_utils.py b/pandas/core/arrays/arrow/_arrow_utils.py index c9666de9f892d..6e6ef6a2c20a8 100644 --- a/pandas/core/arrays/arrow/_arrow_utils.py +++ b/pandas/core/arrays/arrow/_arrow_utils.py @@ -1,18 +1,13 @@ from __future__ import annotations -import json import warnings import numpy as np import pyarrow -from pandas._typing import IntervalInclusiveType from pandas.errors import PerformanceWarning -from pandas.util._decorators import deprecate_kwarg from pandas.util._exceptions import find_stack_level -from pandas.core.arrays.interval import VALID_INCLUSIVE - def fallback_performancewarning(version: str | None = None) -> None: """ @@ -64,109 +59,3 @@ def pyarrow_array_to_numpy_and_mask( else: mask = np.ones(len(arr), dtype=bool) return data, mask - - -class ArrowPeriodType(pyarrow.ExtensionType): - def __init__(self, freq) -> None: - # attributes need to be set first before calling - # super init (as that calls serialize) - self._freq = freq - pyarrow.ExtensionType.__init__(self, pyarrow.int64(), "pandas.period") - - @property - def freq(self): - return self._freq - - def __arrow_ext_serialize__(self) -> bytes: - metadata = {"freq": self.freq} - return json.dumps(metadata).encode() - - @classmethod - def __arrow_ext_deserialize__(cls, storage_type, serialized) -> ArrowPeriodType: - metadata = json.loads(serialized.decode()) - return ArrowPeriodType(metadata["freq"]) - - def __eq__(self, other): - if isinstance(other, pyarrow.BaseExtensionType): - return type(self) == type(other) and self.freq == other.freq - else: - return NotImplemented - - def __hash__(self) -> int: - return hash((str(self), self.freq)) - - def to_pandas_dtype(self): - import pandas as pd - - return pd.PeriodDtype(freq=self.freq) - - -# register the type with a dummy instance -_period_type = ArrowPeriodType("D") -pyarrow.register_extension_type(_period_type) - - -class ArrowIntervalType(pyarrow.ExtensionType): - @deprecate_kwarg(old_arg_name="closed", new_arg_name="inclusive") - def __init__(self, subtype, inclusive: IntervalInclusiveType) -> None: - # attributes need to be set first before calling - # super init (as that calls serialize) - assert inclusive in VALID_INCLUSIVE - self._inclusive: IntervalInclusiveType = inclusive - if not isinstance(subtype, pyarrow.DataType): - subtype = pyarrow.type_for_alias(str(subtype)) - self._subtype = subtype - - storage_type = pyarrow.struct([("left", subtype), ("right", subtype)]) - pyarrow.ExtensionType.__init__(self, storage_type, "pandas.interval") - - @property - def subtype(self): - return self._subtype - - @property - def inclusive(self) -> IntervalInclusiveType: - return self._inclusive - - @property - def closed(self) -> IntervalInclusiveType: - warnings.warn( - "Attribute `closed` is deprecated in favor of `inclusive`.", - FutureWarning, - stacklevel=find_stack_level(), - ) - return self._inclusive - - def __arrow_ext_serialize__(self) -> bytes: - metadata = {"subtype": str(self.subtype), "inclusive": self.inclusive} - return json.dumps(metadata).encode() - - @classmethod - def __arrow_ext_deserialize__(cls, storage_type, serialized) -> ArrowIntervalType: - metadata = json.loads(serialized.decode()) - subtype = pyarrow.type_for_alias(metadata["subtype"]) - inclusive = metadata["inclusive"] - return ArrowIntervalType(subtype, inclusive) - - def __eq__(self, other): - if isinstance(other, pyarrow.BaseExtensionType): - return ( - type(self) == type(other) - and self.subtype == other.subtype - and self.inclusive == other.inclusive - ) - else: - return NotImplemented - - def __hash__(self) -> int: - return hash((str(self), str(self.subtype), self.inclusive)) - - def to_pandas_dtype(self): - import pandas as pd - - return pd.IntervalDtype(self.subtype.to_pandas_dtype(), self.inclusive) - - -# register the type with a dummy instance -_interval_type = ArrowIntervalType(pyarrow.int64(), "left") -pyarrow.register_extension_type(_interval_type) diff --git a/pandas/core/arrays/arrow/extension_types.py b/pandas/core/arrays/arrow/extension_types.py new file mode 100644 index 0000000000000..a2b3c6d4da080 --- /dev/null +++ b/pandas/core/arrays/arrow/extension_types.py @@ -0,0 +1,118 @@ +from __future__ import annotations + +import json +import warnings + +import pyarrow + +from pandas._typing import IntervalInclusiveType +from pandas.util._decorators import deprecate_kwarg +from pandas.util._exceptions import find_stack_level + +from pandas.core.arrays.interval import VALID_INCLUSIVE + + +class ArrowPeriodType(pyarrow.ExtensionType): + def __init__(self, freq) -> None: + # attributes need to be set first before calling + # super init (as that calls serialize) + self._freq = freq + pyarrow.ExtensionType.__init__(self, pyarrow.int64(), "pandas.period") + + @property + def freq(self): + return self._freq + + def __arrow_ext_serialize__(self) -> bytes: + metadata = {"freq": self.freq} + return json.dumps(metadata).encode() + + @classmethod + def __arrow_ext_deserialize__(cls, storage_type, serialized) -> ArrowPeriodType: + metadata = json.loads(serialized.decode()) + return ArrowPeriodType(metadata["freq"]) + + def __eq__(self, other): + if isinstance(other, pyarrow.BaseExtensionType): + return type(self) == type(other) and self.freq == other.freq + else: + return NotImplemented + + def __hash__(self) -> int: + return hash((str(self), self.freq)) + + def to_pandas_dtype(self): + import pandas as pd + + return pd.PeriodDtype(freq=self.freq) + + +# register the type with a dummy instance +_period_type = ArrowPeriodType("D") +pyarrow.register_extension_type(_period_type) + + +class ArrowIntervalType(pyarrow.ExtensionType): + @deprecate_kwarg(old_arg_name="closed", new_arg_name="inclusive") + def __init__(self, subtype, inclusive: IntervalInclusiveType) -> None: + # attributes need to be set first before calling + # super init (as that calls serialize) + assert inclusive in VALID_INCLUSIVE + self._inclusive: IntervalInclusiveType = inclusive + if not isinstance(subtype, pyarrow.DataType): + subtype = pyarrow.type_for_alias(str(subtype)) + self._subtype = subtype + + storage_type = pyarrow.struct([("left", subtype), ("right", subtype)]) + pyarrow.ExtensionType.__init__(self, storage_type, "pandas.interval") + + @property + def subtype(self): + return self._subtype + + @property + def inclusive(self) -> IntervalInclusiveType: + return self._inclusive + + @property + def closed(self) -> IntervalInclusiveType: + warnings.warn( + "Attribute `closed` is deprecated in favor of `inclusive`.", + FutureWarning, + stacklevel=find_stack_level(), + ) + return self._inclusive + + def __arrow_ext_serialize__(self) -> bytes: + metadata = {"subtype": str(self.subtype), "inclusive": self.inclusive} + return json.dumps(metadata).encode() + + @classmethod + def __arrow_ext_deserialize__(cls, storage_type, serialized) -> ArrowIntervalType: + metadata = json.loads(serialized.decode()) + subtype = pyarrow.type_for_alias(metadata["subtype"]) + inclusive = metadata["inclusive"] + return ArrowIntervalType(subtype, inclusive) + + def __eq__(self, other): + if isinstance(other, pyarrow.BaseExtensionType): + return ( + type(self) == type(other) + and self.subtype == other.subtype + and self.inclusive == other.inclusive + ) + else: + return NotImplemented + + def __hash__(self) -> int: + return hash((str(self), str(self.subtype), self.inclusive)) + + def to_pandas_dtype(self): + import pandas as pd + + return pd.IntervalDtype(self.subtype.to_pandas_dtype(), self.inclusive) + + +# register the type with a dummy instance +_interval_type = ArrowIntervalType(pyarrow.int64(), "left") +pyarrow.register_extension_type(_interval_type) diff --git a/pandas/core/arrays/interval.py b/pandas/core/arrays/interval.py index 6469dccf6e2d5..3420a3ad5ca43 100644 --- a/pandas/core/arrays/interval.py +++ b/pandas/core/arrays/interval.py @@ -1555,7 +1555,7 @@ def __arrow_array__(self, type=None): """ import pyarrow - from pandas.core.arrays.arrow._arrow_utils import ArrowIntervalType + from pandas.core.arrays.arrow.extension_types import ArrowIntervalType try: subtype = pyarrow.from_numpy_dtype(self.dtype.subtype) diff --git a/pandas/core/arrays/period.py b/pandas/core/arrays/period.py index 2d676f94c6a64..c0ac748c24067 100644 --- a/pandas/core/arrays/period.py +++ b/pandas/core/arrays/period.py @@ -375,7 +375,7 @@ def __arrow_array__(self, type=None): """ import pyarrow - from pandas.core.arrays.arrow._arrow_utils import ArrowPeriodType + from pandas.core.arrays.arrow.extension_types import ArrowPeriodType if type is not None: if pyarrow.types.is_integer(type): diff --git a/pandas/io/parquet.py b/pandas/io/parquet.py index ed0e0a99ec43b..e335ab4470a50 100644 --- a/pandas/io/parquet.py +++ b/pandas/io/parquet.py @@ -151,7 +151,7 @@ def __init__(self) -> None: import pyarrow.parquet # import utils to register the pyarrow extension types - import pandas.core.arrays.arrow._arrow_utils # pyright: ignore # noqa:F401 + import pandas.core.arrays.arrow.extension_types # pyright: ignore # noqa:F401 self.api = pyarrow diff --git a/pandas/tests/arrays/interval/test_interval.py b/pandas/tests/arrays/interval/test_interval.py index 073e6b6119b14..48f5c676b66e6 100644 --- a/pandas/tests/arrays/interval/test_interval.py +++ b/pandas/tests/arrays/interval/test_interval.py @@ -248,7 +248,7 @@ def test_min_max(self, left_right_dtypes, index_or_series_or_array): def test_arrow_extension_type(): import pyarrow as pa - from pandas.core.arrays.arrow._arrow_utils import ArrowIntervalType + from pandas.core.arrays.arrow.extension_types import ArrowIntervalType p1 = ArrowIntervalType(pa.int64(), "left") p2 = ArrowIntervalType(pa.int64(), "left") @@ -265,7 +265,7 @@ def test_arrow_extension_type(): def test_arrow_array(): import pyarrow as pa - from pandas.core.arrays.arrow._arrow_utils import ArrowIntervalType + from pandas.core.arrays.arrow.extension_types import ArrowIntervalType intervals = pd.interval_range(1, 5, freq=1).array @@ -295,7 +295,7 @@ def test_arrow_array(): def test_arrow_array_missing(): import pyarrow as pa - from pandas.core.arrays.arrow._arrow_utils import ArrowIntervalType + from pandas.core.arrays.arrow.extension_types import ArrowIntervalType arr = IntervalArray.from_breaks([0.0, 1.0, 2.0, 3.0]) arr[1] = None @@ -330,7 +330,7 @@ def test_arrow_array_missing(): def test_arrow_table_roundtrip(breaks): import pyarrow as pa - from pandas.core.arrays.arrow._arrow_utils import ArrowIntervalType + from pandas.core.arrays.arrow.extension_types import ArrowIntervalType arr = IntervalArray.from_breaks(breaks) arr[1] = None @@ -431,7 +431,7 @@ def test_arrow_interval_type_error_and_warning(): # GH 40245 import pyarrow as pa - from pandas.core.arrays.arrow._arrow_utils import ArrowIntervalType + from pandas.core.arrays.arrow.extension_types import ArrowIntervalType msg = "Can only specify 'closed' or 'inclusive', not both." with pytest.raises(TypeError, match=msg): diff --git a/pandas/tests/arrays/period/test_arrow_compat.py b/pandas/tests/arrays/period/test_arrow_compat.py index 7d2d2daed3497..03fd146572405 100644 --- a/pandas/tests/arrays/period/test_arrow_compat.py +++ b/pandas/tests/arrays/period/test_arrow_compat.py @@ -13,7 +13,7 @@ def test_arrow_extension_type(): - from pandas.core.arrays.arrow._arrow_utils import ArrowPeriodType + from pandas.core.arrays.arrow.extension_types import ArrowPeriodType p1 = ArrowPeriodType("D") p2 = ArrowPeriodType("D") @@ -34,7 +34,7 @@ def test_arrow_extension_type(): ], ) def test_arrow_array(data, freq): - from pandas.core.arrays.arrow._arrow_utils import ArrowPeriodType + from pandas.core.arrays.arrow.extension_types import ArrowPeriodType periods = period_array(data, freq=freq) result = pa.array(periods) @@ -57,7 +57,7 @@ def test_arrow_array(data, freq): def test_arrow_array_missing(): - from pandas.core.arrays.arrow._arrow_utils import ArrowPeriodType + from pandas.core.arrays.arrow.extension_types import ArrowPeriodType arr = PeriodArray([1, 2, 3], freq="D") arr[1] = pd.NaT @@ -70,7 +70,7 @@ def test_arrow_array_missing(): def test_arrow_table_roundtrip(): - from pandas.core.arrays.arrow._arrow_utils import ArrowPeriodType + from pandas.core.arrays.arrow.extension_types import ArrowPeriodType arr = PeriodArray([1, 2, 3], freq="D") arr[1] = pd.NaT @@ -91,7 +91,7 @@ def test_arrow_table_roundtrip(): def test_arrow_load_from_zero_chunks(): # GH-41040 - from pandas.core.arrays.arrow._arrow_utils import ArrowPeriodType + from pandas.core.arrays.arrow.extension_types import ArrowPeriodType arr = PeriodArray([], freq="D") df = pd.DataFrame({"a": arr}) From daf56c68e6ae66297820c88194c6d0fe744554a3 Mon Sep 17 00:00:00 2001 From: Matthew Roeschke Date: Fri, 22 Jul 2022 10:16:08 -0700 Subject: [PATCH 2/5] ensure string[pyarrow] dispatches to StringDtype for now --- pandas/core/arrays/arrow/dtype.py | 3 +++ pandas/tests/api/test_api.py | 1 + 2 files changed, 4 insertions(+) diff --git a/pandas/core/arrays/arrow/dtype.py b/pandas/core/arrays/arrow/dtype.py index 4a32663a68ed2..197494b6c039a 100644 --- a/pandas/core/arrays/arrow/dtype.py +++ b/pandas/core/arrays/arrow/dtype.py @@ -93,6 +93,9 @@ def construct_from_string(cls, string: str) -> ArrowDtype: ) if not string.endswith("[pyarrow]"): raise TypeError(f"'{string}' must end with '[pyarrow]'") + if string == "string[pyarrow]": + # Ensure Registry.find skips ArrowDtype to use StringDtype instead + raise TypeError("string[pyarrow] should be constructed by StringDtype") base_type = string.split("[pyarrow]")[0] try: pa_dtype = pa.type_for_alias(base_type) diff --git a/pandas/tests/api/test_api.py b/pandas/tests/api/test_api.py index 6350f402ac0e5..c64403b8691d5 100644 --- a/pandas/tests/api/test_api.py +++ b/pandas/tests/api/test_api.py @@ -53,6 +53,7 @@ class TestPDApi(Base): # top-level classes classes = [ + "ArrowDtype", "Categorical", "CategoricalIndex", "DataFrame", From 8cec855c18f91c88004e6e5e885d8dc9857a02b4 Mon Sep 17 00:00:00 2001 From: Matthew Roeschke Date: Fri, 22 Jul 2022 11:12:08 -0700 Subject: [PATCH 3/5] type ignores --- pandas/core/arrays/string_arrow.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pandas/core/arrays/string_arrow.py b/pandas/core/arrays/string_arrow.py index bb2fefabd6ae5..67a4d7ad6c286 100644 --- a/pandas/core/arrays/string_arrow.py +++ b/pandas/core/arrays/string_arrow.py @@ -111,7 +111,7 @@ class ArrowStringArray(ArrowExtensionArray, BaseStringArray, ObjectStringArrayMi def __init__(self, values) -> None: super().__init__(values) # TODO: Migrate to ArrowDtype instead - self._dtype = StringDtype(storage="pyarrow") + self._dtype = StringDtype(storage="pyarrow") # type: ignore[assignment] if not pa.types.is_string(self._data.type): raise ValueError( @@ -151,7 +151,7 @@ def dtype(self) -> StringDtype: # type: ignore[override] """ An instance of 'string[pyarrow]'. """ - return self._dtype + return self._dtype # type: ignore[return-value] def __array__(self, dtype: NpDtype | None = None) -> np.ndarray: """Correctly construct numpy arrays when passed to `np.asarray()`.""" From 65be5ca3ab3033abc92b89059ff75a08b3e4ab90 Mon Sep 17 00:00:00 2001 From: Matthew Roeschke Date: Fri, 22 Jul 2022 11:16:41 -0700 Subject: [PATCH 4/5] Address availability of Pyarrow --- pandas/core/arrays/arrow/dtype.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/pandas/core/arrays/arrow/dtype.py b/pandas/core/arrays/arrow/dtype.py index 197494b6c039a..523e031c220e4 100644 --- a/pandas/core/arrays/arrow/dtype.py +++ b/pandas/core/arrays/arrow/dtype.py @@ -3,9 +3,9 @@ import re import numpy as np -import pyarrow as pa from pandas._typing import DtypeObj +from pandas.compat import pa_version_under1p01 from pandas.util._decorators import cache_readonly from pandas.core.dtypes.base import ( @@ -13,6 +13,9 @@ register_extension_dtype, ) +if not pa_version_under1p01: + import pyarrow as pa + @register_extension_dtype class ArrowDtype(StorageExtensionDtype): @@ -25,6 +28,8 @@ class ArrowDtype(StorageExtensionDtype): def __init__(self, pyarrow_dtype: pa.DataType) -> None: super().__init__("pyarrow") + if pa_version_under1p01: + raise ImportError("pyarrow>=1.0.1 is required for ArrowDtype") if not isinstance(pyarrow_dtype, pa.DataType): raise ValueError( f"pyarrow_dtype ({pyarrow_dtype}) must be an instance " From 3e5b4b2f57b4ca85b4721343535bc8a8add44820 Mon Sep 17 00:00:00 2001 From: Matthew Roeschke Date: Mon, 1 Aug 2022 15:15:40 -0700 Subject: [PATCH 5/5] Address typing --- pandas/core/arrays/string_arrow.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/pandas/core/arrays/string_arrow.py b/pandas/core/arrays/string_arrow.py index af048a26655af..01f29efbfba11 100644 --- a/pandas/core/arrays/string_arrow.py +++ b/pandas/core/arrays/string_arrow.py @@ -114,8 +114,7 @@ class ArrowStringArray(ArrowExtensionArray, BaseStringArray, ObjectStringArrayMi def __init__(self, values) -> None: super().__init__(values) - # TODO: Migrate to ArrowDtype instead - self._dtype = StringDtype(storage="pyarrow") # type: ignore[assignment] + self._dtype = StringDtype(storage="pyarrow") if not pa.types.is_string(self._data.type): raise ValueError( @@ -155,7 +154,7 @@ def dtype(self) -> StringDtype: # type: ignore[override] """ An instance of 'string[pyarrow]'. """ - return self._dtype # type: ignore[return-value] + return self._dtype def __array__(self, dtype: NpDtype | None = None) -> np.ndarray: """Correctly construct numpy arrays when passed to `np.asarray()`."""