Skip to content

[WIP] Add basic ExtensionIndex class #23223

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion pandas/core/dtypes/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ def _check(cls, inst):
"float64index", "uint64index",
"multiindex", "datetimeindex",
"timedeltaindex", "periodindex",
"categoricalindex", "intervalindex"))
"categoricalindex", "intervalindex",
"extensionindex"))

ABCSeries = create_pandas_abc_type("ABCSeries", "_typ", ("series", ))
ABCDataFrame = create_pandas_abc_type("ABCDataFrame", "_typ", ("dataframe", ))
Expand Down
6 changes: 6 additions & 0 deletions pandas/core/indexes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,10 @@ def __new__(cls, data=None, dtype=None, copy=False, name=None,
else:
return result

elif is_extension_array_dtype(data):
from pandas.core.indexes.extension import ExtensionIndex
return ExtensionIndex(data, name=name)

# extension dtype
elif is_extension_array_dtype(data) or is_extension_array_dtype(dtype):
data = np.asarray(data)
Expand Down Expand Up @@ -2408,6 +2412,8 @@ def to_native_types(self, slicer=None, **kwargs):
values = values[slicer]
return values._format_native_types(**kwargs)

# TODO(EA) potentially overwrite for better implementation
# or use _formatting_values
def _format_native_types(self, na_rep='', quoting=None, **kwargs):
""" actually format my specific types """
mask = isna(self)
Expand Down
149 changes: 149 additions & 0 deletions pandas/core/indexes/extension.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
import numpy as np
from pandas._libs import index as libindex

# from pandas._libs import (lib, index as libindex, tslibs,
# algos as libalgos, join as libjoin,
# Timedelta)

from pandas.compat.numpy import function as nv

from pandas.core.arrays import ExtensionArray
from pandas.core.dtypes.common import (
ensure_platform_int,
is_integer_dtype, is_float_dtype)

from pandas.util._decorators import (
Appender, cache_readonly)

from .base import Index


# _index_doc_kwargs = dict(ibase._index_doc_kwargs)
# _index_doc_kwargs.update(
# dict(klass='IntervalIndex',
# target_klass='IntervalIndex or list of Intervals',
# name=textwrap.dedent("""\
# name : object, optional
# to be stored in the index.
# """),
# ))


class ExtensionIndex(Index):
"""
Index class that holds an ExtensionArray.

"""
_typ = 'extensionindex'
_comparables = ['name']
_attributes = ['name']

_can_hold_na = True

@property
def _is_numeric_dtype(self):
return self.dtype._is_numeric

# TODO
# # would we like our indexing holder to defer to us
# _defer_to_indexing = False

# # prioritize current class for _shallow_copy_with_infer,
# # used to infer integers as datetime-likes
# _infer_as_myclass = False

def __new__(cls, *args, **kwargs):
return object.__new__(cls)

def __init__(self, array, name=None, copy=False, **kwargs):
# needs to accept and ignore kwargs eg for freq passed in
# Index._shallow_copy_with_infer

if isinstance(array, ExtensionIndex):
array = array._data

if not isinstance(array, ExtensionArray):
raise TypeError()
if copy:
array = array.copy()
self._data = array
self.name = name

def __len__(self):
"""
return the length of the Index
"""
return len(self._data)

@property
def size(self):
# EA does not have .size
return len(self._data)

def __array__(self, dtype=None):
""" the array interface, return my values """
return np.array(self._data)

@cache_readonly
def dtype(self):
""" return the dtype object of the underlying data """
return self._values.dtype

@cache_readonly
def dtype_str(self):
""" return the dtype str of the underlying data """
return str(self.dtype)

@property
def _values(self):
return self._data

@property
def values(self):
""" return the underlying data as an ndarray """
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ndarray -> extension array.

return self._values

@cache_readonly
def _isnan(self):
""" return if each value is nan"""
return self._values.isna()

@cache_readonly
def _engine_type(self):
values, na_value = self._values._values_for_factorize()
if is_integer_dtype(values):
return libindex.Int64Engine
elif is_float_dtype(values):
return libindex.Float64Engine
# TODO add more
else:
return libindex.ObjectEngine

@cache_readonly
def _engine(self):
# property, for now, slow to look up
values, na_value = self._values._values_for_factorize()
return self._engine_type(lambda: values, len(self))

def _format_with_header(self, header, **kwargs):
return header + list(self._format_native_types(**kwargs))

@Appender(Index.take.__doc__)
def take(self, indices, axis=0, allow_fill=True, fill_value=None,
**kwargs):
if kwargs:
nv.validate_take(tuple(), kwargs)
indices = ensure_platform_int(indices)

result = self._data.take(indices, allow_fill=allow_fill,
fill_value=fill_value)
attributes = self._get_attributes_dict()
return self._simple_new(result, **attributes)

def __getitem__(self, value):
result = self._data[value]
if isinstance(result, self._data.__class__):
return self._shallow_copy(result)
else:
# scalar
return result
2 changes: 2 additions & 0 deletions pandas/tests/extension/base/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ class TestMyDtype(BaseDtypeTests):
from .dtype import BaseDtypeTests # noqa
from .getitem import BaseGetitemTests # noqa
from .groupby import BaseGroupbyTests # noqa
from .index import BaseIndexTests # noqa

from .interface import BaseInterfaceTests # noqa
from .methods import BaseMethodsTests # noqa
from .ops import BaseArithmeticOpsTests, BaseComparisonOpsTests, BaseOpsUtil # noqa
Expand Down
96 changes: 96 additions & 0 deletions pandas/tests/extension/base/index.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
import pytest
import numpy as np

import pandas as pd
import pandas.util.testing as tm
from pandas.core.indexes.extension import ExtensionIndex

from .base import BaseExtensionTests


class BaseIndexTests(BaseExtensionTests):
"""Tests for ExtensionIndex."""

def test_constructor(self, data):
result = ExtensionIndex(data, name='test')
assert result.name == 'test'
self.assert_extension_array_equal(data, result._values)

def test_series_constructor(self, data):
result = pd.Series(range(len(data)), index=data)
assert isinstance(result.index, ExtensionIndex)

def test_asarray(self, data):
idx = ExtensionIndex(data)
tm.assert_numpy_array_equal(np.array(idx), np.array(data))

def test_repr(self, data):
idx = ExtensionIndex(data, name='test')
repr(idx)
s = pd.Series(range(len(data)), index=data)
repr(s)

def test_indexing_scalar(self, data):
s = pd.Series(range(len(data)), index=data)
label = data[1]
assert s[label] == 1
assert s.iloc[1] == 1
assert s.loc[label] == 1

def test_indexing_list(self, data):
s = pd.Series(range(len(data)), index=data)
labels = [data[1], data[3]]
exp = pd.Series([1, 3], index=data[[1, 3]])
self.assert_series_equal(s[labels], exp)
self.assert_series_equal(s.loc[labels], exp)
self.assert_series_equal(s.iloc[[1, 3]], exp)

def test_contains(self, data_missing, data_for_sorting, na_value):
idx = ExtensionIndex(data_missing)
assert data_missing[0] in idx
assert data_missing[1] in idx
assert na_value in idx
assert '__random' not in idx
idx = ExtensionIndex(data_for_sorting)
assert na_value not in idx

def test_na(self, data_missing):
idx = ExtensionIndex(data_missing)
result = idx.isna()
expected = np.array([True, False], dtype=bool)
tm.assert_numpy_array_equal(result, expected)
result = idx.notna()
tm.assert_numpy_array_equal(result, ~expected)
assert idx.hasnans #is True

def test_monotonic(self, data_for_sorting):
data = data_for_sorting
idx = ExtensionIndex(data)
assert idx.is_monotonic_increasing is False
assert idx.is_monotonic_decreasing is False

idx = ExtensionIndex(data[[2, 0, 1]])
assert idx.is_monotonic_increasing is True
assert idx.is_monotonic_decreasing is False

idx = ExtensionIndex(data[[1, 0, 2]])
assert idx.is_monotonic_increasing is False
assert idx.is_monotonic_decreasing is True

def test_is_unique(self, data_for_sorting, data_for_grouping):
idx = ExtensionIndex(data_for_sorting)
assert idx.is_unique is True

idx = ExtensionIndex(data_for_grouping)
assert idx.is_unique is False

def test_take(self, data):
idx = ExtensionIndex(data)
expected = ExtensionIndex(data.take([0, 2, 3]))
result = idx.take([0, 2, 3])
tm.assert_index_equal(result, expected)

def test_getitem(self, data):
idx = ExtensionIndex(data)
assert idx[0] == data[0]
tm.assert_index_equal(idx[[0, 1]], ExtensionIndex(data[[0, 1]]))
4 changes: 4 additions & 0 deletions pandas/tests/extension/decimal/test_decimal.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,10 @@ def test_compare_array(self, data, all_compare_operators):
self._compare_other(s, data, op_name, other)


class TestIndex(base.BaseIndexTests):
pass


class DecimalArrayWithoutFromSequence(DecimalArray):
"""Helper class for testing error handling in _from_sequence."""
def _from_sequence(cls, scalars, dtype=None, copy=False):
Expand Down
4 changes: 4 additions & 0 deletions pandas/tests/extension/test_integer.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,3 +216,7 @@ class TestNumericReduce(base.BaseNumericReduceTests):

class TestBooleanReduce(base.BaseBooleanReduceTests):
pass


class TestIndex(base.BaseIndexTests):
pass