From 015e4b204e3c1900833134181e03e4a8edf6507e Mon Sep 17 00:00:00 2001 From: Joris Van den Bossche Date: Fri, 28 Sep 2018 19:46:54 +0200 Subject: [PATCH 1/4] Add basic ExtensionIndex class --- pandas/core/dtypes/generic.py | 3 +- pandas/core/indexes/base.py | 6 + pandas/core/indexes/extension.py | 196 ++++++++++++++++++ pandas/tests/extension/base/__init__.py | 2 + pandas/tests/extension/base/index.py | 96 +++++++++ .../tests/extension/decimal/test_decimal.py | 4 + pandas/tests/extension/test_integer.py | 4 + 7 files changed, 310 insertions(+), 1 deletion(-) create mode 100644 pandas/core/indexes/extension.py create mode 100644 pandas/tests/extension/base/index.py diff --git a/pandas/core/dtypes/generic.py b/pandas/core/dtypes/generic.py index cb54c94d29205..0b031899aeb5d 100644 --- a/pandas/core/dtypes/generic.py +++ b/pandas/core/dtypes/generic.py @@ -39,7 +39,8 @@ def _check(cls, inst): "float64index", "uint64index", "multiindex", "datetimeindex", "timedeltaindex", "periodindex", - "categoricalindex", "intervalindex")) + "categoricalindex", "intervalindex", + "extensionindex")) ABCSeries = create_pandas_abc_type("ABCSeries", "_typ", ("series", )) ABCDataFrame = create_pandas_abc_type("ABCDataFrame", "_typ", ("dataframe", )) diff --git a/pandas/core/indexes/base.py b/pandas/core/indexes/base.py index af04a846ed787..f648b03b2103a 100644 --- a/pandas/core/indexes/base.py +++ b/pandas/core/indexes/base.py @@ -303,6 +303,10 @@ def __new__(cls, data=None, dtype=None, copy=False, name=None, else: return result + elif is_extension_array_dtype(data): + from pandas.core.indexes.extension import ExtensionIndex + return ExtensionIndex(data, name=name) + # extension dtype elif is_extension_array_dtype(data) or is_extension_array_dtype(dtype): data = np.asarray(data) @@ -2400,6 +2404,8 @@ def to_native_types(self, slicer=None, **kwargs): values = values[slicer] return values._format_native_types(**kwargs) + # TODO(EA) potentially overwrite for better implementation + # or use _formatting_values def _format_native_types(self, na_rep='', quoting=None, **kwargs): """ actually format my specific types """ mask = isna(self) diff --git a/pandas/core/indexes/extension.py b/pandas/core/indexes/extension.py new file mode 100644 index 0000000000000..746196ff5cf02 --- /dev/null +++ b/pandas/core/indexes/extension.py @@ -0,0 +1,196 @@ +from datetime import datetime, timedelta +import warnings +import operator +from textwrap import dedent + +import numpy as np +from pandas._libs import (lib, index as libindex, tslibs, + algos as libalgos, join as libjoin, + Timedelta) +from pandas._libs.lib import is_datetime_array + +from pandas.compat import range, u, set_function_name +from pandas.compat.numpy import function as nv +from pandas import compat + +from pandas.core.accessor import CachedAccessor +from pandas.core.arrays import ExtensionArray +from pandas.core.dtypes.generic import ( + ABCSeries, ABCDataFrame, + ABCMultiIndex, + ABCPeriodIndex, ABCTimedeltaIndex, + ABCDateOffset) +from pandas.core.dtypes.missing import isna, array_equivalent +from pandas.core.dtypes.cast import maybe_cast_to_integer_array +from pandas.core.dtypes.common import ( + ensure_int64, + ensure_object, + ensure_categorical, + ensure_platform_int, + is_integer, + is_float, + is_dtype_equal, + is_dtype_union_equal, + is_object_dtype, + is_categorical, + is_categorical_dtype, + is_interval_dtype, + is_period_dtype, + is_bool, + is_bool_dtype, + is_signed_integer_dtype, + is_unsigned_integer_dtype, + is_integer_dtype, is_float_dtype, + is_datetime64_any_dtype, + is_datetime64tz_dtype, + is_timedelta64_dtype, + is_extension_array_dtype, + is_hashable, + is_iterator, is_list_like, + is_scalar) + +from pandas.core.base import PandasObject, IndexOpsMixin +import pandas.core.common as com +from pandas.core import ops +from pandas.util._decorators import ( + Appender, Substitution, cache_readonly) +from pandas.core.indexes.frozen import FrozenList +import pandas.core.indexes.base as ibase +import pandas.core.dtypes.concat as _concat +import pandas.core.missing as missing +import pandas.core.algorithms as algos +import pandas.core.sorting as sorting +from pandas.io.formats.printing import ( + pprint_thing, default_pprint, format_object_summary, format_object_attrs) +from pandas.core.ops import make_invalid_op +from pandas.core.strings import StringMethods + +from .base import Index + + +# _index_doc_kwargs = dict(ibase._index_doc_kwargs) +# _index_doc_kwargs.update( +# dict(klass='IntervalIndex', +# target_klass='IntervalIndex or list of Intervals', +# name=textwrap.dedent("""\ +# name : object, optional +# to be stored in the index. +# """), +# )) + + +class ExtensionIndex(Index): + """ + + """ + _typ = 'extensionindex' + _comparables = ['name'] + _attributes = ['name'] + + _can_hold_na = True + + @property + def _is_numeric_dtype(self): + return self.dtype._is_numeric + + # # would we like our indexing holder to defer to us + # _defer_to_indexing = False + + # # prioritize current class for _shallow_copy_with_infer, + # # used to infer integers as datetime-likes + # _infer_as_myclass = False + + def __new__(cls, *args, **kwargs): + return object.__new__(cls) + + def __init__(self, array, name=None, copy=False, **kwargs): + # needs to accept and ignore kwargs eg for freq passed in Index._shallow_copy_with_infer + + if isinstance(array, ExtensionIndex): + array = array._data + + if not isinstance(array, ExtensionArray): + raise TypeError() + if copy: + array = array.copy() + self._data = array + self.name = name + + def __len__(self): + """ + return the length of the Index + """ + return len(self._data) + + @property + def size(self): + # EA does not have .size + return len(self._data) + + def __array__(self, dtype=None): + """ the array interface, return my values """ + return np.array(self._values) + + @cache_readonly + def dtype(self): + """ return the dtype object of the underlying data """ + return self._values.dtype + + @cache_readonly + def dtype_str(self): + """ return the dtype str of the underlying data """ + return str(self.dtype) + + @property + def _values(self): + return self._data + + @property + def values(self): + """ return the underlying data as an ndarray """ + return self._values + + @cache_readonly + def _isnan(self): + """ return if each value is nan""" + return self._values.isna() + + @cache_readonly + def _engine_type(self): + values, na_value = self._values._values_for_factorize() + if is_integer_dtype(values): + return libindex.Int64Engine + elif is_float_dtype(values): + return libindex.Float64Engine + # TODO add more + else: + return libindex.ObjectEngine + + @cache_readonly + def _engine(self): + # property, for now, slow to look up + values, na_value = self._values._values_for_factorize() + return self._engine_type(lambda: values, len(self)) + + def _format_with_header(self, header, **kwargs): + return header + list(self._format_native_types(**kwargs)) + + @Appender(Index.take.__doc__) + def take(self, indices, axis=0, allow_fill=True, fill_value=None, + **kwargs): + if kwargs: + nv.validate_take(tuple(), kwargs) + indices = ensure_platform_int(indices) + + result = self._data.take(indices, allow_fill=allow_fill, + fill_value=fill_value) + attributes = self._get_attributes_dict() + return self._simple_new(result, **attributes) + + def __getitem__(self, value): + result = self._data[value] + if isinstance(result, self._data.__class__): + return self._shallow_copy(result) + else: + # scalar + return result diff --git a/pandas/tests/extension/base/__init__.py b/pandas/tests/extension/base/__init__.py index b6b81bb941a59..8a0e1829352f8 100644 --- a/pandas/tests/extension/base/__init__.py +++ b/pandas/tests/extension/base/__init__.py @@ -45,6 +45,8 @@ class TestMyDtype(BaseDtypeTests): from .dtype import BaseDtypeTests # noqa from .getitem import BaseGetitemTests # noqa from .groupby import BaseGroupbyTests # noqa +from .index import BaseIndexTests # noqa + from .interface import BaseInterfaceTests # noqa from .methods import BaseMethodsTests # noqa from .ops import BaseArithmeticOpsTests, BaseComparisonOpsTests, BaseOpsUtil # noqa diff --git a/pandas/tests/extension/base/index.py b/pandas/tests/extension/base/index.py new file mode 100644 index 0000000000000..28daa0f3c8eed --- /dev/null +++ b/pandas/tests/extension/base/index.py @@ -0,0 +1,96 @@ +import pytest +import numpy as np + +import pandas as pd +import pandas.util.testing as tm +from pandas.core.indexes.extension import ExtensionIndex + +from .base import BaseExtensionTests + + +class BaseIndexTests(BaseExtensionTests): + """Tests for ExtensionIndex.""" + + def test_constructor(self, data): + result = ExtensionIndex(data, name='test') + assert result.name == 'test' + self.assert_extension_array_equal(data, result._values) + + def test_series_constructor(self, data): + result = pd.Series(range(len(data)), index=data) + assert isinstance(result.index, ExtensionIndex) + + def test_asarray(self, data): + idx = ExtensionIndex(data) + tm.assert_numpy_array_equal(np.array(idx), np.array(data)) + + def test_repr(self, data): + idx = ExtensionIndex(data, name='test') + repr(idx) + s = pd.Series(range(len(data)), index=data) + repr(s) + + def test_indexing_scalar(self, data): + s = pd.Series(range(len(data)), index=data) + label = data[1] + assert s[label] == 1 + assert s.iloc[1] == 1 + assert s.loc[label] == 1 + + def test_indexing_list(self, data): + s = pd.Series(range(len(data)), index=data) + labels = [data[1], data[3]] + exp = pd.Series([1, 3], index=data[[1, 3]]) + self.assert_series_equal(s[labels], exp) + self.assert_series_equal(s.loc[labels], exp) + self.assert_series_equal(s.iloc[[1, 3]], exp) + + def test_contains(self, data_missing, data_for_sorting, na_value): + idx = ExtensionIndex(data_missing) + assert data_missing[0] in idx + assert data_missing[1] in idx + assert na_value in idx + assert '__random' not in idx + idx = ExtensionIndex(data_for_sorting) + assert na_value not in idx + + def test_na(self, data_missing): + idx = ExtensionIndex(data_missing) + result = idx.isna() + expected = np.array([True, False], dtype=bool) + tm.assert_numpy_array_equal(result, expected) + result = idx.notna() + tm.assert_numpy_array_equal(result, ~expected) + assert idx.hasnans #is True + + def test_monotonic(self, data_for_sorting): + data = data_for_sorting + idx = ExtensionIndex(data) + assert idx.is_monotonic_increasing is False + assert idx.is_monotonic_decreasing is False + + idx = ExtensionIndex(data[[2, 0, 1]]) + assert idx.is_monotonic_increasing is True + assert idx.is_monotonic_decreasing is False + + idx = ExtensionIndex(data[[1, 0, 2]]) + assert idx.is_monotonic_increasing is False + assert idx.is_monotonic_decreasing is True + + def test_is_unique(self, data_for_sorting, data_for_grouping): + idx = ExtensionIndex(data_for_sorting) + assert idx.is_unique is True + + idx = ExtensionIndex(data_for_grouping) + assert idx.is_unique is False + + def test_take(self, data): + idx = ExtensionIndex(data) + expected = ExtensionIndex(data.take([0, 2, 3])) + result = idx.take([0, 2, 3]) + tm.assert_index_equal(result, expected) + + def test_getitem(self, data): + idx = ExtensionIndex(data) + assert idx[0] == data[0] + tm.assert_index_equal(idx[[0, 1]], ExtensionIndex(data[[0, 1]])) diff --git a/pandas/tests/extension/decimal/test_decimal.py b/pandas/tests/extension/decimal/test_decimal.py index 93b8ea786ef5b..d477b3c643644 100644 --- a/pandas/tests/extension/decimal/test_decimal.py +++ b/pandas/tests/extension/decimal/test_decimal.py @@ -275,3 +275,7 @@ def test_compare_array(self, data, all_compare_operators): other = pd.Series(data) * [decimal.Decimal(pow(2.0, i)) for i in alter] self._compare_other(s, data, op_name, other) + + +class TestIndex(base.BaseIndexTests): + pass diff --git a/pandas/tests/extension/test_integer.py b/pandas/tests/extension/test_integer.py index 7aa33006dadda..64880d7de48a7 100644 --- a/pandas/tests/extension/test_integer.py +++ b/pandas/tests/extension/test_integer.py @@ -216,3 +216,7 @@ def test_groupby_extension_no_sort(self, data_for_grouping): def test_groupby_extension_agg(self, as_index, data_for_grouping): super(TestGroupby, self).test_groupby_extension_agg( as_index, data_for_grouping) + + +class TestIndex(base.BaseIndexTests): + pass From 9e282c96fbf9cd0f0c5df14ecbd1e23fcdf4ba0d Mon Sep 17 00:00:00 2001 From: Joris Van den Bossche Date: Thu, 18 Oct 2018 16:08:43 +0200 Subject: [PATCH 2/4] clean-up --- pandas/core/indexes/extension.py | 83 +++++++------------------------- 1 file changed, 18 insertions(+), 65 deletions(-) diff --git a/pandas/core/indexes/extension.py b/pandas/core/indexes/extension.py index 746196ff5cf02..b86b59c829b25 100644 --- a/pandas/core/indexes/extension.py +++ b/pandas/core/indexes/extension.py @@ -1,69 +1,19 @@ -from datetime import datetime, timedelta -import warnings -import operator -from textwrap import dedent - import numpy as np -from pandas._libs import (lib, index as libindex, tslibs, - algos as libalgos, join as libjoin, - Timedelta) -from pandas._libs.lib import is_datetime_array +from pandas._libs import index as libindex + +# from pandas._libs import (lib, index as libindex, tslibs, +# algos as libalgos, join as libjoin, +# Timedelta) -from pandas.compat import range, u, set_function_name from pandas.compat.numpy import function as nv -from pandas import compat -from pandas.core.accessor import CachedAccessor from pandas.core.arrays import ExtensionArray -from pandas.core.dtypes.generic import ( - ABCSeries, ABCDataFrame, - ABCMultiIndex, - ABCPeriodIndex, ABCTimedeltaIndex, - ABCDateOffset) -from pandas.core.dtypes.missing import isna, array_equivalent -from pandas.core.dtypes.cast import maybe_cast_to_integer_array from pandas.core.dtypes.common import ( - ensure_int64, - ensure_object, - ensure_categorical, ensure_platform_int, - is_integer, - is_float, - is_dtype_equal, - is_dtype_union_equal, - is_object_dtype, - is_categorical, - is_categorical_dtype, - is_interval_dtype, - is_period_dtype, - is_bool, - is_bool_dtype, - is_signed_integer_dtype, - is_unsigned_integer_dtype, - is_integer_dtype, is_float_dtype, - is_datetime64_any_dtype, - is_datetime64tz_dtype, - is_timedelta64_dtype, - is_extension_array_dtype, - is_hashable, - is_iterator, is_list_like, - is_scalar) - -from pandas.core.base import PandasObject, IndexOpsMixin -import pandas.core.common as com -from pandas.core import ops + is_integer_dtype, is_float_dtype) + from pandas.util._decorators import ( - Appender, Substitution, cache_readonly) -from pandas.core.indexes.frozen import FrozenList -import pandas.core.indexes.base as ibase -import pandas.core.dtypes.concat as _concat -import pandas.core.missing as missing -import pandas.core.algorithms as algos -import pandas.core.sorting as sorting -from pandas.io.formats.printing import ( - pprint_thing, default_pprint, format_object_summary, format_object_attrs) -from pandas.core.ops import make_invalid_op -from pandas.core.strings import StringMethods + Appender, cache_readonly) from .base import Index @@ -81,7 +31,8 @@ class ExtensionIndex(Index): """ - + Index class that holds an ExtensionArray. + """ _typ = 'extensionindex' _comparables = ['name'] @@ -92,7 +43,8 @@ class ExtensionIndex(Index): @property def _is_numeric_dtype(self): return self.dtype._is_numeric - + + # TODO # # would we like our indexing holder to defer to us # _defer_to_indexing = False @@ -104,8 +56,9 @@ def __new__(cls, *args, **kwargs): return object.__new__(cls) def __init__(self, array, name=None, copy=False, **kwargs): - # needs to accept and ignore kwargs eg for freq passed in Index._shallow_copy_with_infer - + # needs to accept and ignore kwargs eg for freq passed in + # Index._shallow_copy_with_infer + if isinstance(array, ExtensionIndex): array = array._data @@ -121,7 +74,7 @@ def __len__(self): return the length of the Index """ return len(self._data) - + @property def size(self): # EA does not have .size @@ -129,7 +82,7 @@ def size(self): def __array__(self, dtype=None): """ the array interface, return my values """ - return np.array(self._values) + return np.array(self._data) @cache_readonly def dtype(self): @@ -165,7 +118,7 @@ def _engine_type(self): # TODO add more else: return libindex.ObjectEngine - + @cache_readonly def _engine(self): # property, for now, slow to look up From 6c1d798d205c40c520441fdcbe05907f7bb6c208 Mon Sep 17 00:00:00 2001 From: Joris Van den Bossche Date: Fri, 19 Oct 2018 15:20:43 +0200 Subject: [PATCH 3/4] more robust constructor + add tests --- pandas/core/indexes/base.py | 23 +++----- pandas/core/indexes/extension.py | 38 +++++++++--- pandas/tests/indexes/test_extension.py | 81 ++++++++++++++++++++++++++ 3 files changed, 117 insertions(+), 25 deletions(-) create mode 100644 pandas/tests/indexes/test_extension.py diff --git a/pandas/core/indexes/base.py b/pandas/core/indexes/base.py index db8316bffcc3a..8609bea17f7a9 100644 --- a/pandas/core/indexes/base.py +++ b/pandas/core/indexes/base.py @@ -317,23 +317,14 @@ def __new__(cls, data=None, dtype=None, copy=False, name=None, else: return result - elif is_extension_array_dtype(data): + elif (is_extension_array_dtype(data) + or is_extension_array_dtype(dtype)): + if dtype is not None and is_object_dtype(dtype): + data = np.asarray(data) + return Index(data, dtype=object, copy=copy, name=name, + **kwargs) from pandas.core.indexes.extension import ExtensionIndex - return ExtensionIndex(data, name=name) - - # extension dtype - elif is_extension_array_dtype(data) or is_extension_array_dtype(dtype): - data = np.asarray(data) - if not (dtype is None or is_object_dtype(dtype)): - - # coerce to the provided dtype - data = dtype.construct_array_type()._from_sequence( - data, dtype=dtype, copy=False) - - # coerce to the object dtype - data = data.astype(object) - return Index(data, dtype=object, copy=copy, name=name, - **kwargs) + return ExtensionIndex(data, dtype=dtype, name=name) # index-like elif isinstance(data, (np.ndarray, Index, ABCSeries)): diff --git a/pandas/core/indexes/extension.py b/pandas/core/indexes/extension.py index b86b59c829b25..c90dec42de3c5 100644 --- a/pandas/core/indexes/extension.py +++ b/pandas/core/indexes/extension.py @@ -9,9 +9,15 @@ from pandas.core.arrays import ExtensionArray from pandas.core.dtypes.common import ( + pandas_dtype, ensure_platform_int, - is_integer_dtype, is_float_dtype) - + is_dtype_equal, + is_integer_dtype, + is_float_dtype, + is_extension_array_dtype) +from pandas.core.dtypes.generic import ( + ABCSeries, ABCIndex +) from pandas.util._decorators import ( Appender, cache_readonly) @@ -55,18 +61,32 @@ def _is_numeric_dtype(self): def __new__(cls, *args, **kwargs): return object.__new__(cls) - def __init__(self, array, name=None, copy=False, **kwargs): + def __init__(self, data, dtype=None, name=None, copy=False, **kwargs): # needs to accept and ignore kwargs eg for freq passed in # Index._shallow_copy_with_infer - if isinstance(array, ExtensionIndex): - array = array._data + # unbox containers that can contain ExtensionArray + if isinstance(data, (ABCSeries, ABCIndex)): + data = data._values + + # check dtype and coerce data to dtype if needed + if dtype is not None: + dtype = pandas_dtype(dtype) + if not is_extension_array_dtype(dtype): + raise ValueError( + "The passed dtype should be an ExtensionDtype") + if not is_dtype_equal(getattr(data, 'dtype', None), dtype): + data = dtype.construct_array_type()._from_sequence( + data, dtype=dtype, copy=False) + + if not isinstance(data, ExtensionArray): + raise ValueError("passed data should be an ExtensionArray, or the " + "passed dtype should be an ExtensionDtype") - if not isinstance(array, ExtensionArray): - raise TypeError() if copy: - array = array.copy() - self._data = array + data = data.copy() + + self._data = data self.name = name def __len__(self): diff --git a/pandas/tests/indexes/test_extension.py b/pandas/tests/indexes/test_extension.py new file mode 100644 index 0000000000000..1da6819a03b11 --- /dev/null +++ b/pandas/tests/indexes/test_extension.py @@ -0,0 +1,81 @@ +# -*- coding: utf-8 -*- + +import pytest + +import pandas.util.testing as tm +from pandas.core.indexes.api import Index +from .common import Base + +import numpy as np + +from pandas.util.testing import ( + assert_extension_array_equal, assert_index_equal) + +from pandas.core.arrays import integer_array +from pandas.core.indexes.extension import ExtensionIndex + + +@pytest.fixture +def data(): + return integer_array([1, 2, 3, 4]) + + +def test_constructor(data): + result = ExtensionIndex(data, name='test') + assert result.name == 'test' + assert isinstance(result, ExtensionIndex) + assert_extension_array_equal(data, result._values) + + expected = ExtensionIndex(data, name='test') + # data and passed dtype match + result = ExtensionIndex(data, dtype=data.dtype, name='test') + assert_index_equal(result, expected) + # data is converted to passed dtype + result = ExtensionIndex(np.array(data), dtype=data.dtype, name='test') + assert_index_equal(result, expected) + # EA is converted to passed dtype + expected = ExtensionIndex(integer_array(data, dtype='Int32'), name='test') + result = ExtensionIndex(data, dtype=expected.dtype, name='test') + assert_index_equal(result, expected) + + # no ExtensionDtype passed + with pytest.raises(ValueError): + ExtensionIndex(data, dtype='int64', name='test') + + with pytest.raises(ValueError): + ExtensionIndex(data, dtype=object, name='test') + + # no ExtensionArray passed + with pytest.raises(ValueError): + ExtensionIndex(np.array(data), name='test') + + +def test_default_index_constructor(data): + result = Index(data, name='test') + expected = ExtensionIndex(data, name='test') + assert_index_equal(result, expected) + + result = Index(data, dtype=data.dtype, name='test') + assert_index_equal(result, expected) + + result = Index(np.array(data), dtype=data.dtype, name='test') + assert_index_equal(result, expected) + + result = Index(data, dtype=object, name='test') + expected = Index(np.array(data), dtype=object, name='test') + assert_index_equal(result, expected) + + +# class TestExtensionIndex(Base): +# _holder = ExtensionIndex + +# def setup_method(self, method): +# self.indices = dict( +# extIndex=ExtensionIndex(np.arange(100), dtype='Int64')) +# self.setup_indices() + +# # def create_index(self): +# # if categories is None: +# # categories = list('cab') +# # return CategoricalIndex( +# # list('aabbca'), categories=categories, ordered=ordered) From 00d4a167d8738532348ee06e8b763ddcc58b0f66 Mon Sep 17 00:00:00 2001 From: Joris Van den Bossche Date: Fri, 19 Oct 2018 15:40:48 +0200 Subject: [PATCH 4/4] add common tests --- pandas/core/indexes/extension.py | 7 +++++-- pandas/tests/indexes/test_extension.py | 27 +++++++++++++++----------- 2 files changed, 21 insertions(+), 13 deletions(-) diff --git a/pandas/core/indexes/extension.py b/pandas/core/indexes/extension.py index c90dec42de3c5..37f2c88422d55 100644 --- a/pandas/core/indexes/extension.py +++ b/pandas/core/indexes/extension.py @@ -16,7 +16,7 @@ is_float_dtype, is_extension_array_dtype) from pandas.core.dtypes.generic import ( - ABCSeries, ABCIndex + ABCSeries, ABCIndexClass ) from pandas.util._decorators import ( Appender, cache_readonly) @@ -65,8 +65,11 @@ def __init__(self, data, dtype=None, name=None, copy=False, **kwargs): # needs to accept and ignore kwargs eg for freq passed in # Index._shallow_copy_with_infer + if name is None and hasattr(data, 'name'): + name = data.name + # unbox containers that can contain ExtensionArray - if isinstance(data, (ABCSeries, ABCIndex)): + if isinstance(data, (ABCSeries, ABCIndexClass)): data = data._values # check dtype and coerce data to dtype if needed diff --git a/pandas/tests/indexes/test_extension.py b/pandas/tests/indexes/test_extension.py index 1da6819a03b11..4d9977aef5f2a 100644 --- a/pandas/tests/indexes/test_extension.py +++ b/pandas/tests/indexes/test_extension.py @@ -13,6 +13,7 @@ from pandas.core.arrays import integer_array from pandas.core.indexes.extension import ExtensionIndex +from pandas.tests.extension.decimal import to_decimal, make_data @pytest.fixture @@ -66,16 +67,20 @@ def test_default_index_constructor(data): assert_index_equal(result, expected) -# class TestExtensionIndex(Base): -# _holder = ExtensionIndex +class TestExtensionIndex(Base): + _holder = ExtensionIndex + _compat_props = ['shape', 'ndim', 'nbytes'] # 'size' is not in EA -# def setup_method(self, method): -# self.indices = dict( -# extIndex=ExtensionIndex(np.arange(100), dtype='Int64')) -# self.setup_indices() + def setup_method(self, method): + self.indices = dict( + intIndex=ExtensionIndex(np.arange(100), dtype='Int64'), + decInd=ExtensionIndex(to_decimal(make_data()))) + self.setup_indices() -# # def create_index(self): -# # if categories is None: -# # categories = list('cab') -# # return CategoricalIndex( -# # list('aabbca'), categories=categories, ordered=ordered) + def create_index(self): + return ExtensionIndex(integer_array([0, 1, 2, 3])) + + def test_logical_compat(self): + idx = self.create_index() + assert idx.all() == np.array(idx).all() + assert idx.any() == np.array(idx).any()