Skip to content

API: implement __array_function__ for ExtensionArray #35032

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from 26 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
7fe754e
wip
simonjayhawkins Jun 26, 2020
d41d136
Merge remote-tracking branch 'upstream/master' into __array_function__
simonjayhawkins Jun 26, 2020
cc57254
Merge remote-tracking branch 'upstream/master' into __array_function__
simonjayhawkins Jun 26, 2020
0c515e4
add tests for np.tile and SparseArray specific implementation
simonjayhawkins Jun 26, 2020
aef535c
add IS_NEP18_ACTIVE and skip_if_no_nep18 test decorator
simonjayhawkins Jun 26, 2020
ed2962d
fix failing tests (tests/arrays and tests/extension)
simonjayhawkins Jun 27, 2020
cd161ad
mypy fixup
simonjayhawkins Jun 27, 2020
34afe27
Merge remote-tracking branch 'upstream/master' into __array_function__
simonjayhawkins Jun 28, 2020
5ca65e1
add np.vstack (dispatches back to NumPy)
simonjayhawkins Jun 28, 2020
f347243
remove np.tile implementation for now to reduce diff
simonjayhawkins Jun 28, 2020
009ac61
and remove _tile_1d
simonjayhawkins Jun 28, 2020
584ccf2
don't dispatch to EA.unique from np.unique
simonjayhawkins Jun 28, 2020
72e4477
lint fixup
simonjayhawkins Jun 28, 2020
807aae8
fix test_fillna_null for IntervalIndex
simonjayhawkins Jun 28, 2020
9621c25
Merge remote-tracking branch 'upstream/master' into __array_function__
simonjayhawkins Jun 29, 2020
ffab382
test from 35038 to get ci to green
simonjayhawkins Jun 29, 2020
0255671
add ArrayFunctionMixin
simonjayhawkins Jun 29, 2020
7b379f4
remove ArrayFunctionMixin unintentially added to BaseMaskedDtype
simonjayhawkins Jun 29, 2020
96db577
revert changes to pandas/core/arrays/numpy_.py
simonjayhawkins Jun 29, 2020
d8b1a8b
Revert "revert changes to pandas/core/arrays/numpy_.py"
simonjayhawkins Jun 29, 2020
61c53f5
add ArrayFunctionMixin to NDArrayBackedExtensionArray and reinstate P…
simonjayhawkins Jun 29, 2020
39ae5f9
Merge remote-tracking branch 'upstream/master' into __array_function__
simonjayhawkins Jul 1, 2020
0e27746
Merge remote-tracking branch 'upstream/master' into __array_function__
simonjayhawkins Jul 4, 2020
d7f4550
reduce diff
simonjayhawkins Jul 4, 2020
577010f
refactor condition
simonjayhawkins Jul 4, 2020
bc3f4e1
Merge remote-tracking branch 'upstream/master' into __array_function__
simonjayhawkins Jul 17, 2020
d8af2fc
Merge remote-tracking branch 'upstream/master' into __array_function__
simonjayhawkins Sep 13, 2020
3843481
amin, amax to NDArrayBackedExtensionArray
simonjayhawkins Sep 13, 2020
20c7328
Merge remote-tracking branch 'upstream/master' into __array_function__
simonjayhawkins Sep 14, 2020
7b9fab5
concat_compat raises numpy.AxisError with axis=1
simonjayhawkins Sep 14, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions pandas/compat/numpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,21 @@ def np_array_datetime64_compat(arr, *args, **kwargs):
return np.array(arr, *args, **kwargs)


# taken from dask array
# https://github.com/dask/dask/blob/master/dask/array/utils.py#L352-L363
def _is_nep18_active():
class A:
def __array_function__(self, *args, **kwargs):
return True

try:
return np.concatenate([A()])
except ValueError:
return False


IS_NEP18_ACTIVE = _is_nep18_active()

__all__ = [
"np",
"_np_version",
Expand Down
2 changes: 2 additions & 0 deletions pandas/core/arrays/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from pandas.core.arrays.base import (
ArrayFunctionMixin,
ExtensionArray,
ExtensionOpsMixin,
ExtensionScalarOpsMixin,
Expand All @@ -15,6 +16,7 @@
from pandas.core.arrays.timedeltas import TimedeltaArray

__all__ = [
"ArrayFunctionMixin",
"ExtensionArray",
"ExtensionOpsMixin",
"ExtensionScalarOpsMixin",
Expand Down
4 changes: 2 additions & 2 deletions pandas/core/arrays/_mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@
from pandas.util._decorators import cache_readonly

from pandas.core.algorithms import take, unique
from pandas.core.arrays.base import ExtensionArray
from pandas.core.arrays.base import ArrayFunctionMixin, ExtensionArray

_T = TypeVar("_T", bound="NDArrayBackedExtensionArray")


class NDArrayBackedExtensionArray(ExtensionArray):
class NDArrayBackedExtensionArray(ExtensionArray, ArrayFunctionMixin):
"""
ExtensionArray that is backed by a single NumPy ndarray.
"""
Expand Down
37 changes: 36 additions & 1 deletion pandas/core/arrays/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import numpy as np

from pandas._libs import lib
from pandas._typing import ArrayLike
from pandas._typing import ArrayLike, F
from pandas.compat import set_function_name
from pandas.compat.numpy import function as nv
from pandas.errors import AbstractMethodError
Expand All @@ -32,6 +32,41 @@

_extension_array_shared_docs: Dict[str, str] = dict()

_HANDLED_FUNCTIONS = {}


def implements(numpy_function) -> Callable[[F], F]:
"""
Register an __array_function__ implementation for ExtensionArray objects.
"""

def decorator(func):
_HANDLED_FUNCTIONS[numpy_function] = func
return func

return decorator


class ArrayFunctionMixin:
def __array_function__(self, func, types, args, kwargs):
if func not in _HANDLED_FUNCTIONS:
# try to find a matching method name. If that doesn't work, we may
# be dealing with an alias or a function that's simply not in the
# ExtensionArray API. Handle aliases via the _HANDLED_FUNCTIONS
# dict mapping.
exclude_list = {"unique"}
ea_func = getattr(type(self), func.__name__, None)
if (
ea_func is None
or not callable(ea_func)
or ea_func.__name__ in exclude_list
):
return func.__wrapped__(*args, **kwargs)

return ea_func(*args, **kwargs)

return _HANDLED_FUNCTIONS[func](*args, **kwargs)


class ExtensionArray:
"""
Expand Down
4 changes: 4 additions & 0 deletions pandas/core/arrays/categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -2118,6 +2118,8 @@ def min(self, skipna=True, **kwargs):
pointer = self._codes.min()
return self.categories[pointer]

amin = min

@deprecate_kwarg(old_arg_name="numeric_only", new_arg_name="skipna")
def max(self, skipna=True, **kwargs):
"""
Expand Down Expand Up @@ -2154,6 +2156,8 @@ def max(self, skipna=True, **kwargs):
pointer = self._codes.max()
return self.categories[pointer]

amax = max

def mode(self, dropna=True):
"""
Returns the mode(s) of the Categorical.
Expand Down
4 changes: 4 additions & 0 deletions pandas/core/arrays/datetimelike.py
Original file line number Diff line number Diff line change
Expand Up @@ -1579,6 +1579,8 @@ def min(self, axis=None, skipna=True, *args, **kwargs):
return NaT
return self._box_func(result)

amin = min

def max(self, axis=None, skipna=True, *args, **kwargs):
"""
Return the maximum value of the Array or maximum along
Expand Down Expand Up @@ -1611,6 +1613,8 @@ def max(self, axis=None, skipna=True, *args, **kwargs):
# Don't have to worry about NA `result`, since no NA went in.
return self._box_func(result)

amax = max

def mean(self, skipna=True):
"""
Return the mean value of the Array.
Expand Down
8 changes: 6 additions & 2 deletions pandas/core/arrays/interval.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,11 @@
from pandas.core.dtypes.missing import isna, notna

from pandas.core.algorithms import take, value_counts
from pandas.core.arrays.base import ExtensionArray, _extension_array_shared_docs
from pandas.core.arrays.base import (
ArrayFunctionMixin,
ExtensionArray,
_extension_array_shared_docs,
)
from pandas.core.arrays.categorical import Categorical
import pandas.core.common as com
from pandas.core.construction import array
Expand Down Expand Up @@ -143,7 +147,7 @@
),
)
)
class IntervalArray(IntervalMixin, ExtensionArray):
class IntervalArray(IntervalMixin, ExtensionArray, ArrayFunctionMixin):
ndim = 1
can_hold_na = True
_na_value = _fill_value = np.nan
Expand Down
4 changes: 2 additions & 2 deletions pandas/core/arrays/masked.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from pandas.core import nanops
from pandas.core.algorithms import _factorize_array, take
from pandas.core.array_algos import masked_reductions
from pandas.core.arrays import ExtensionArray, ExtensionOpsMixin
from pandas.core.arrays import ArrayFunctionMixin, ExtensionArray, ExtensionOpsMixin
from pandas.core.indexers import check_array_indexer

if TYPE_CHECKING:
Expand Down Expand Up @@ -52,7 +52,7 @@ def construct_array_type(cls) -> Type["BaseMaskedArray"]:
raise NotImplementedError


class BaseMaskedArray(ExtensionArray, ExtensionOpsMixin):
class BaseMaskedArray(ExtensionArray, ExtensionOpsMixin, ArrayFunctionMixin):
"""
Base class for masked arrays (which use _data and _mask to store the data).

Expand Down
5 changes: 5 additions & 0 deletions pandas/core/arrays/numpy_.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,7 @@ def _from_factorized(cls, values, original) -> "PandasArray":

@classmethod
def _concat_same_type(cls, to_concat) -> "PandasArray":
to_concat = [arr.to_numpy() for arr in to_concat]
return cls(np.concatenate(to_concat))

def _from_backing_data(self, arr: np.ndarray) -> "PandasArray":
Expand Down Expand Up @@ -347,13 +348,17 @@ def min(self, skipna: bool = True, **kwargs) -> Scalar:
)
return result

amin = min

def max(self, skipna: bool = True, **kwargs) -> Scalar:
nv.validate_max((), kwargs)
result = masked_reductions.max(
values=self.to_numpy(), mask=self.isna(), skipna=skipna
)
return result

amax = max

def sum(self, axis=None, skipna=True, min_count=0, **kwargs) -> Scalar:
nv.validate_sum((), kwargs)
return nanops.nansum(
Expand Down
4 changes: 2 additions & 2 deletions pandas/core/arrays/sparse/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
from pandas.core.dtypes.missing import isna, na_value_for_dtype, notna

import pandas.core.algorithms as algos
from pandas.core.arrays import ExtensionArray, ExtensionOpsMixin
from pandas.core.arrays import ArrayFunctionMixin, ExtensionArray, ExtensionOpsMixin
from pandas.core.arrays.sparse.dtype import SparseDtype
from pandas.core.base import PandasObject
import pandas.core.common as com
Expand Down Expand Up @@ -195,7 +195,7 @@ def _wrap_result(name, data, sparse_index, fill_value, dtype=None):
)


class SparseArray(PandasObject, ExtensionArray, ExtensionOpsMixin):
class SparseArray(PandasObject, ExtensionArray, ExtensionOpsMixin, ArrayFunctionMixin):
"""
An ExtensionArray for storing sparse data.

Expand Down
6 changes: 6 additions & 0 deletions pandas/core/dtypes/concat.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from pandas.core.dtypes.generic import ABCCategoricalIndex, ABCRangeIndex, ABCSeries

from pandas.core.arrays import ExtensionArray
from pandas.core.arrays.base import implements
from pandas.core.arrays.sparse import SparseArray
from pandas.core.construction import array

Expand Down Expand Up @@ -107,6 +108,7 @@ def _cast_to_common_type(arr: ArrayLike, dtype: DtypeObj) -> ArrayLike:
return arr.astype(dtype, copy=False)


@implements(np.concatenate)
def concat_compat(to_concat, axis: int = 0):
"""
provide concatenation of an array of arrays each of which is a single
Expand Down Expand Up @@ -156,6 +158,10 @@ def is_nonempty(x) -> bool:
cls = type(to_concat[0])
return cls._concat_same_type(to_concat)
else:
to_concat = [
arr.to_numpy() if isinstance(arr, ExtensionArray) else arr
for arr in to_concat
]
return np.concatenate(to_concat, axis=axis)

elif _contains_datetime or "timedelta" in typs:
Expand Down
1 change: 1 addition & 0 deletions pandas/tests/extension/base/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ class TestMyDtype(BaseDtypeTests):
from .io import BaseParsingTests # noqa
from .methods import BaseMethodsTests # noqa
from .missing import BaseMissingTests # noqa
from .numpy_array_functions import BaseNumpyArrayFunctionTests # noqa
from .ops import ( # noqa
BaseArithmeticOpsTests,
BaseComparisonOpsTests,
Expand Down
4 changes: 4 additions & 0 deletions pandas/tests/extension/base/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,7 @@ def assert_frame_equal(cls, left, right, *args, **kwargs):
@classmethod
def assert_extension_array_equal(cls, left, right, *args, **kwargs):
return tm.assert_extension_array_equal(left, right, *args, **kwargs)

@classmethod
def assert_numpy_array_equal(cls, left, right, *args, **kwargs):
return tm.assert_numpy_array_equal(left, right, *args, **kwargs)
32 changes: 32 additions & 0 deletions pandas/tests/extension/base/numpy_array_functions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import numpy as np
import pytest

import pandas.util._test_decorators as td

import pandas as pd

from .base import BaseExtensionTests


@td.skip_if_no_nep18
class BaseNumpyArrayFunctionTests(BaseExtensionTests):
def test_concatenate(self, data):
expected = pd.array(list(data) * 3, dtype=data.dtype)

result = np.concatenate([data] * 3)
self.assert_extension_array_equal(result, expected)

result = np.concatenate([data] * 3, axis=0)
self.assert_extension_array_equal(result, expected)

msg = "axis 1 is out of bounds for array of dimension 1"
with pytest.raises(np.AxisError, match=msg):
np.concatenate([data] * 3, axis=1)

def test_ndim(self, data):
assert np.ndim(data) == 1

def test_vstack(self, data):
expected = np.array([data.to_numpy()] * 2)
result = np.vstack([data, data])
self.assert_numpy_array_equal(result, expected)
4 changes: 4 additions & 0 deletions pandas/tests/extension/test_boolean.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,3 +387,7 @@ class TestUnaryOps(base.BaseUnaryOpsTests):

class TestParsing(base.BaseParsingTests):
pass


class TestNumpyArrayFunctions(base.BaseNumpyArrayFunctionTests):
pass
4 changes: 4 additions & 0 deletions pandas/tests/extension/test_categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,3 +268,7 @@ def test_not_equal_with_na(self, categories):

class TestParsing(base.BaseParsingTests):
pass


class TestNumpyArrayFunctions(base.BaseNumpyArrayFunctionTests):
pass
4 changes: 4 additions & 0 deletions pandas/tests/extension/test_datetime.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,3 +221,7 @@ class TestGroupby(BaseDatetimeTests, base.BaseGroupbyTests):

class TestPrinting(BaseDatetimeTests, base.BasePrintingTests):
pass


class TestNumpyArrayFunctions(base.BaseNumpyArrayFunctionTests):
pass
4 changes: 4 additions & 0 deletions pandas/tests/extension/test_integer.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,3 +255,7 @@ class TestPrinting(base.BasePrintingTests):

class TestParsing(base.BaseParsingTests):
pass


class TestNumpyArrayFunctions(base.BaseNumpyArrayFunctionTests):
pass
4 changes: 4 additions & 0 deletions pandas/tests/extension/test_interval.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,3 +164,7 @@ def test_EA_types(self, engine, data):
expected_msg = r".*must implement _from_sequence_of_strings.*"
with pytest.raises(NotImplementedError, match=expected_msg):
super().test_EA_types(engine, data)


class TestNumpyArrayFunctions(base.BaseNumpyArrayFunctionTests):
pass
5 changes: 5 additions & 0 deletions pandas/tests/extension/test_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,3 +502,8 @@ def test_setitem_loc_iloc_slice(self, data):
@skip_nested
class TestParsing(BaseNumPyTests, base.BaseParsingTests):
pass


@skip_nested
class TestNumpyArrayFunctions(base.BaseNumpyArrayFunctionTests):
pass
4 changes: 4 additions & 0 deletions pandas/tests/extension/test_period.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,3 +168,7 @@ class TestParsing(BasePeriodTests, base.BaseParsingTests):
@pytest.mark.parametrize("engine", ["c", "python"])
def test_EA_types(self, engine, data):
super().test_EA_types(engine, data)


class TestNumpyArrayFunctions(base.BaseNumpyArrayFunctionTests):
pass
4 changes: 4 additions & 0 deletions pandas/tests/extension/test_sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,3 +427,7 @@ def test_EA_types(self, engine, data):
expected_msg = r".*must implement _from_sequence_of_strings.*"
with pytest.raises(NotImplementedError, match=expected_msg):
super().test_EA_types(engine, data)


class TestNumpyArrayFunctions(base.BaseNumpyArrayFunctionTests):
pass
4 changes: 4 additions & 0 deletions pandas/tests/extension/test_string.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,3 +121,7 @@ class TestPrinting(base.BasePrintingTests):

class TestGroupBy(base.BaseGroupbyTests):
pass


class TestNumpyArrayFunctions(base.BaseNumpyArrayFunctionTests):
pass
5 changes: 4 additions & 1 deletion pandas/util/_test_decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def test_foo():

from pandas.compat import is_platform_32bit, is_platform_windows
from pandas.compat._optional import import_optional_dependency
from pandas.compat.numpy import _np_version
from pandas.compat.numpy import IS_NEP18_ACTIVE, _np_version

from pandas.core.computation.expressions import _NUMEXPR_INSTALLED, _USE_NUMEXPR

Expand Down Expand Up @@ -194,6 +194,9 @@ def skip_if_no(package: str, min_version: Optional[str] = None) -> Callable:
not _USE_NUMEXPR,
reason=f"numexpr enabled->{_USE_NUMEXPR}, installed->{_NUMEXPR_INSTALLED}",
)
skip_if_no_nep18 = pytest.mark.skipif(
not IS_NEP18_ACTIVE, reason="NEP-18 support is not available in NumPy"
)


def skip_if_np_lt(
Expand Down