-
-
Notifications
You must be signed in to change notification settings - Fork 18.5k
ENH: add and register Arrow extension types for Period and Interval #28371
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 18 commits
e3ab110
6c1300f
5eb8ad6
47c4755
e7e0674
85bf36c
f325ff1
82589dd
64bf38b
70e7023
b09f54d
913f310
76a6f46
6587bd2
5303bae
a97808c
206c609
e9a032d
16523af
1b6f21e
d39b8a3
4156718
92a1ede
e303749
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,5 @@ | ||
from distutils.version import LooseVersion | ||
import json | ||
from operator import le, lt | ||
import textwrap | ||
|
||
|
@@ -39,6 +41,14 @@ | |
import pandas.core.common as com | ||
from pandas.core.indexes.base import ensure_index | ||
|
||
try: | ||
import pyarrow | ||
|
||
_PYARROW_INSTALLED = True | ||
except ImportError: | ||
_PYARROW_INSTALLED = False | ||
|
||
|
||
_VALID_CLOSED = {"left", "right", "both", "neither"} | ||
_interval_shared_docs = {} | ||
|
||
|
@@ -1026,6 +1036,58 @@ def __array__(self, dtype=None): | |
result[i] = Interval(left[i], right[i], closed) | ||
return result | ||
|
||
def __arrow_array__(self, type=None): | ||
""" | ||
Convert myself into a pyarrow Array. | ||
""" | ||
import pyarrow as pa | ||
|
||
try: | ||
subtype = pa.from_numpy_dtype(self.dtype.subtype) | ||
except TypeError: | ||
raise TypeError( | ||
"Conversion to arrow with subtype '{}' " | ||
"is not supported".format(self.dtype.subtype) | ||
) | ||
interval_type = ArrowIntervalType(subtype, self.closed) | ||
storage_array = pa.StructArray.from_arrays( | ||
[ | ||
pa.array(self.left, type=subtype, from_pandas=True), | ||
pa.array(self.right, type=subtype, from_pandas=True), | ||
], | ||
names=["left", "right"], | ||
) | ||
jreback marked this conversation as resolved.
Show resolved
Hide resolved
|
||
mask = self.isna() | ||
if mask.any(): | ||
# if there are missing values, set validity bitmap also on the array level | ||
null_bitmap = pa.array(~mask).buffers()[1] | ||
storage_array = pa.StructArray.from_buffers( | ||
storage_array.type, | ||
len(storage_array), | ||
[null_bitmap], | ||
children=[storage_array.field(0), storage_array.field(1)], | ||
) | ||
|
||
if type is not None: | ||
if type.equals(interval_type.storage_type): | ||
return storage_array | ||
elif isinstance(type, ArrowIntervalType): | ||
# ensure we have the same subtype and closed attributes | ||
if not type.equals(interval_type): | ||
raise TypeError( | ||
"Not supported to convert IntervalArray to type with " | ||
"different 'subtype' ({0} vs {1}) and 'closed' ({2} vs {3}) " | ||
"attributes".format( | ||
self.dtype.subtype, type.subtype, self.closed, type.closed | ||
) | ||
) | ||
else: | ||
raise TypeError( | ||
"Not supported to convert IntervalArray to '{0}' type".format(type) | ||
) | ||
|
||
return pa.ExtensionArray.from_storage(interval_type, storage_array) | ||
|
||
_interval_shared_docs[ | ||
"to_tuples" | ||
] = """ | ||
|
@@ -1217,3 +1279,55 @@ def maybe_convert_platform_interval(values): | |
values = np.asarray(values) | ||
|
||
return maybe_convert_platform(values) | ||
|
||
|
||
if _PYARROW_INSTALLED and LooseVersion(pyarrow.__version__) >= LooseVersion("0.15"): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. __PYARROW_INSTALLED needs to incorporate the version check There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I moved the version check into the separate file (and made it a variable), but kept it separate from the import check as different functionalities might need a different pyarrow version |
||
|
||
class ArrowIntervalType(pyarrow.ExtensionType): | ||
def __init__(self, subtype, closed): | ||
# attributes need to be set first before calling | ||
# super init (as that calls serialize) | ||
assert closed in _VALID_CLOSED | ||
self._closed = closed | ||
if not isinstance(subtype, pyarrow.DataType): | ||
subtype = pyarrow.type_for_alias(str(subtype)) | ||
self._subtype = subtype | ||
|
||
storage_type = pyarrow.struct([("left", subtype), ("right", subtype)]) | ||
pyarrow.ExtensionType.__init__(self, storage_type, "pandas.interval") | ||
|
||
@property | ||
def subtype(self): | ||
return self._subtype | ||
|
||
@property | ||
def closed(self): | ||
return self._closed | ||
|
||
def __arrow_ext_serialize__(self): | ||
metadata = {"subtype": str(self.subtype), "closed": self.closed} | ||
return json.dumps(metadata).encode() | ||
|
||
@classmethod | ||
def __arrow_ext_deserialize__(cls, storage_type, serialized): | ||
metadata = json.loads(serialized.decode()) | ||
subtype = pyarrow.type_for_alias(metadata["subtype"]) | ||
closed = metadata["closed"] | ||
return ArrowIntervalType(subtype, closed) | ||
|
||
def __eq__(self, other): | ||
if isinstance(other, pyarrow.BaseExtensionType): | ||
return ( | ||
type(self) == type(other) | ||
and self.subtype == other.subtype | ||
and self.closed == other.closed | ||
) | ||
else: | ||
return NotImplemented | ||
|
||
def __hash__(self): | ||
return hash((str(self), str(self.subtype), self.closed)) | ||
|
||
# register the type with a dummy instance | ||
_interval_type = ArrowIntervalType(pyarrow.int64(), "left") | ||
pyarrow.register_extension_type(_interval_type) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,6 @@ | ||
from datetime import timedelta | ||
from distutils.version import LooseVersion | ||
import json | ||
import operator | ||
from typing import Any, Callable, List, Optional, Sequence, Union | ||
|
||
|
@@ -49,6 +51,13 @@ | |
from pandas.tseries import frequencies | ||
from pandas.tseries.offsets import DateOffset, Tick, _delta_to_tick | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same as above |
||
try: | ||
import pyarrow | ||
|
||
_PYARROW_INSTALLED = True | ||
except ImportError: | ||
_PYARROW_INSTALLED = False | ||
|
||
|
||
def _field_accessor(name, alias, docstring=None): | ||
def f(self): | ||
|
@@ -332,6 +341,31 @@ def __array__(self, dtype=None): | |
# overriding DatetimelikeArray | ||
return np.array(list(self), dtype=object) | ||
|
||
def __arrow_array__(self, type=None): | ||
""" | ||
Convert myself into a pyarrow Array. | ||
""" | ||
import pyarrow as pa | ||
|
||
if type is not None: | ||
if pa.types.is_integer(type): | ||
return pa.array(self._data, mask=self.isna(), type=type) | ||
elif isinstance(type, ArrowPeriodType): | ||
# ensure we have the same freq | ||
if self.freqstr != type.freq: | ||
raise TypeError( | ||
"Not supported to convert PeriodArray to array with different" | ||
" 'freq' ({0} vs {1})".format(self.freqstr, type.freq) | ||
) | ||
else: | ||
raise TypeError( | ||
"Not supported to convert PeriodArray to '{0}' type".format(type) | ||
) | ||
|
||
period_type = ArrowPeriodType(self.freqstr) | ||
storage_array = pa.array(self._data, mask=self.isna(), type="int64") | ||
return pa.ExtensionArray.from_storage(period_type, storage_array) | ||
|
||
# -------------------------------------------------------------------- | ||
# Vectorized analogues of Period properties | ||
|
||
|
@@ -1074,3 +1108,39 @@ def _make_field_arrays(*fields): | |
] | ||
|
||
return arrays | ||
|
||
|
||
if _PYARROW_INSTALLED and LooseVersion(pyarrow.__version__) >= LooseVersion("0.15"): | ||
|
||
class ArrowPeriodType(pyarrow.ExtensionType): | ||
def __init__(self, freq): | ||
# attributes need to be set first before calling | ||
# super init (as that calls serialize) | ||
self._freq = freq | ||
pyarrow.ExtensionType.__init__(self, pyarrow.int64(), "pandas.period") | ||
|
||
@property | ||
def freq(self): | ||
return self._freq | ||
|
||
def __arrow_ext_serialize__(self): | ||
metadata = {"freq": self.freq} | ||
return json.dumps(metadata).encode() | ||
|
||
@classmethod | ||
def __arrow_ext_deserialize__(cls, storage_type, serialized): | ||
metadata = json.loads(serialized.decode()) | ||
return ArrowPeriodType(metadata["freq"]) | ||
|
||
def __eq__(self, other): | ||
jorisvandenbossche marked this conversation as resolved.
Show resolved
Hide resolved
|
||
if isinstance(other, pyarrow.BaseExtensionType): | ||
return type(self) == type(other) and self.freq == other.freq | ||
else: | ||
return NotImplemented | ||
|
||
def __hash__(self): | ||
return hash((str(self), self.freq)) | ||
|
||
# register the type with a dummy instance | ||
_period_type = ArrowPeriodType("D") | ||
pyarrow.register_extension_type(_period_type) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can u make this into a function and put in common location
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I moved this for now into an
_arrow_utils.py
file in thearrays
directory (open for other names), we can then put some common functions in that file as well