Skip to content

ENH: Add ArrowDype and .array.ArrowExtensionArray to top level #47818

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Aug 16, 2022
2 changes: 2 additions & 0 deletions pandas/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@

from pandas.core.api import (
# dtype
ArrowDtype,
Int8Dtype,
Int16Dtype,
Int32Dtype,
Expand Down Expand Up @@ -308,6 +309,7 @@ def __getattr__(name):
# Pandas is not (yet) a py.typed library: the public API is determined
# based on the documentation.
__all__ = [
"ArrowDtype",
"BooleanDtype",
"Categorical",
"CategoricalDtype",
Expand Down
2 changes: 2 additions & 0 deletions pandas/core/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
value_counts,
)
from pandas.core.arrays import Categorical
from pandas.core.arrays.arrow import ArrowDtype
from pandas.core.arrays.boolean import BooleanDtype
from pandas.core.arrays.floating import (
Float32Dtype,
Expand Down Expand Up @@ -85,6 +86,7 @@

__all__ = [
"array",
"ArrowDtype",
"bdate_range",
"BooleanDtype",
"Categorical",
Expand Down
2 changes: 2 additions & 0 deletions pandas/core/arrays/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from pandas.core.arrays.arrow import ArrowExtensionArray
from pandas.core.arrays.base import (
ExtensionArray,
ExtensionOpsMixin,
Expand All @@ -21,6 +22,7 @@
from pandas.core.arrays.timedeltas import TimedeltaArray

__all__ = [
"ArrowExtensionArray",
"ExtensionArray",
"ExtensionOpsMixin",
"ExtensionScalarOpsMixin",
Expand Down
3 changes: 2 additions & 1 deletion pandas/core/arrays/arrow/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from pandas.core.arrays.arrow.array import ArrowExtensionArray
from pandas.core.arrays.arrow.dtype import ArrowDtype

__all__ = ["ArrowExtensionArray"]
__all__ = ["ArrowDtype", "ArrowExtensionArray"]
111 changes: 0 additions & 111 deletions pandas/core/arrays/arrow/_arrow_utils.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,13 @@
from __future__ import annotations

import json
import warnings

import numpy as np
import pyarrow

from pandas._typing import IntervalInclusiveType
from pandas.errors import PerformanceWarning
from pandas.util._decorators import deprecate_kwarg
from pandas.util._exceptions import find_stack_level

from pandas.core.arrays.interval import VALID_INCLUSIVE


def fallback_performancewarning(version: str | None = None) -> None:
"""
Expand Down Expand Up @@ -64,109 +59,3 @@ def pyarrow_array_to_numpy_and_mask(
else:
mask = np.ones(len(arr), dtype=bool)
return data, mask


class ArrowPeriodType(pyarrow.ExtensionType):
def __init__(self, freq) -> None:
# attributes need to be set first before calling
# super init (as that calls serialize)
self._freq = freq
pyarrow.ExtensionType.__init__(self, pyarrow.int64(), "pandas.period")

@property
def freq(self):
return self._freq

def __arrow_ext_serialize__(self) -> bytes:
metadata = {"freq": self.freq}
return json.dumps(metadata).encode()

@classmethod
def __arrow_ext_deserialize__(cls, storage_type, serialized) -> ArrowPeriodType:
metadata = json.loads(serialized.decode())
return ArrowPeriodType(metadata["freq"])

def __eq__(self, other):
if isinstance(other, pyarrow.BaseExtensionType):
return type(self) == type(other) and self.freq == other.freq
else:
return NotImplemented

def __hash__(self) -> int:
return hash((str(self), self.freq))

def to_pandas_dtype(self):
import pandas as pd

return pd.PeriodDtype(freq=self.freq)


# register the type with a dummy instance
_period_type = ArrowPeriodType("D")
pyarrow.register_extension_type(_period_type)


class ArrowIntervalType(pyarrow.ExtensionType):
@deprecate_kwarg(old_arg_name="closed", new_arg_name="inclusive")
def __init__(self, subtype, inclusive: IntervalInclusiveType) -> None:
# attributes need to be set first before calling
# super init (as that calls serialize)
assert inclusive in VALID_INCLUSIVE
self._inclusive: IntervalInclusiveType = inclusive
if not isinstance(subtype, pyarrow.DataType):
subtype = pyarrow.type_for_alias(str(subtype))
self._subtype = subtype

storage_type = pyarrow.struct([("left", subtype), ("right", subtype)])
pyarrow.ExtensionType.__init__(self, storage_type, "pandas.interval")

@property
def subtype(self):
return self._subtype

@property
def inclusive(self) -> IntervalInclusiveType:
return self._inclusive

@property
def closed(self) -> IntervalInclusiveType:
warnings.warn(
"Attribute `closed` is deprecated in favor of `inclusive`.",
FutureWarning,
stacklevel=find_stack_level(),
)
return self._inclusive

def __arrow_ext_serialize__(self) -> bytes:
metadata = {"subtype": str(self.subtype), "inclusive": self.inclusive}
return json.dumps(metadata).encode()

@classmethod
def __arrow_ext_deserialize__(cls, storage_type, serialized) -> ArrowIntervalType:
metadata = json.loads(serialized.decode())
subtype = pyarrow.type_for_alias(metadata["subtype"])
inclusive = metadata["inclusive"]
return ArrowIntervalType(subtype, inclusive)

def __eq__(self, other):
if isinstance(other, pyarrow.BaseExtensionType):
return (
type(self) == type(other)
and self.subtype == other.subtype
and self.inclusive == other.inclusive
)
else:
return NotImplemented

def __hash__(self) -> int:
return hash((str(self), str(self.subtype), self.inclusive))

def to_pandas_dtype(self):
import pandas as pd

return pd.IntervalDtype(self.subtype.to_pandas_dtype(), self.inclusive)


# register the type with a dummy instance
_interval_type = ArrowIntervalType(pyarrow.int64(), "left")
pyarrow.register_extension_type(_interval_type)
10 changes: 9 additions & 1 deletion pandas/core/arrays/arrow/dtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,19 @@
import re

import numpy as np
import pyarrow as pa

from pandas._typing import DtypeObj
from pandas.compat import pa_version_under1p01
from pandas.util._decorators import cache_readonly

from pandas.core.dtypes.base import (
StorageExtensionDtype,
register_extension_dtype,
)

if not pa_version_under1p01:
import pyarrow as pa


@register_extension_dtype
class ArrowDtype(StorageExtensionDtype):
Expand All @@ -25,6 +28,8 @@ class ArrowDtype(StorageExtensionDtype):

def __init__(self, pyarrow_dtype: pa.DataType) -> None:
super().__init__("pyarrow")
if pa_version_under1p01:
raise ImportError("pyarrow>=1.0.1 is required for ArrowDtype")
if not isinstance(pyarrow_dtype, pa.DataType):
raise ValueError(
f"pyarrow_dtype ({pyarrow_dtype}) must be an instance "
Expand Down Expand Up @@ -93,6 +98,9 @@ def construct_from_string(cls, string: str) -> ArrowDtype:
)
if not string.endswith("[pyarrow]"):
raise TypeError(f"'{string}' must end with '[pyarrow]'")
if string == "string[pyarrow]":
# Ensure Registry.find skips ArrowDtype to use StringDtype instead
raise TypeError("string[pyarrow] should be constructed by StringDtype")
base_type = string.split("[pyarrow]")[0]
try:
pa_dtype = pa.type_for_alias(base_type)
Expand Down
118 changes: 118 additions & 0 deletions pandas/core/arrays/arrow/extension_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
from __future__ import annotations

import json
import warnings

import pyarrow

from pandas._typing import IntervalInclusiveType
from pandas.util._decorators import deprecate_kwarg
from pandas.util._exceptions import find_stack_level

from pandas.core.arrays.interval import VALID_INCLUSIVE


class ArrowPeriodType(pyarrow.ExtensionType):
def __init__(self, freq) -> None:
# attributes need to be set first before calling
# super init (as that calls serialize)
self._freq = freq
pyarrow.ExtensionType.__init__(self, pyarrow.int64(), "pandas.period")

@property
def freq(self):
return self._freq

def __arrow_ext_serialize__(self) -> bytes:
metadata = {"freq": self.freq}
return json.dumps(metadata).encode()

@classmethod
def __arrow_ext_deserialize__(cls, storage_type, serialized) -> ArrowPeriodType:
metadata = json.loads(serialized.decode())
return ArrowPeriodType(metadata["freq"])

def __eq__(self, other):
if isinstance(other, pyarrow.BaseExtensionType):
return type(self) == type(other) and self.freq == other.freq
else:
return NotImplemented

def __hash__(self) -> int:
return hash((str(self), self.freq))

def to_pandas_dtype(self):
import pandas as pd

return pd.PeriodDtype(freq=self.freq)


# register the type with a dummy instance
_period_type = ArrowPeriodType("D")
pyarrow.register_extension_type(_period_type)


class ArrowIntervalType(pyarrow.ExtensionType):
@deprecate_kwarg(old_arg_name="closed", new_arg_name="inclusive")
def __init__(self, subtype, inclusive: IntervalInclusiveType) -> None:
# attributes need to be set first before calling
# super init (as that calls serialize)
assert inclusive in VALID_INCLUSIVE
self._inclusive: IntervalInclusiveType = inclusive
if not isinstance(subtype, pyarrow.DataType):
subtype = pyarrow.type_for_alias(str(subtype))
self._subtype = subtype

storage_type = pyarrow.struct([("left", subtype), ("right", subtype)])
pyarrow.ExtensionType.__init__(self, storage_type, "pandas.interval")

@property
def subtype(self):
return self._subtype

@property
def inclusive(self) -> IntervalInclusiveType:
return self._inclusive

@property
def closed(self) -> IntervalInclusiveType:
warnings.warn(
"Attribute `closed` is deprecated in favor of `inclusive`.",
FutureWarning,
stacklevel=find_stack_level(),
)
return self._inclusive

def __arrow_ext_serialize__(self) -> bytes:
metadata = {"subtype": str(self.subtype), "inclusive": self.inclusive}
return json.dumps(metadata).encode()

@classmethod
def __arrow_ext_deserialize__(cls, storage_type, serialized) -> ArrowIntervalType:
metadata = json.loads(serialized.decode())
subtype = pyarrow.type_for_alias(metadata["subtype"])
inclusive = metadata["inclusive"]
return ArrowIntervalType(subtype, inclusive)

def __eq__(self, other):
if isinstance(other, pyarrow.BaseExtensionType):
return (
type(self) == type(other)
and self.subtype == other.subtype
and self.inclusive == other.inclusive
)
else:
return NotImplemented

def __hash__(self) -> int:
return hash((str(self), str(self.subtype), self.inclusive))

def to_pandas_dtype(self):
import pandas as pd

return pd.IntervalDtype(self.subtype.to_pandas_dtype(), self.inclusive)


# register the type with a dummy instance
_interval_type = ArrowIntervalType(pyarrow.int64(), "left")
pyarrow.register_extension_type(_interval_type)
2 changes: 1 addition & 1 deletion pandas/core/arrays/interval.py
Original file line number Diff line number Diff line change
Expand Up @@ -1555,7 +1555,7 @@ def __arrow_array__(self, type=None):
"""
import pyarrow

from pandas.core.arrays.arrow._arrow_utils import ArrowIntervalType
from pandas.core.arrays.arrow.extension_types import ArrowIntervalType

try:
subtype = pyarrow.from_numpy_dtype(self.dtype.subtype)
Expand Down
2 changes: 1 addition & 1 deletion pandas/core/arrays/period.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,7 +375,7 @@ def __arrow_array__(self, type=None):
"""
import pyarrow

from pandas.core.arrays.arrow._arrow_utils import ArrowPeriodType
from pandas.core.arrays.arrow.extension_types import ArrowPeriodType

if type is not None:
if pyarrow.types.is_integer(type):
Expand Down
4 changes: 2 additions & 2 deletions pandas/core/arrays/string_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ class ArrowStringArray(ArrowExtensionArray, BaseStringArray, ObjectStringArrayMi
def __init__(self, values) -> None:
super().__init__(values)
# TODO: Migrate to ArrowDtype instead
self._dtype = StringDtype(storage="pyarrow")
self._dtype = StringDtype(storage="pyarrow") # type: ignore[assignment]

if not pa.types.is_string(self._data.type):
raise ValueError(
Expand Down Expand Up @@ -151,7 +151,7 @@ def dtype(self) -> StringDtype: # type: ignore[override]
"""
An instance of 'string[pyarrow]'.
"""
return self._dtype
return self._dtype # type: ignore[return-value]

def __array__(self, dtype: NpDtype | None = None) -> np.ndarray:
"""Correctly construct numpy arrays when passed to `np.asarray()`."""
Expand Down
2 changes: 1 addition & 1 deletion pandas/io/parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def __init__(self) -> None:
import pyarrow.parquet

# import utils to register the pyarrow extension types
import pandas.core.arrays.arrow._arrow_utils # pyright: ignore # noqa:F401
import pandas.core.arrays.arrow.extension_types # pyright: ignore # noqa:F401

self.api = pyarrow

Expand Down
1 change: 1 addition & 0 deletions pandas/tests/api/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ class TestPDApi(Base):

# top-level classes
classes = [
"ArrowDtype",
"Categorical",
"CategoricalIndex",
"DataFrame",
Expand Down
Loading