From 7fe754ebae5f28afa4f6c629388e21203894d059 Mon Sep 17 00:00:00 2001 From: Simon Hawkins Date: Fri, 26 Jun 2020 10:19:59 +0100 Subject: [PATCH 01/21] wip --- pandas/core/arrays/base.py | 44 ++++++++++++++++++++++++++++++++++++- pandas/core/reshape/melt.py | 10 +++------ 2 files changed, 46 insertions(+), 8 deletions(-) diff --git a/pandas/core/arrays/base.py b/pandas/core/arrays/base.py index 5565b85f8d59a..d0ec98afe60c7 100644 --- a/pandas/core/arrays/base.py +++ b/pandas/core/arrays/base.py @@ -12,7 +12,7 @@ import numpy as np from pandas._libs import lib -from pandas._typing import ArrayLike +from pandas._typing import AnyArrayLike, ArrayLike, F from pandas.compat import set_function_name from pandas.compat.numpy import function as nv from pandas.errors import AbstractMethodError @@ -32,6 +32,20 @@ _extension_array_shared_docs: Dict[str, str] = dict() +HANDLED_FUNCTIONS = {} + + +def implements(numpy_function) -> Callable[[F], F]: + """ + Register an __array_function__ implementation for ExtensionArray objects. + """ + + def decorator(func): + HANDLED_FUNCTIONS[numpy_function] = func + return func + + return decorator + class ExtensionArray: """ @@ -165,6 +179,15 @@ class ExtensionArray: # Don't override this. _typ = "extension" + def __array_function__(self, func, types, args, kwargs): + if func not in HANDLED_FUNCTIONS: + return NotImplemented + # Note: this allows subclasses that don't override + # __array_function__ to handle ExtensionArray objects + if not all(issubclass(t, ExtensionArray) for t in types): + return NotImplemented + return HANDLED_FUNCTIONS[func](*args, **kwargs) + # ------------------------------------------------------------------------ # Constructors # ------------------------------------------------------------------------ @@ -863,6 +886,13 @@ def repeat(self, repeats, axis=None): ind = np.arange(len(self)).repeat(repeats) return self.take(ind) + def _tile(self, rep: int) -> "ExtensionArray": + """ + non-public method for compatibility with np.tile + """ + ind = np.tile(np.arange(len(self)), rep) + return self.take(ind) + # ------------------------------------------------------------------------ # Indexing methods # ------------------------------------------------------------------------ @@ -1279,3 +1309,15 @@ def _create_arithmetic_method(cls, op): @classmethod def _create_comparison_method(cls, op): return cls._create_method(op, coerce_to_dtype=False, result_dtype=bool) + + +@implements(np.tile) +def tile(array: ExtensionArray, reps: Union[int, AnyArrayLike]): + try: + tup = tuple(reps) + except TypeError: + tup = (reps,) + d = len(tup) + if d != 1: + raise ValueError("can only tile extension arrays along first axis") + return array._tile(reps) diff --git a/pandas/core/reshape/melt.py b/pandas/core/reshape/melt.py index cd0619738677d..a17079f56c868 100644 --- a/pandas/core/reshape/melt.py +++ b/pandas/core/reshape/melt.py @@ -1,18 +1,17 @@ import re -from typing import TYPE_CHECKING, List, cast +from typing import TYPE_CHECKING, List import numpy as np from pandas.util._decorators import Appender, deprecate_kwarg -from pandas.core.dtypes.common import is_extension_array_dtype, is_list_like +from pandas.core.dtypes.common import is_list_like from pandas.core.dtypes.concat import concat_compat from pandas.core.dtypes.missing import notna from pandas.core.arrays import Categorical import pandas.core.common as com from pandas.core.indexes.api import Index, MultiIndex -from pandas.core.reshape.concat import concat from pandas.core.shared_docs import _shared_docs from pandas.core.tools.numeric import to_numeric @@ -108,10 +107,7 @@ def melt( mdata = {} for col in id_vars: id_data = frame.pop(col) - if is_extension_array_dtype(id_data): - id_data = cast("Series", concat([id_data] * K, ignore_index=True)) - else: - id_data = np.tile(id_data._values, K) + id_data = np.tile(id_data._values, K) mdata[col] = id_data mcolumns = id_vars + var_name + [value_name] From 0c515e4bc5253b99bdf4b022fee075d7d811325d Mon Sep 17 00:00:00 2001 From: Simon Hawkins Date: Fri, 26 Jun 2020 15:52:34 +0100 Subject: [PATCH 02/21] add tests for np.tile and SparseArray specific implementation --- pandas/core/arrays/base.py | 4 ++++ pandas/core/arrays/sparse/array.py | 15 +++++++++++++++ pandas/tests/extension/base/__init__.py | 1 + .../tests/extension/base/numpy_array_functions.py | 12 ++++++++++++ pandas/tests/extension/test_boolean.py | 4 ++++ pandas/tests/extension/test_categorical.py | 4 ++++ pandas/tests/extension/test_datetime.py | 4 ++++ pandas/tests/extension/test_integer.py | 4 ++++ pandas/tests/extension/test_interval.py | 4 ++++ pandas/tests/extension/test_numpy.py | 7 +++++++ pandas/tests/extension/test_period.py | 4 ++++ pandas/tests/extension/test_sparse.py | 4 ++++ pandas/tests/extension/test_string.py | 4 ++++ 13 files changed, 71 insertions(+) create mode 100644 pandas/tests/extension/base/numpy_array_functions.py diff --git a/pandas/core/arrays/base.py b/pandas/core/arrays/base.py index d0ec98afe60c7..666df67dfda19 100644 --- a/pandas/core/arrays/base.py +++ b/pandas/core/arrays/base.py @@ -886,6 +886,10 @@ def repeat(self, repeats, axis=None): ind = np.arange(len(self)).repeat(repeats) return self.take(ind) + # ------------------------------------------------------------------------ + # Numpy Array Functions + # ------------------------------------------------------------------------ + def _tile(self, rep: int) -> "ExtensionArray": """ non-public method for compatibility with np.tile diff --git a/pandas/core/arrays/sparse/array.py b/pandas/core/arrays/sparse/array.py index 4996a10002c63..a25bdea133d02 100644 --- a/pandas/core/arrays/sparse/array.py +++ b/pandas/core/arrays/sparse/array.py @@ -1379,6 +1379,21 @@ def __array_ufunc__(self, ufunc, method, *inputs, **kwargs): def __abs__(self): return np.abs(self) + # ------------------------------------------------------------------------ + # Numpy Array Functions + # ------------------------------------------------------------------------ + + def _tile(self, rep: int) -> "ExtensionArray": + """ + non-public method for compatibility with np.tile + """ + sp_values = np.tile(self.sp_values, rep) + index = self.sp_index.indices + for i in range(1, rep): + index = np.append(index, self.sp_index.indices + (i * len(self))) + sp_index = IntIndex(len(self) * rep, index) + return self._simple_new(sp_values, sp_index, self.dtype) + # ------------------------------------------------------------------------ # Ops # ------------------------------------------------------------------------ diff --git a/pandas/tests/extension/base/__init__.py b/pandas/tests/extension/base/__init__.py index 323cb843b2d74..7171014cf0693 100644 --- a/pandas/tests/extension/base/__init__.py +++ b/pandas/tests/extension/base/__init__.py @@ -64,3 +64,4 @@ class TestMyDtype(BaseDtypeTests): ) from .reshaping import BaseReshapingTests # noqa from .setitem import BaseSetitemTests # noqa +from .numpy_array_functions import BaseNumpyArrayFunctionTests # noqa diff --git a/pandas/tests/extension/base/numpy_array_functions.py b/pandas/tests/extension/base/numpy_array_functions.py new file mode 100644 index 0000000000000..cd135c769e13f --- /dev/null +++ b/pandas/tests/extension/base/numpy_array_functions.py @@ -0,0 +1,12 @@ +import numpy as np + +import pandas as pd + +from .base import BaseExtensionTests + + +class BaseNumpyArrayFunctionTests(BaseExtensionTests): + def test_tile(self, data): + expected = pd.array(list(data) * 3, dtype=data.dtype) + result = np.tile(data, 3) + self.assert_extension_array_equal(result, expected) diff --git a/pandas/tests/extension/test_boolean.py b/pandas/tests/extension/test_boolean.py index 725067951eeef..a5d33ec368ee6 100644 --- a/pandas/tests/extension/test_boolean.py +++ b/pandas/tests/extension/test_boolean.py @@ -370,3 +370,7 @@ class TestUnaryOps(base.BaseUnaryOpsTests): class TestParsing(base.BaseParsingTests): pass + + +class TestNumpyArrayFunctions(base.BaseNumpyArrayFunctionTests): + pass diff --git a/pandas/tests/extension/test_categorical.py b/pandas/tests/extension/test_categorical.py index d1211e477fe3e..4ee40fc104e3e 100644 --- a/pandas/tests/extension/test_categorical.py +++ b/pandas/tests/extension/test_categorical.py @@ -267,3 +267,7 @@ def test_not_equal_with_na(self, categories): class TestParsing(base.BaseParsingTests): pass + + +class TestNumpyArrayFunctions(base.BaseNumpyArrayFunctionTests): + pass diff --git a/pandas/tests/extension/test_datetime.py b/pandas/tests/extension/test_datetime.py index e026809f7e611..413adba040048 100644 --- a/pandas/tests/extension/test_datetime.py +++ b/pandas/tests/extension/test_datetime.py @@ -221,3 +221,7 @@ class TestGroupby(BaseDatetimeTests, base.BaseGroupbyTests): class TestPrinting(BaseDatetimeTests, base.BasePrintingTests): pass + + +class TestNumpyArrayFunctions(base.BaseNumpyArrayFunctionTests): + pass diff --git a/pandas/tests/extension/test_integer.py b/pandas/tests/extension/test_integer.py index 725533765ca2c..1c67b573c8fa4 100644 --- a/pandas/tests/extension/test_integer.py +++ b/pandas/tests/extension/test_integer.py @@ -255,3 +255,7 @@ class TestPrinting(base.BasePrintingTests): class TestParsing(base.BaseParsingTests): pass + + +class TestNumpyArrayFunctions(base.BaseNumpyArrayFunctionTests): + pass diff --git a/pandas/tests/extension/test_interval.py b/pandas/tests/extension/test_interval.py index 2411f6cfbd936..db63177d5b56f 100644 --- a/pandas/tests/extension/test_interval.py +++ b/pandas/tests/extension/test_interval.py @@ -164,3 +164,7 @@ def test_EA_types(self, engine, data): expected_msg = r".*must implement _from_sequence_of_strings.*" with pytest.raises(NotImplementedError, match=expected_msg): super().test_EA_types(engine, data) + + +class TestNumpyArrayFunctions(base.BaseNumpyArrayFunctionTests): + pass diff --git a/pandas/tests/extension/test_numpy.py b/pandas/tests/extension/test_numpy.py index 78000c0252375..3c230742a1eae 100644 --- a/pandas/tests/extension/test_numpy.py +++ b/pandas/tests/extension/test_numpy.py @@ -502,3 +502,10 @@ def test_setitem_loc_iloc_slice(self, data): @skip_nested class TestParsing(BaseNumPyTests, base.BaseParsingTests): pass + + +class TestNumpyArrayFunctions(base.BaseNumpyArrayFunctionTests): + def test_tile(self, data, dtype): + # TODO: maybe rethink the validation + # raises ValueError: PandasArray must be 1-dimensional. + pass diff --git a/pandas/tests/extension/test_period.py b/pandas/tests/extension/test_period.py index b1eb276bfc227..ae98c2679aa52 100644 --- a/pandas/tests/extension/test_period.py +++ b/pandas/tests/extension/test_period.py @@ -168,3 +168,7 @@ class TestParsing(BasePeriodTests, base.BaseParsingTests): @pytest.mark.parametrize("engine", ["c", "python"]) def test_EA_types(self, engine, data): super().test_EA_types(engine, data) + + +class TestNumpyArrayFunctions(base.BaseNumpyArrayFunctionTests): + pass diff --git a/pandas/tests/extension/test_sparse.py b/pandas/tests/extension/test_sparse.py index f318934ef5e52..2d1bbdafb0ddd 100644 --- a/pandas/tests/extension/test_sparse.py +++ b/pandas/tests/extension/test_sparse.py @@ -424,3 +424,7 @@ def test_EA_types(self, engine, data): expected_msg = r".*must implement _from_sequence_of_strings.*" with pytest.raises(NotImplementedError, match=expected_msg): super().test_EA_types(engine, data) + + +class TestNumpyArrayFunctions(base.BaseNumpyArrayFunctionTests): + pass diff --git a/pandas/tests/extension/test_string.py b/pandas/tests/extension/test_string.py index 27a157d2127f6..11788f71482e1 100644 --- a/pandas/tests/extension/test_string.py +++ b/pandas/tests/extension/test_string.py @@ -121,3 +121,7 @@ class TestPrinting(base.BasePrintingTests): class TestGroupBy(base.BaseGroupbyTests): pass + + +class TestNumpyArrayFunctions(base.BaseNumpyArrayFunctionTests): + pass From aef535c4a76237fbc3623f08a397a08ba8b8f9da Mon Sep 17 00:00:00 2001 From: Simon Hawkins Date: Fri, 26 Jun 2020 20:41:07 +0100 Subject: [PATCH 03/21] add IS_NEP18_ACTIVE and skip_if_no_nep18 test decorator --- pandas/compat/numpy/__init__.py | 15 +++++++++++++++ pandas/core/reshape/melt.py | 11 ++++++++--- .../tests/extension/base/numpy_array_functions.py | 3 +++ pandas/util/_test_decorators.py | 5 ++++- 4 files changed, 30 insertions(+), 4 deletions(-) diff --git a/pandas/compat/numpy/__init__.py b/pandas/compat/numpy/__init__.py index 789a4668b6fee..9086baf43f2b9 100644 --- a/pandas/compat/numpy/__init__.py +++ b/pandas/compat/numpy/__init__.py @@ -62,6 +62,21 @@ def np_array_datetime64_compat(arr, *args, **kwargs): return np.array(arr, *args, **kwargs) +# taken from dask array +# https://github.com/dask/dask/blob/master/dask/array/utils.py#L352-L363 +def _is_nep18_active(): + class A: + def __array_function__(self, *args, **kwargs): + return True + + try: + return np.concatenate([A()]) + except ValueError: + return False + + +IS_NEP18_ACTIVE = _is_nep18_active() + __all__ = [ "np", "_np_version", diff --git a/pandas/core/reshape/melt.py b/pandas/core/reshape/melt.py index a17079f56c868..3fe1bbaf53148 100644 --- a/pandas/core/reshape/melt.py +++ b/pandas/core/reshape/melt.py @@ -1,17 +1,19 @@ import re -from typing import TYPE_CHECKING, List +from typing import TYPE_CHECKING, List, cast import numpy as np +from pandas.compat.numpy import IS_NEP18_ACTIVE from pandas.util._decorators import Appender, deprecate_kwarg -from pandas.core.dtypes.common import is_list_like +from pandas.core.dtypes.common import is_extension_array_dtype, is_list_like from pandas.core.dtypes.concat import concat_compat from pandas.core.dtypes.missing import notna from pandas.core.arrays import Categorical import pandas.core.common as com from pandas.core.indexes.api import Index, MultiIndex +from pandas.core.reshape.concat import concat from pandas.core.shared_docs import _shared_docs from pandas.core.tools.numeric import to_numeric @@ -107,7 +109,10 @@ def melt( mdata = {} for col in id_vars: id_data = frame.pop(col) - id_data = np.tile(id_data._values, K) + if not IS_NEP18_ACTIVE and is_extension_array_dtype(id_data): + id_data = cast("Series", concat([id_data] * K, ignore_index=True)) + else: + id_data = np.tile(id_data._values, K) mdata[col] = id_data mcolumns = id_vars + var_name + [value_name] diff --git a/pandas/tests/extension/base/numpy_array_functions.py b/pandas/tests/extension/base/numpy_array_functions.py index cd135c769e13f..4b852a3ac1099 100644 --- a/pandas/tests/extension/base/numpy_array_functions.py +++ b/pandas/tests/extension/base/numpy_array_functions.py @@ -1,11 +1,14 @@ import numpy as np +import pandas.util._test_decorators as td + import pandas as pd from .base import BaseExtensionTests class BaseNumpyArrayFunctionTests(BaseExtensionTests): + @td.skip_if_no_nep18 def test_tile(self, data): expected = pd.array(list(data) * 3, dtype=data.dtype) result = np.tile(data, 3) diff --git a/pandas/util/_test_decorators.py b/pandas/util/_test_decorators.py index 25394dc6775d8..a32dd28476bc7 100644 --- a/pandas/util/_test_decorators.py +++ b/pandas/util/_test_decorators.py @@ -33,7 +33,7 @@ def test_foo(): from pandas.compat import is_platform_32bit, is_platform_windows from pandas.compat._optional import import_optional_dependency -from pandas.compat.numpy import _np_version +from pandas.compat.numpy import IS_NEP18_ACTIVE, _np_version from pandas.core.computation.expressions import _NUMEXPR_INSTALLED, _USE_NUMEXPR @@ -194,6 +194,9 @@ def skip_if_no(package: str, min_version: Optional[str] = None) -> Callable: not _USE_NUMEXPR, reason=f"numexpr enabled->{_USE_NUMEXPR}, installed->{_NUMEXPR_INSTALLED}", ) +skip_if_no_nep18 = pytest.mark.skipif( + not IS_NEP18_ACTIVE, reason="NEP-18 support is not available in NumPy" +) def skip_if_np_lt( From ed2962d1cbc96d8c41995e6c2c876e709da041f4 Mon Sep 17 00:00:00 2001 From: Simon Hawkins Date: Sat, 27 Jun 2020 20:11:31 +0100 Subject: [PATCH 04/21] fix failing tests (tests/arrays and tests/extension) --- pandas/core/arrays/base.py | 56 ++++++++++++------- pandas/core/arrays/categorical.py | 4 ++ pandas/core/arrays/datetimelike.py | 4 ++ pandas/core/arrays/numpy_.py | 9 ++- pandas/core/arrays/sparse/array.py | 26 ++++----- pandas/core/dtypes/concat.py | 10 +++- pandas/core/internals/concat.py | 2 +- pandas/tests/extension/base/__init__.py | 2 +- pandas/tests/extension/base/base.py | 4 ++ .../extension/base/numpy_array_functions.py | 28 +++++++++- pandas/tests/extension/test_numpy.py | 6 +- 11 files changed, 106 insertions(+), 45 deletions(-) diff --git a/pandas/core/arrays/base.py b/pandas/core/arrays/base.py index 666df67dfda19..040a73ee65acf 100644 --- a/pandas/core/arrays/base.py +++ b/pandas/core/arrays/base.py @@ -32,7 +32,7 @@ _extension_array_shared_docs: Dict[str, str] = dict() -HANDLED_FUNCTIONS = {} +_HANDLED_FUNCTIONS = {} def implements(numpy_function) -> Callable[[F], F]: @@ -41,7 +41,7 @@ def implements(numpy_function) -> Callable[[F], F]: """ def decorator(func): - HANDLED_FUNCTIONS[numpy_function] = func + _HANDLED_FUNCTIONS[numpy_function] = func return func return decorator @@ -180,13 +180,25 @@ class ExtensionArray: _typ = "extension" def __array_function__(self, func, types, args, kwargs): - if func not in HANDLED_FUNCTIONS: - return NotImplemented - # Note: this allows subclasses that don't override - # __array_function__ to handle ExtensionArray objects - if not all(issubclass(t, ExtensionArray) for t in types): - return NotImplemented - return HANDLED_FUNCTIONS[func](*args, **kwargs) + if func not in _HANDLED_FUNCTIONS: + # try to find a matching method name. If that doesn't work, we may + # be dealing with an alias or a function that's simply not in the + # ExtensionArray API. Handle aliases via the _HANDLED_FUNCTIONS + # dict mapping. + if not hasattr(type(self), func.__name__): + # Need to convert EAs to numpy.ndarray so we can call the NumPy + # function again and it gets the chance to dispatch to the + # right implementation. + args = tuple( + arg.to_numpy() if isinstance(arg, ExtensionArray) else arg + for arg in args + ) + return func(*args, **kwargs) + + func = getattr(type(self), func.__name__) + return func(*args, **kwargs) + + return _HANDLED_FUNCTIONS[func](*args, **kwargs) # ------------------------------------------------------------------------ # Constructors @@ -886,11 +898,7 @@ def repeat(self, repeats, axis=None): ind = np.arange(len(self)).repeat(repeats) return self.take(ind) - # ------------------------------------------------------------------------ - # Numpy Array Functions - # ------------------------------------------------------------------------ - - def _tile(self, rep: int) -> "ExtensionArray": + def _tile_1d(self, rep: int) -> "ExtensionArray": """ non-public method for compatibility with np.tile """ @@ -1316,12 +1324,22 @@ def _create_comparison_method(cls, op): @implements(np.tile) -def tile(array: ExtensionArray, reps: Union[int, AnyArrayLike]): +def tile(arr: ExtensionArray, reps: Union[int, AnyArrayLike]) -> ArrayLike: + """ + Construct an array by repeating array the number of times given by reps. + """ try: tup = tuple(reps) except TypeError: tup = (reps,) - d = len(tup) - if d != 1: - raise ValueError("can only tile extension arrays along first axis") - return array._tile(reps) + if len(tup) == 1: + return arr._tile_1d(tup[0]) + return np.tile(arr.to_numpy(), tup) + + +@implements(np.ndim) +def ndim(array: ExtensionArray) -> int: + """ + Return the number of dimensions of an array. + """ + return array.ndim diff --git a/pandas/core/arrays/categorical.py b/pandas/core/arrays/categorical.py index 1fedfa70cc469..5f72c5b7f2678 100644 --- a/pandas/core/arrays/categorical.py +++ b/pandas/core/arrays/categorical.py @@ -2118,6 +2118,8 @@ def min(self, skipna=True, **kwargs): pointer = self._codes.min() return self.categories[pointer] + amin = min + @deprecate_kwarg(old_arg_name="numeric_only", new_arg_name="skipna") def max(self, skipna=True, **kwargs): """ @@ -2154,6 +2156,8 @@ def max(self, skipna=True, **kwargs): pointer = self._codes.max() return self.categories[pointer] + amax = max + def mode(self, dropna=True): """ Returns the mode(s) of the Categorical. diff --git a/pandas/core/arrays/datetimelike.py b/pandas/core/arrays/datetimelike.py index a306268cd8ede..3c648d884fea3 100644 --- a/pandas/core/arrays/datetimelike.py +++ b/pandas/core/arrays/datetimelike.py @@ -1579,6 +1579,8 @@ def min(self, axis=None, skipna=True, *args, **kwargs): return NaT return self._box_func(result) + amin = min + def max(self, axis=None, skipna=True, *args, **kwargs): """ Return the maximum value of the Array or maximum along @@ -1611,6 +1613,8 @@ def max(self, axis=None, skipna=True, *args, **kwargs): # Don't have to worry about NA `result`, since no NA went in. return self._box_func(result) + amax = max + def mean(self, skipna=True): """ Return the mean value of the Array. diff --git a/pandas/core/arrays/numpy_.py b/pandas/core/arrays/numpy_.py index f6dfb1f0f1e62..76b9397f0ff5a 100644 --- a/pandas/core/arrays/numpy_.py +++ b/pandas/core/arrays/numpy_.py @@ -1,5 +1,5 @@ import numbers -from typing import Optional, Tuple, Type, Union +from typing import Optional, Sequence, Tuple, Type, Union import numpy as np from numpy.lib.mixins import NDArrayOperatorsMixin @@ -191,7 +191,8 @@ def _from_factorized(cls, values, original) -> "PandasArray": return cls(values) @classmethod - def _concat_same_type(cls, to_concat) -> "PandasArray": + def _concat_same_type(cls, to_concat: Sequence["PandasArray"]) -> "PandasArray": + to_concat = [arr.to_numpy() for arr in to_concat] return cls(np.concatenate(to_concat)) def _from_backing_data(self, arr: np.ndarray) -> "PandasArray": @@ -347,6 +348,8 @@ def min(self, skipna: bool = True, **kwargs) -> Scalar: ) return result + amin = min + def max(self, skipna: bool = True, **kwargs) -> Scalar: nv.validate_max((), kwargs) result = masked_reductions.max( @@ -354,6 +357,8 @@ def max(self, skipna: bool = True, **kwargs) -> Scalar: ) return result + amax = max + def sum(self, axis=None, skipna=True, min_count=0, **kwargs) -> Scalar: nv.validate_sum((), kwargs) return nanops.nansum( diff --git a/pandas/core/arrays/sparse/array.py b/pandas/core/arrays/sparse/array.py index a25bdea133d02..9320b4a3d8209 100644 --- a/pandas/core/arrays/sparse/array.py +++ b/pandas/core/arrays/sparse/array.py @@ -757,6 +757,17 @@ def value_counts(self, dropna=True): result = Series(counts, index=keys) return result + def _tile_1d(self, rep: int) -> "ExtensionArray": + """ + non-public method for compatibility with np.tile + """ + sp_values = np.tile(self.sp_values, rep) + index = self.sp_index.indices + for i in range(1, rep): + index = np.append(index, self.sp_index.indices + (i * len(self))) + sp_index = IntIndex(len(self) * rep, index) + return self._simple_new(sp_values, sp_index, self.dtype) + # -------- # Indexing # -------- @@ -1379,21 +1390,6 @@ def __array_ufunc__(self, ufunc, method, *inputs, **kwargs): def __abs__(self): return np.abs(self) - # ------------------------------------------------------------------------ - # Numpy Array Functions - # ------------------------------------------------------------------------ - - def _tile(self, rep: int) -> "ExtensionArray": - """ - non-public method for compatibility with np.tile - """ - sp_values = np.tile(self.sp_values, rep) - index = self.sp_index.indices - for i in range(1, rep): - index = np.append(index, self.sp_index.indices + (i * len(self))) - sp_index = IntIndex(len(self) * rep, index) - return self._simple_new(sp_values, sp_index, self.dtype) - # ------------------------------------------------------------------------ # Ops # ------------------------------------------------------------------------ diff --git a/pandas/core/dtypes/concat.py b/pandas/core/dtypes/concat.py index 4b7c818f487ac..38701c377302b 100644 --- a/pandas/core/dtypes/concat.py +++ b/pandas/core/dtypes/concat.py @@ -22,6 +22,7 @@ from pandas.core.dtypes.generic import ABCCategoricalIndex, ABCRangeIndex, ABCSeries from pandas.core.arrays import ExtensionArray +from pandas.core.arrays.base import implements from pandas.core.arrays.sparse import SparseArray from pandas.core.construction import array @@ -107,6 +108,7 @@ def _cast_to_common_type(arr: ArrayLike, dtype: DtypeObj) -> ArrayLike: return arr.astype(dtype, copy=False) +@implements(np.concatenate) def concat_compat(to_concat, axis: int = 0): """ provide concatenation of an array of arrays each of which is a single @@ -152,11 +154,15 @@ def is_nonempty(x) -> bool: target_dtype = find_common_type([x.dtype for x in to_concat]) to_concat = [_cast_to_common_type(arr, target_dtype) for arr in to_concat] - if isinstance(to_concat[0], ExtensionArray): + if isinstance(to_concat[0], ExtensionArray) and axis == 0: cls = type(to_concat[0]) return cls._concat_same_type(to_concat) else: - return np.concatenate(to_concat) + to_concat = [ + arr.to_numpy() if isinstance(arr, ExtensionArray) else arr + for arr in to_concat + ] + return np.concatenate(to_concat, axis=axis) elif _contains_datetime or "timedelta" in typs: return concat_datetime(to_concat, axis=axis, typs=typs) diff --git a/pandas/core/internals/concat.py b/pandas/core/internals/concat.py index fd8c5f5e27c02..1d48516906084 100644 --- a/pandas/core/internals/concat.py +++ b/pandas/core/internals/concat.py @@ -323,7 +323,7 @@ def _concatenate_join_units(join_units, concat_axis, copy): # concatting with at least one EA means we are concatting a single column # the non-EA values are 2D arrays with shape (1, n) to_concat = [t if isinstance(t, ExtensionArray) else t[0, :] for t in to_concat] - concat_values = concat_compat(to_concat, axis=concat_axis) + concat_values = concat_compat(to_concat, axis=0) if not isinstance(concat_values, ExtensionArray): # if the result of concat is not an EA but an ndarray, reshape to # 2D to put it a non-EA Block diff --git a/pandas/tests/extension/base/__init__.py b/pandas/tests/extension/base/__init__.py index 7171014cf0693..97a2763f5ab7b 100644 --- a/pandas/tests/extension/base/__init__.py +++ b/pandas/tests/extension/base/__init__.py @@ -50,6 +50,7 @@ class TestMyDtype(BaseDtypeTests): from .io import BaseParsingTests # noqa from .methods import BaseMethodsTests # noqa from .missing import BaseMissingTests # noqa +from .numpy_array_functions import BaseNumpyArrayFunctionTests # noqa from .ops import ( # noqa BaseArithmeticOpsTests, BaseComparisonOpsTests, @@ -64,4 +65,3 @@ class TestMyDtype(BaseDtypeTests): ) from .reshaping import BaseReshapingTests # noqa from .setitem import BaseSetitemTests # noqa -from .numpy_array_functions import BaseNumpyArrayFunctionTests # noqa diff --git a/pandas/tests/extension/base/base.py b/pandas/tests/extension/base/base.py index 97d8e7c66dbdb..67e7c1809da7a 100644 --- a/pandas/tests/extension/base/base.py +++ b/pandas/tests/extension/base/base.py @@ -19,3 +19,7 @@ def assert_frame_equal(cls, left, right, *args, **kwargs): @classmethod def assert_extension_array_equal(cls, left, right, *args, **kwargs): return tm.assert_extension_array_equal(left, right, *args, **kwargs) + + @classmethod + def assert_numpy_array_equal(cls, left, right, *args, **kwargs): + return tm.assert_numpy_array_equal(left, right, *args, **kwargs) diff --git a/pandas/tests/extension/base/numpy_array_functions.py b/pandas/tests/extension/base/numpy_array_functions.py index 4b852a3ac1099..5625f71b0f82e 100644 --- a/pandas/tests/extension/base/numpy_array_functions.py +++ b/pandas/tests/extension/base/numpy_array_functions.py @@ -1,4 +1,5 @@ import numpy as np +import pytest import pandas.util._test_decorators as td @@ -7,9 +8,34 @@ from .base import BaseExtensionTests +@td.skip_if_no_nep18 class BaseNumpyArrayFunctionTests(BaseExtensionTests): - @td.skip_if_no_nep18 def test_tile(self, data): expected = pd.array(list(data) * 3, dtype=data.dtype) + result = np.tile(data, 3) self.assert_extension_array_equal(result, expected) + + result = np.tile(data, (3,)) + self.assert_extension_array_equal(result, expected) + + expected = np.array([data.to_numpy()] * 3) + + result = np.tile(data, (3, 1)) + self.assert_numpy_array_equal(result, expected) + + def test_concatenate(self, data): + expected = pd.array(list(data) * 3, dtype=data.dtype) + + result = np.concatenate([data] * 3) + self.assert_extension_array_equal(result, expected) + + result = np.concatenate([data] * 3, axis=0) + self.assert_extension_array_equal(result, expected) + + msg = "axis 1 is out of bounds for array of dimension 1" + with pytest.raises(np.AxisError, match=msg): + np.concatenate([data] * 3, axis=1) + + def test_ndim(self, data): + assert np.ndim(data) == 1 diff --git a/pandas/tests/extension/test_numpy.py b/pandas/tests/extension/test_numpy.py index 3c230742a1eae..bf9be949c5cd6 100644 --- a/pandas/tests/extension/test_numpy.py +++ b/pandas/tests/extension/test_numpy.py @@ -504,8 +504,6 @@ class TestParsing(BaseNumPyTests, base.BaseParsingTests): pass +@skip_nested class TestNumpyArrayFunctions(base.BaseNumpyArrayFunctionTests): - def test_tile(self, data, dtype): - # TODO: maybe rethink the validation - # raises ValueError: PandasArray must be 1-dimensional. - pass + pass From cd161add01f4ff9b8f54c6454a90deb9af86d82a Mon Sep 17 00:00:00 2001 From: Simon Hawkins Date: Sat, 27 Jun 2020 20:29:15 +0100 Subject: [PATCH 05/21] mypy fixup --- pandas/core/arrays/base.py | 12 +++++------- pandas/core/arrays/numpy_.py | 4 ++-- 2 files changed, 7 insertions(+), 9 deletions(-) diff --git a/pandas/core/arrays/base.py b/pandas/core/arrays/base.py index 040a73ee65acf..99c2b43d36ae1 100644 --- a/pandas/core/arrays/base.py +++ b/pandas/core/arrays/base.py @@ -27,6 +27,7 @@ from pandas.core import ops from pandas.core.algorithms import _factorize_array, unique +import pandas.core.common as com from pandas.core.missing import backfill_1d, pad_1d from pandas.core.sorting import nargsort @@ -1328,13 +1329,10 @@ def tile(arr: ExtensionArray, reps: Union[int, AnyArrayLike]) -> ArrayLike: """ Construct an array by repeating array the number of times given by reps. """ - try: - tup = tuple(reps) - except TypeError: - tup = (reps,) - if len(tup) == 1: - return arr._tile_1d(tup[0]) - return np.tile(arr.to_numpy(), tup) + reps_list = com.convert_to_list_like(reps) + if len(reps_list) == 1: + return arr._tile_1d(reps_list[0]) + return np.tile(arr.to_numpy(), reps) @implements(np.ndim) diff --git a/pandas/core/arrays/numpy_.py b/pandas/core/arrays/numpy_.py index 76b9397f0ff5a..c935dd37ecddf 100644 --- a/pandas/core/arrays/numpy_.py +++ b/pandas/core/arrays/numpy_.py @@ -1,5 +1,5 @@ import numbers -from typing import Optional, Sequence, Tuple, Type, Union +from typing import Optional, Tuple, Type, Union import numpy as np from numpy.lib.mixins import NDArrayOperatorsMixin @@ -191,7 +191,7 @@ def _from_factorized(cls, values, original) -> "PandasArray": return cls(values) @classmethod - def _concat_same_type(cls, to_concat: Sequence["PandasArray"]) -> "PandasArray": + def _concat_same_type(cls, to_concat) -> "PandasArray": to_concat = [arr.to_numpy() for arr in to_concat] return cls(np.concatenate(to_concat)) From 5ca65e1f6081ff098898b24178ecfed00998533c Mon Sep 17 00:00:00 2001 From: Simon Hawkins Date: Sun, 28 Jun 2020 13:04:28 +0100 Subject: [PATCH 06/21] add np.vstack (dispatches back to NumPy) --- pandas/core/arrays/base.py | 12 ++++++++++++ pandas/tests/extension/base/numpy_array_functions.py | 5 +++++ 2 files changed, 17 insertions(+) diff --git a/pandas/core/arrays/base.py b/pandas/core/arrays/base.py index 99c2b43d36ae1..c8766de8ab1f3 100644 --- a/pandas/core/arrays/base.py +++ b/pandas/core/arrays/base.py @@ -1341,3 +1341,15 @@ def ndim(array: ExtensionArray) -> int: Return the number of dimensions of an array. """ return array.ndim + + +@implements(np.vstack) +def vstack(to_stack: Sequence[ArrayLike]) -> np.ndarray: + """ + Stack arrays in sequence vertically (row wise). + """ + to_stack = tuple( + arr.to_numpy() if isinstance(arr, ExtensionArray) else arr + for arr in to_stack + ) + return np.vstack(to_stack) diff --git a/pandas/tests/extension/base/numpy_array_functions.py b/pandas/tests/extension/base/numpy_array_functions.py index 5625f71b0f82e..82d40ec9fce31 100644 --- a/pandas/tests/extension/base/numpy_array_functions.py +++ b/pandas/tests/extension/base/numpy_array_functions.py @@ -39,3 +39,8 @@ def test_concatenate(self, data): def test_ndim(self, data): assert np.ndim(data) == 1 + + def test_vstack(self, data): + expected = np.array([data.to_numpy()] * 2) + result = np.vstack([data, data]) + self.assert_numpy_array_equal(result, expected) From f34724378362fa777bd49de1ca13879624dfde9e Mon Sep 17 00:00:00 2001 From: Simon Hawkins Date: Sun, 28 Jun 2020 13:12:49 +0100 Subject: [PATCH 07/21] remove np.tile implementation for now to reduce diff --- pandas/core/arrays/base.py | 14 +------------- pandas/core/arrays/sparse/array.py | 11 ----------- pandas/core/reshape/melt.py | 3 +-- .../tests/extension/base/numpy_array_functions.py | 14 -------------- 4 files changed, 2 insertions(+), 40 deletions(-) diff --git a/pandas/core/arrays/base.py b/pandas/core/arrays/base.py index c8766de8ab1f3..a27177f964aae 100644 --- a/pandas/core/arrays/base.py +++ b/pandas/core/arrays/base.py @@ -12,7 +12,7 @@ import numpy as np from pandas._libs import lib -from pandas._typing import AnyArrayLike, ArrayLike, F +from pandas._typing import ArrayLike, F from pandas.compat import set_function_name from pandas.compat.numpy import function as nv from pandas.errors import AbstractMethodError @@ -27,7 +27,6 @@ from pandas.core import ops from pandas.core.algorithms import _factorize_array, unique -import pandas.core.common as com from pandas.core.missing import backfill_1d, pad_1d from pandas.core.sorting import nargsort @@ -1324,17 +1323,6 @@ def _create_comparison_method(cls, op): return cls._create_method(op, coerce_to_dtype=False, result_dtype=bool) -@implements(np.tile) -def tile(arr: ExtensionArray, reps: Union[int, AnyArrayLike]) -> ArrayLike: - """ - Construct an array by repeating array the number of times given by reps. - """ - reps_list = com.convert_to_list_like(reps) - if len(reps_list) == 1: - return arr._tile_1d(reps_list[0]) - return np.tile(arr.to_numpy(), reps) - - @implements(np.ndim) def ndim(array: ExtensionArray) -> int: """ diff --git a/pandas/core/arrays/sparse/array.py b/pandas/core/arrays/sparse/array.py index 9320b4a3d8209..4996a10002c63 100644 --- a/pandas/core/arrays/sparse/array.py +++ b/pandas/core/arrays/sparse/array.py @@ -757,17 +757,6 @@ def value_counts(self, dropna=True): result = Series(counts, index=keys) return result - def _tile_1d(self, rep: int) -> "ExtensionArray": - """ - non-public method for compatibility with np.tile - """ - sp_values = np.tile(self.sp_values, rep) - index = self.sp_index.indices - for i in range(1, rep): - index = np.append(index, self.sp_index.indices + (i * len(self))) - sp_index = IntIndex(len(self) * rep, index) - return self._simple_new(sp_values, sp_index, self.dtype) - # -------- # Indexing # -------- diff --git a/pandas/core/reshape/melt.py b/pandas/core/reshape/melt.py index 3fe1bbaf53148..cd0619738677d 100644 --- a/pandas/core/reshape/melt.py +++ b/pandas/core/reshape/melt.py @@ -3,7 +3,6 @@ import numpy as np -from pandas.compat.numpy import IS_NEP18_ACTIVE from pandas.util._decorators import Appender, deprecate_kwarg from pandas.core.dtypes.common import is_extension_array_dtype, is_list_like @@ -109,7 +108,7 @@ def melt( mdata = {} for col in id_vars: id_data = frame.pop(col) - if not IS_NEP18_ACTIVE and is_extension_array_dtype(id_data): + if is_extension_array_dtype(id_data): id_data = cast("Series", concat([id_data] * K, ignore_index=True)) else: id_data = np.tile(id_data._values, K) diff --git a/pandas/tests/extension/base/numpy_array_functions.py b/pandas/tests/extension/base/numpy_array_functions.py index 82d40ec9fce31..f1d38b246bf95 100644 --- a/pandas/tests/extension/base/numpy_array_functions.py +++ b/pandas/tests/extension/base/numpy_array_functions.py @@ -10,20 +10,6 @@ @td.skip_if_no_nep18 class BaseNumpyArrayFunctionTests(BaseExtensionTests): - def test_tile(self, data): - expected = pd.array(list(data) * 3, dtype=data.dtype) - - result = np.tile(data, 3) - self.assert_extension_array_equal(result, expected) - - result = np.tile(data, (3,)) - self.assert_extension_array_equal(result, expected) - - expected = np.array([data.to_numpy()] * 3) - - result = np.tile(data, (3, 1)) - self.assert_numpy_array_equal(result, expected) - def test_concatenate(self, data): expected = pd.array(list(data) * 3, dtype=data.dtype) From 009ac61c6c0a15ba1010c00ac5dea7f84b65530f Mon Sep 17 00:00:00 2001 From: Simon Hawkins Date: Sun, 28 Jun 2020 13:18:01 +0100 Subject: [PATCH 08/21] and remove _tile_1d --- pandas/core/arrays/base.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/pandas/core/arrays/base.py b/pandas/core/arrays/base.py index a27177f964aae..1b70457491bad 100644 --- a/pandas/core/arrays/base.py +++ b/pandas/core/arrays/base.py @@ -898,13 +898,6 @@ def repeat(self, repeats, axis=None): ind = np.arange(len(self)).repeat(repeats) return self.take(ind) - def _tile_1d(self, rep: int) -> "ExtensionArray": - """ - non-public method for compatibility with np.tile - """ - ind = np.tile(np.arange(len(self)), rep) - return self.take(ind) - # ------------------------------------------------------------------------ # Indexing methods # ------------------------------------------------------------------------ From 584ccf27e8336a2190002913c653ced3cb655976 Mon Sep 17 00:00:00 2001 From: Simon Hawkins Date: Sun, 28 Jun 2020 14:28:32 +0100 Subject: [PATCH 09/21] don't dispatch to EA.unique from np.unique --- pandas/core/arrays/base.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/pandas/core/arrays/base.py b/pandas/core/arrays/base.py index 1b70457491bad..b47058f2af07a 100644 --- a/pandas/core/arrays/base.py +++ b/pandas/core/arrays/base.py @@ -185,7 +185,9 @@ def __array_function__(self, func, types, args, kwargs): # be dealing with an alias or a function that's simply not in the # ExtensionArray API. Handle aliases via the _HANDLED_FUNCTIONS # dict mapping. - if not hasattr(type(self), func.__name__): + exclude_list = {"unique"} + ea_func = getattr(type(self), func.__name__, None) + if ea_func is None or ea_func.__name__ in exclude_list: # Need to convert EAs to numpy.ndarray so we can call the NumPy # function again and it gets the chance to dispatch to the # right implementation. @@ -195,8 +197,7 @@ def __array_function__(self, func, types, args, kwargs): ) return func(*args, **kwargs) - func = getattr(type(self), func.__name__) - return func(*args, **kwargs) + return ea_func(*args, **kwargs) return _HANDLED_FUNCTIONS[func](*args, **kwargs) From 72e447718868269606fb57e65fcbfbf2e0204570 Mon Sep 17 00:00:00 2001 From: Simon Hawkins Date: Sun, 28 Jun 2020 14:29:48 +0100 Subject: [PATCH 10/21] lint fixup --- pandas/core/arrays/base.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pandas/core/arrays/base.py b/pandas/core/arrays/base.py index b47058f2af07a..b6e0828ac81f5 100644 --- a/pandas/core/arrays/base.py +++ b/pandas/core/arrays/base.py @@ -1331,7 +1331,6 @@ def vstack(to_stack: Sequence[ArrayLike]) -> np.ndarray: Stack arrays in sequence vertically (row wise). """ to_stack = tuple( - arr.to_numpy() if isinstance(arr, ExtensionArray) else arr - for arr in to_stack + arr.to_numpy() if isinstance(arr, ExtensionArray) else arr for arr in to_stack ) return np.vstack(to_stack) From 807aae82b0e96299c167ceffe0ec6a3e922c7c03 Mon Sep 17 00:00:00 2001 From: Simon Hawkins Date: Sun, 28 Jun 2020 17:41:16 +0100 Subject: [PATCH 11/21] fix test_fillna_null for IntervalIndex --- pandas/core/arrays/base.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/pandas/core/arrays/base.py b/pandas/core/arrays/base.py index b6e0828ac81f5..05d90574aac7c 100644 --- a/pandas/core/arrays/base.py +++ b/pandas/core/arrays/base.py @@ -1334,3 +1334,18 @@ def vstack(to_stack: Sequence[ArrayLike]) -> np.ndarray: arr.to_numpy() if isinstance(arr, ExtensionArray) else arr for arr in to_stack ) return np.vstack(to_stack) + + +@implements(np.putmask) +def putmask(a: ArrayLike, mask: ArrayLike, values: ArrayLike) -> None: + """ + Changes elements of an array based on conditional and input values. + """ + # TODO: refactor Index.putmask to not rely on this behaviour for IntervalArray + if isinstance(a, ExtensionArray): + raise TypeError( + f"putmask() argument 1 must be numpy.ndarray, not {type(a).__name__}" + ) + mask = mask.to_numpy() if isinstance(mask, ExtensionArray) else mask + values = values.to_numpy() if isinstance(values, ExtensionArray) else values + return np.putmask(a, mask, values) From ffab3827ef12370eb50a301f87cb29271816d34c Mon Sep 17 00:00:00 2001 From: Simon Hawkins Date: Mon, 29 Jun 2020 09:40:48 +0100 Subject: [PATCH 12/21] test from 35038 to get ci to green --- pandas/tests/reshape/test_concat.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/pandas/tests/reshape/test_concat.py b/pandas/tests/reshape/test_concat.py index ffeb5ff0f8aaa..be96a45c04299 100644 --- a/pandas/tests/reshape/test_concat.py +++ b/pandas/tests/reshape/test_concat.py @@ -1088,19 +1088,14 @@ def test_append_empty_frame_to_series_with_dateutil_tz(self): s = Series({"date": date, "a": 1.0, "b": 2.0}) df = DataFrame(columns=["c", "d"]) result = df.append(s, ignore_index=True) - # n.b. it's not clear to me that expected is correct here. - # It's possible that the `date` column should have - # datetime64[ns, tz] dtype for both result and expected. - # that would be more consistent with new columns having - # their own dtype (float for a and b, datetime64ns, tz for date). expected = DataFrame( [[np.nan, np.nan, 1.0, 2.0, date]], columns=["c", "d", "a", "b", "date"], dtype=object, ) - # These columns get cast to object after append expected["a"] = expected["a"].astype(float) expected["b"] = expected["b"].astype(float) + expected["date"] = pd.to_datetime(expected["date"]) tm.assert_frame_equal(result, expected) From 025567134c84c3707f55774080bf51a0d1407eb8 Mon Sep 17 00:00:00 2001 From: Simon Hawkins Date: Mon, 29 Jun 2020 13:44:01 +0100 Subject: [PATCH 13/21] add ArrayFunctionMixin --- pandas/core/arrays/__init__.py | 2 ++ pandas/core/arrays/base.py | 46 +++++++++++++++------------- pandas/core/arrays/categorical.py | 3 +- pandas/core/arrays/datetimelike.py | 8 +++-- pandas/core/arrays/interval.py | 8 +++-- pandas/core/arrays/masked.py | 6 ++-- pandas/core/arrays/sparse/array.py | 4 +-- pandas/core/arrays/string_.py | 4 +-- pandas/tests/extension/test_numpy.py | 5 --- 9 files changed, 47 insertions(+), 39 deletions(-) diff --git a/pandas/core/arrays/__init__.py b/pandas/core/arrays/__init__.py index 1d538824e6d82..136345b8345c5 100644 --- a/pandas/core/arrays/__init__.py +++ b/pandas/core/arrays/__init__.py @@ -1,4 +1,5 @@ from pandas.core.arrays.base import ( + ArrayFunctionMixin, ExtensionArray, ExtensionOpsMixin, ExtensionScalarOpsMixin, @@ -15,6 +16,7 @@ from pandas.core.arrays.timedeltas import TimedeltaArray __all__ = [ + "ArrayFunctionMixin", "ExtensionArray", "ExtensionOpsMixin", "ExtensionScalarOpsMixin", diff --git a/pandas/core/arrays/base.py b/pandas/core/arrays/base.py index 05d90574aac7c..13d681af74c06 100644 --- a/pandas/core/arrays/base.py +++ b/pandas/core/arrays/base.py @@ -47,6 +47,30 @@ def decorator(func): return decorator +class ArrayFunctionMixin: + def __array_function__(self, func, types, args, kwargs): + if func not in _HANDLED_FUNCTIONS: + # try to find a matching method name. If that doesn't work, we may + # be dealing with an alias or a function that's simply not in the + # ExtensionArray API. Handle aliases via the _HANDLED_FUNCTIONS + # dict mapping. + exclude_list = {"unique"} + ea_func = getattr(type(self), func.__name__, None) + if ea_func is None or ea_func.__name__ in exclude_list: + # Need to convert EAs to numpy.ndarray so we can call the NumPy + # function again and it gets the chance to dispatch to the + # right implementation. + args = tuple( + arg.to_numpy() if isinstance(arg, ExtensionArray) else arg + for arg in args + ) + return func(*args, **kwargs) + + return ea_func(*args, **kwargs) + + return _HANDLED_FUNCTIONS[func](*args, **kwargs) + + class ExtensionArray: """ Abstract base class for custom 1-D array types. @@ -179,28 +203,6 @@ class ExtensionArray: # Don't override this. _typ = "extension" - def __array_function__(self, func, types, args, kwargs): - if func not in _HANDLED_FUNCTIONS: - # try to find a matching method name. If that doesn't work, we may - # be dealing with an alias or a function that's simply not in the - # ExtensionArray API. Handle aliases via the _HANDLED_FUNCTIONS - # dict mapping. - exclude_list = {"unique"} - ea_func = getattr(type(self), func.__name__, None) - if ea_func is None or ea_func.__name__ in exclude_list: - # Need to convert EAs to numpy.ndarray so we can call the NumPy - # function again and it gets the chance to dispatch to the - # right implementation. - args = tuple( - arg.to_numpy() if isinstance(arg, ExtensionArray) else arg - for arg in args - ) - return func(*args, **kwargs) - - return ea_func(*args, **kwargs) - - return _HANDLED_FUNCTIONS[func](*args, **kwargs) - # ------------------------------------------------------------------------ # Constructors # ------------------------------------------------------------------------ diff --git a/pandas/core/arrays/categorical.py b/pandas/core/arrays/categorical.py index 5f72c5b7f2678..28aaf83dec97e 100644 --- a/pandas/core/arrays/categorical.py +++ b/pandas/core/arrays/categorical.py @@ -45,6 +45,7 @@ import pandas.core.algorithms as algorithms from pandas.core.algorithms import _get_data_algo, factorize, take_1d, unique1d from pandas.core.array_algos.transforms import shift +from pandas.core.arrays import ArrayFunctionMixin from pandas.core.arrays._mixins import _T, NDArrayBackedExtensionArray from pandas.core.base import ( ExtensionArray, @@ -201,7 +202,7 @@ def contains(cat, key, container): return any(loc_ in container for loc_ in loc) -class Categorical(NDArrayBackedExtensionArray, PandasObject): +class Categorical(NDArrayBackedExtensionArray, PandasObject, ArrayFunctionMixin): """ Represent a categorical variable in classic R / S-plus fashion. diff --git a/pandas/core/arrays/datetimelike.py b/pandas/core/arrays/datetimelike.py index 3c648d884fea3..5626c3656b0a6 100644 --- a/pandas/core/arrays/datetimelike.py +++ b/pandas/core/arrays/datetimelike.py @@ -56,7 +56,11 @@ from pandas.core.algorithms import checked_add_with_arr, unique1d, value_counts from pandas.core.array_algos.transforms import shift from pandas.core.arrays._mixins import _T, NDArrayBackedExtensionArray -from pandas.core.arrays.base import ExtensionArray, ExtensionOpsMixin +from pandas.core.arrays.base import ( + ArrayFunctionMixin, + ExtensionArray, + ExtensionOpsMixin, +) import pandas.core.common as com from pandas.core.construction import array, extract_array from pandas.core.indexers import check_array_indexer @@ -440,7 +444,7 @@ def _with_freq(self, freq): class DatetimeLikeArrayMixin( - ExtensionOpsMixin, AttributesMixin, NDArrayBackedExtensionArray + ExtensionOpsMixin, AttributesMixin, NDArrayBackedExtensionArray, ArrayFunctionMixin ): """ Shared Base/Mixin class for DatetimeArray, TimedeltaArray, PeriodArray diff --git a/pandas/core/arrays/interval.py b/pandas/core/arrays/interval.py index c861d25afd13f..b48b5bf07b883 100644 --- a/pandas/core/arrays/interval.py +++ b/pandas/core/arrays/interval.py @@ -35,7 +35,11 @@ from pandas.core.dtypes.missing import isna, notna from pandas.core.algorithms import take, value_counts -from pandas.core.arrays.base import ExtensionArray, _extension_array_shared_docs +from pandas.core.arrays.base import ( + ArrayFunctionMixin, + ExtensionArray, + _extension_array_shared_docs, +) from pandas.core.arrays.categorical import Categorical import pandas.core.common as com from pandas.core.construction import array @@ -143,7 +147,7 @@ ), ) ) -class IntervalArray(IntervalMixin, ExtensionArray): +class IntervalArray(IntervalMixin, ExtensionArray, ArrayFunctionMixin): ndim = 1 can_hold_na = True _na_value = _fill_value = np.nan diff --git a/pandas/core/arrays/masked.py b/pandas/core/arrays/masked.py index 28add129825d1..8f260b4444f05 100644 --- a/pandas/core/arrays/masked.py +++ b/pandas/core/arrays/masked.py @@ -19,7 +19,7 @@ from pandas.core import nanops from pandas.core.algorithms import _factorize_array, take from pandas.core.array_algos import masked_reductions -from pandas.core.arrays import ExtensionArray, ExtensionOpsMixin +from pandas.core.arrays import ArrayFunctionMixin, ExtensionArray, ExtensionOpsMixin from pandas.core.indexers import check_array_indexer if TYPE_CHECKING: @@ -29,7 +29,7 @@ BaseMaskedArrayT = TypeVar("BaseMaskedArrayT", bound="BaseMaskedArray") -class BaseMaskedDtype(ExtensionDtype): +class BaseMaskedDtype(ExtensionDtype, ArrayFunctionMixin): """ Base class for dtypes for BasedMaskedArray subclasses. """ @@ -41,7 +41,7 @@ def numpy_dtype(self) -> np.dtype: raise AbstractMethodError -class BaseMaskedArray(ExtensionArray, ExtensionOpsMixin): +class BaseMaskedArray(ExtensionArray, ExtensionOpsMixin, ArrayFunctionMixin): """ Base class for masked arrays (which use _data and _mask to store the data). diff --git a/pandas/core/arrays/sparse/array.py b/pandas/core/arrays/sparse/array.py index 4996a10002c63..4779f5a00c6d4 100644 --- a/pandas/core/arrays/sparse/array.py +++ b/pandas/core/arrays/sparse/array.py @@ -40,7 +40,7 @@ from pandas.core.dtypes.missing import isna, na_value_for_dtype, notna import pandas.core.algorithms as algos -from pandas.core.arrays import ExtensionArray, ExtensionOpsMixin +from pandas.core.arrays import ArrayFunctionMixin, ExtensionArray, ExtensionOpsMixin from pandas.core.arrays.sparse.dtype import SparseDtype from pandas.core.base import PandasObject import pandas.core.common as com @@ -195,7 +195,7 @@ def _wrap_result(name, data, sparse_index, fill_value, dtype=None): ) -class SparseArray(PandasObject, ExtensionArray, ExtensionOpsMixin): +class SparseArray(PandasObject, ExtensionArray, ExtensionOpsMixin, ArrayFunctionMixin): """ An ExtensionArray for storing sparse data. diff --git a/pandas/core/arrays/string_.py b/pandas/core/arrays/string_.py index ac501a8afbe09..d808668da16b3 100644 --- a/pandas/core/arrays/string_.py +++ b/pandas/core/arrays/string_.py @@ -13,7 +13,7 @@ from pandas import compat from pandas.core import ops -from pandas.core.arrays import IntegerArray, PandasArray +from pandas.core.arrays import ArrayFunctionMixin, IntegerArray, PandasArray from pandas.core.arrays.integer import _IntegerDtype from pandas.core.construction import extract_array from pandas.core.indexers import check_array_indexer @@ -98,7 +98,7 @@ def __from_arrow__( return StringArray._concat_same_type(results) -class StringArray(PandasArray): +class StringArray(PandasArray, ArrayFunctionMixin): """ Extension array for string data. diff --git a/pandas/tests/extension/test_numpy.py b/pandas/tests/extension/test_numpy.py index bf9be949c5cd6..78000c0252375 100644 --- a/pandas/tests/extension/test_numpy.py +++ b/pandas/tests/extension/test_numpy.py @@ -502,8 +502,3 @@ def test_setitem_loc_iloc_slice(self, data): @skip_nested class TestParsing(BaseNumPyTests, base.BaseParsingTests): pass - - -@skip_nested -class TestNumpyArrayFunctions(base.BaseNumpyArrayFunctionTests): - pass From 7b379f40ab2440f499e4a17223353012a402e2f5 Mon Sep 17 00:00:00 2001 From: Simon Hawkins Date: Mon, 29 Jun 2020 13:48:43 +0100 Subject: [PATCH 14/21] remove ArrayFunctionMixin unintentially added to BaseMaskedDtype --- pandas/core/arrays/masked.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pandas/core/arrays/masked.py b/pandas/core/arrays/masked.py index 8f260b4444f05..4d2c4d2c44906 100644 --- a/pandas/core/arrays/masked.py +++ b/pandas/core/arrays/masked.py @@ -29,7 +29,7 @@ BaseMaskedArrayT = TypeVar("BaseMaskedArrayT", bound="BaseMaskedArray") -class BaseMaskedDtype(ExtensionDtype, ArrayFunctionMixin): +class BaseMaskedDtype(ExtensionDtype): """ Base class for dtypes for BasedMaskedArray subclasses. """ From 96db577b4f9d6f48d26f7429b836f1a70d43523e Mon Sep 17 00:00:00 2001 From: Simon Hawkins Date: Mon, 29 Jun 2020 13:53:40 +0100 Subject: [PATCH 15/21] revert changes to pandas/core/arrays/numpy_.py --- pandas/core/arrays/numpy_.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/pandas/core/arrays/numpy_.py b/pandas/core/arrays/numpy_.py index c935dd37ecddf..f6dfb1f0f1e62 100644 --- a/pandas/core/arrays/numpy_.py +++ b/pandas/core/arrays/numpy_.py @@ -192,7 +192,6 @@ def _from_factorized(cls, values, original) -> "PandasArray": @classmethod def _concat_same_type(cls, to_concat) -> "PandasArray": - to_concat = [arr.to_numpy() for arr in to_concat] return cls(np.concatenate(to_concat)) def _from_backing_data(self, arr: np.ndarray) -> "PandasArray": @@ -348,8 +347,6 @@ def min(self, skipna: bool = True, **kwargs) -> Scalar: ) return result - amin = min - def max(self, skipna: bool = True, **kwargs) -> Scalar: nv.validate_max((), kwargs) result = masked_reductions.max( @@ -357,8 +354,6 @@ def max(self, skipna: bool = True, **kwargs) -> Scalar: ) return result - amax = max - def sum(self, axis=None, skipna=True, min_count=0, **kwargs) -> Scalar: nv.validate_sum((), kwargs) return nanops.nansum( From d8b1a8ba9f46d897f303dc1f811a82e3414c70f3 Mon Sep 17 00:00:00 2001 From: Simon Hawkins Date: Mon, 29 Jun 2020 14:57:24 +0100 Subject: [PATCH 16/21] Revert "revert changes to pandas/core/arrays/numpy_.py" This reverts commit 96db577b4f9d6f48d26f7429b836f1a70d43523e. --- pandas/core/arrays/numpy_.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/pandas/core/arrays/numpy_.py b/pandas/core/arrays/numpy_.py index f6dfb1f0f1e62..c935dd37ecddf 100644 --- a/pandas/core/arrays/numpy_.py +++ b/pandas/core/arrays/numpy_.py @@ -192,6 +192,7 @@ def _from_factorized(cls, values, original) -> "PandasArray": @classmethod def _concat_same_type(cls, to_concat) -> "PandasArray": + to_concat = [arr.to_numpy() for arr in to_concat] return cls(np.concatenate(to_concat)) def _from_backing_data(self, arr: np.ndarray) -> "PandasArray": @@ -347,6 +348,8 @@ def min(self, skipna: bool = True, **kwargs) -> Scalar: ) return result + amin = min + def max(self, skipna: bool = True, **kwargs) -> Scalar: nv.validate_max((), kwargs) result = masked_reductions.max( @@ -354,6 +357,8 @@ def max(self, skipna: bool = True, **kwargs) -> Scalar: ) return result + amax = max + def sum(self, axis=None, skipna=True, min_count=0, **kwargs) -> Scalar: nv.validate_sum((), kwargs) return nanops.nansum( From 61c53f5c1f731013b22be3515d0f64b4a1c65188 Mon Sep 17 00:00:00 2001 From: Simon Hawkins Date: Mon, 29 Jun 2020 16:01:38 +0100 Subject: [PATCH 17/21] add ArrayFunctionMixin to NDArrayBackedExtensionArray and reinstate PandasArray tests --- pandas/core/arrays/_mixins.py | 4 ++-- pandas/core/arrays/categorical.py | 3 +-- pandas/core/arrays/datetimelike.py | 8 ++------ pandas/core/arrays/string_.py | 4 ++-- pandas/tests/extension/test_numpy.py | 5 +++++ 5 files changed, 12 insertions(+), 12 deletions(-) diff --git a/pandas/core/arrays/_mixins.py b/pandas/core/arrays/_mixins.py index 832d09b062265..2d38ed278f622 100644 --- a/pandas/core/arrays/_mixins.py +++ b/pandas/core/arrays/_mixins.py @@ -7,12 +7,12 @@ from pandas.util._decorators import cache_readonly from pandas.core.algorithms import take, unique -from pandas.core.arrays.base import ExtensionArray +from pandas.core.arrays.base import ArrayFunctionMixin, ExtensionArray _T = TypeVar("_T", bound="NDArrayBackedExtensionArray") -class NDArrayBackedExtensionArray(ExtensionArray): +class NDArrayBackedExtensionArray(ExtensionArray, ArrayFunctionMixin): """ ExtensionArray that is backed by a single NumPy ndarray. """ diff --git a/pandas/core/arrays/categorical.py b/pandas/core/arrays/categorical.py index 28aaf83dec97e..5f72c5b7f2678 100644 --- a/pandas/core/arrays/categorical.py +++ b/pandas/core/arrays/categorical.py @@ -45,7 +45,6 @@ import pandas.core.algorithms as algorithms from pandas.core.algorithms import _get_data_algo, factorize, take_1d, unique1d from pandas.core.array_algos.transforms import shift -from pandas.core.arrays import ArrayFunctionMixin from pandas.core.arrays._mixins import _T, NDArrayBackedExtensionArray from pandas.core.base import ( ExtensionArray, @@ -202,7 +201,7 @@ def contains(cat, key, container): return any(loc_ in container for loc_ in loc) -class Categorical(NDArrayBackedExtensionArray, PandasObject, ArrayFunctionMixin): +class Categorical(NDArrayBackedExtensionArray, PandasObject): """ Represent a categorical variable in classic R / S-plus fashion. diff --git a/pandas/core/arrays/datetimelike.py b/pandas/core/arrays/datetimelike.py index 5626c3656b0a6..3c648d884fea3 100644 --- a/pandas/core/arrays/datetimelike.py +++ b/pandas/core/arrays/datetimelike.py @@ -56,11 +56,7 @@ from pandas.core.algorithms import checked_add_with_arr, unique1d, value_counts from pandas.core.array_algos.transforms import shift from pandas.core.arrays._mixins import _T, NDArrayBackedExtensionArray -from pandas.core.arrays.base import ( - ArrayFunctionMixin, - ExtensionArray, - ExtensionOpsMixin, -) +from pandas.core.arrays.base import ExtensionArray, ExtensionOpsMixin import pandas.core.common as com from pandas.core.construction import array, extract_array from pandas.core.indexers import check_array_indexer @@ -444,7 +440,7 @@ def _with_freq(self, freq): class DatetimeLikeArrayMixin( - ExtensionOpsMixin, AttributesMixin, NDArrayBackedExtensionArray, ArrayFunctionMixin + ExtensionOpsMixin, AttributesMixin, NDArrayBackedExtensionArray ): """ Shared Base/Mixin class for DatetimeArray, TimedeltaArray, PeriodArray diff --git a/pandas/core/arrays/string_.py b/pandas/core/arrays/string_.py index d808668da16b3..ac501a8afbe09 100644 --- a/pandas/core/arrays/string_.py +++ b/pandas/core/arrays/string_.py @@ -13,7 +13,7 @@ from pandas import compat from pandas.core import ops -from pandas.core.arrays import ArrayFunctionMixin, IntegerArray, PandasArray +from pandas.core.arrays import IntegerArray, PandasArray from pandas.core.arrays.integer import _IntegerDtype from pandas.core.construction import extract_array from pandas.core.indexers import check_array_indexer @@ -98,7 +98,7 @@ def __from_arrow__( return StringArray._concat_same_type(results) -class StringArray(PandasArray, ArrayFunctionMixin): +class StringArray(PandasArray): """ Extension array for string data. diff --git a/pandas/tests/extension/test_numpy.py b/pandas/tests/extension/test_numpy.py index 78000c0252375..bf9be949c5cd6 100644 --- a/pandas/tests/extension/test_numpy.py +++ b/pandas/tests/extension/test_numpy.py @@ -502,3 +502,8 @@ def test_setitem_loc_iloc_slice(self, data): @skip_nested class TestParsing(BaseNumPyTests, base.BaseParsingTests): pass + + +@skip_nested +class TestNumpyArrayFunctions(base.BaseNumpyArrayFunctionTests): + pass From d7f4550829b52181f6d89e11b2960e51f24c9908 Mon Sep 17 00:00:00 2001 From: Simon Hawkins Date: Sat, 4 Jul 2020 19:00:27 +0100 Subject: [PATCH 18/21] reduce diff --- pandas/core/arrays/base.py | 45 +++----------------------------------- 1 file changed, 3 insertions(+), 42 deletions(-) diff --git a/pandas/core/arrays/base.py b/pandas/core/arrays/base.py index 13d681af74c06..8230fbdd04870 100644 --- a/pandas/core/arrays/base.py +++ b/pandas/core/arrays/base.py @@ -56,15 +56,10 @@ def __array_function__(self, func, types, args, kwargs): # dict mapping. exclude_list = {"unique"} ea_func = getattr(type(self), func.__name__, None) + if not callable(ea_func): + ea_func = None if ea_func is None or ea_func.__name__ in exclude_list: - # Need to convert EAs to numpy.ndarray so we can call the NumPy - # function again and it gets the chance to dispatch to the - # right implementation. - args = tuple( - arg.to_numpy() if isinstance(arg, ExtensionArray) else arg - for arg in args - ) - return func(*args, **kwargs) + return func.__wrapped__(*args, **kwargs) return ea_func(*args, **kwargs) @@ -1317,37 +1312,3 @@ def _create_arithmetic_method(cls, op): @classmethod def _create_comparison_method(cls, op): return cls._create_method(op, coerce_to_dtype=False, result_dtype=bool) - - -@implements(np.ndim) -def ndim(array: ExtensionArray) -> int: - """ - Return the number of dimensions of an array. - """ - return array.ndim - - -@implements(np.vstack) -def vstack(to_stack: Sequence[ArrayLike]) -> np.ndarray: - """ - Stack arrays in sequence vertically (row wise). - """ - to_stack = tuple( - arr.to_numpy() if isinstance(arr, ExtensionArray) else arr for arr in to_stack - ) - return np.vstack(to_stack) - - -@implements(np.putmask) -def putmask(a: ArrayLike, mask: ArrayLike, values: ArrayLike) -> None: - """ - Changes elements of an array based on conditional and input values. - """ - # TODO: refactor Index.putmask to not rely on this behaviour for IntervalArray - if isinstance(a, ExtensionArray): - raise TypeError( - f"putmask() argument 1 must be numpy.ndarray, not {type(a).__name__}" - ) - mask = mask.to_numpy() if isinstance(mask, ExtensionArray) else mask - values = values.to_numpy() if isinstance(values, ExtensionArray) else values - return np.putmask(a, mask, values) From 577010fde177c0cb82961c55360f06669ad4de07 Mon Sep 17 00:00:00 2001 From: Simon Hawkins Date: Sat, 4 Jul 2020 19:13:34 +0100 Subject: [PATCH 19/21] refactor condition --- pandas/core/arrays/base.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/pandas/core/arrays/base.py b/pandas/core/arrays/base.py index 8230fbdd04870..487e4ee626d5b 100644 --- a/pandas/core/arrays/base.py +++ b/pandas/core/arrays/base.py @@ -56,9 +56,11 @@ def __array_function__(self, func, types, args, kwargs): # dict mapping. exclude_list = {"unique"} ea_func = getattr(type(self), func.__name__, None) - if not callable(ea_func): - ea_func = None - if ea_func is None or ea_func.__name__ in exclude_list: + if ( + ea_func is None + or not callable(ea_func) + or ea_func.__name__ in exclude_list + ): return func.__wrapped__(*args, **kwargs) return ea_func(*args, **kwargs) From 3843481e103b8f8cb7be8ed5bacbd10686518dbd Mon Sep 17 00:00:00 2001 From: Simon Hawkins Date: Sun, 13 Sep 2020 19:49:36 +0100 Subject: [PATCH 20/21] amin, amax to NDArrayBackedExtensionArray --- pandas/core/arrays/_mixins.py | 16 ++++++++++++++++ pandas/core/arrays/categorical.py | 4 ---- pandas/core/arrays/datetimelike.py | 4 ---- pandas/core/arrays/numpy_.py | 4 ---- 4 files changed, 16 insertions(+), 12 deletions(-) diff --git a/pandas/core/arrays/_mixins.py b/pandas/core/arrays/_mixins.py index 65adcf6c5a0b4..209220eb528c6 100644 --- a/pandas/core/arrays/_mixins.py +++ b/pandas/core/arrays/_mixins.py @@ -156,3 +156,19 @@ def _validate_shift_value(self, fill_value): # TODO: after deprecation in datetimelikearraymixin is enforced, # we can remove this and ust validate_fill_value directly return self._validate_fill_value(fill_value) + + # ------------------------------------------------------------------------ + # __array_function__ compat + # ------------------------------------------------------------------------ + + def min(self, skipna: bool = True, **kwargs): + raise AbstractMethodError(self) + + def amin(self, skipna: bool = True, **kwargs): + return self.min(skipna=skipna, **kwargs) + + def max(self, skipna: bool = True, **kwargs): + raise AbstractMethodError(self) + + def amax(self, skipna: bool = True, **kwargs): + return self.max(skipna=skipna, **kwargs) diff --git a/pandas/core/arrays/categorical.py b/pandas/core/arrays/categorical.py index 4fe176d4a1967..4fa6b73932aa4 100644 --- a/pandas/core/arrays/categorical.py +++ b/pandas/core/arrays/categorical.py @@ -2040,8 +2040,6 @@ def min(self, skipna=True, **kwargs): pointer = self._codes.min() return self.categories[pointer] - amin = min - @deprecate_kwarg(old_arg_name="numeric_only", new_arg_name="skipna") def max(self, skipna=True, **kwargs): """ @@ -2078,8 +2076,6 @@ def max(self, skipna=True, **kwargs): pointer = self._codes.max() return self.categories[pointer] - amax = max - def mode(self, dropna=True): """ Returns the mode(s) of the Categorical. diff --git a/pandas/core/arrays/datetimelike.py b/pandas/core/arrays/datetimelike.py index d45ade74baa21..6f0e2a6a598fc 100644 --- a/pandas/core/arrays/datetimelike.py +++ b/pandas/core/arrays/datetimelike.py @@ -1550,8 +1550,6 @@ def min(self, axis=None, skipna=True, *args, **kwargs): return NaT return self._box_func(result) - amin = min - def max(self, axis=None, skipna=True, *args, **kwargs): """ Return the maximum value of the Array or maximum along @@ -1584,8 +1582,6 @@ def max(self, axis=None, skipna=True, *args, **kwargs): # Don't have to worry about NA `result`, since no NA went in. return self._box_func(result) - amax = max - def mean(self, skipna=True): """ Return the mean value of the Array. diff --git a/pandas/core/arrays/numpy_.py b/pandas/core/arrays/numpy_.py index 95f09e92305b3..d3fa87d5ea7ff 100644 --- a/pandas/core/arrays/numpy_.py +++ b/pandas/core/arrays/numpy_.py @@ -341,8 +341,6 @@ def min(self, skipna: bool = True, **kwargs) -> Scalar: ) return result - amin = min - def max(self, skipna: bool = True, **kwargs) -> Scalar: nv.validate_max((), kwargs) result = masked_reductions.max( @@ -350,8 +348,6 @@ def max(self, skipna: bool = True, **kwargs) -> Scalar: ) return result - amax = max - def sum(self, axis=None, skipna=True, min_count=0, **kwargs) -> Scalar: nv.validate_sum((), kwargs) return nanops.nansum( From 7b9fab5f698eb446e6d352df20eeb9c3433f2cd6 Mon Sep 17 00:00:00 2001 From: Simon Hawkins Date: Mon, 14 Sep 2020 17:34:45 +0100 Subject: [PATCH 21/21] concat_compat raises numpy.AxisError with axis=1 --- pandas/core/dtypes/concat.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pandas/core/dtypes/concat.py b/pandas/core/dtypes/concat.py index 3dd2f06c18877..010058c9940aa 100644 --- a/pandas/core/dtypes/concat.py +++ b/pandas/core/dtypes/concat.py @@ -147,7 +147,7 @@ def is_nonempty(x) -> bool: target_dtype = find_common_type([x.dtype for x in to_concat]) to_concat = [_cast_to_common_type(arr, target_dtype) for arr in to_concat] - if isinstance(to_concat[0], ExtensionArray): + if isinstance(to_concat[0], ExtensionArray) and axis == 0: cls = type(to_concat[0]) return cls._concat_same_type(to_concat) else: @@ -155,7 +155,7 @@ def is_nonempty(x) -> bool: arr.to_numpy() if isinstance(arr, ExtensionArray) else arr for arr in to_concat ] - return np.concatenate(to_concat) + return np.concatenate(to_concat, axis=axis) elif _contains_datetime or "timedelta" in typs: return _concat_datetime(to_concat, axis=axis, typs=typs)