Skip to content

ENH: Add 'is_' method to Index for identity checks #4909

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 22, 2013
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions doc/source/release.rst
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,10 @@ API Changes
data - allowing metadata changes.
- ``MultiIndex.astype()`` now only allows ``np.object_``-like dtypes and
now returns a ``MultiIndex`` rather than an ``Index``. (:issue:`4039`)
- Added ``is_`` method to ``Index`` that allows fast equality comparison of
views (similar to ``np.may_share_memory`` but no false positives, and
changes on ``levels`` and ``labels`` setting on ``MultiIndex``).
(:issue:`4859`, :issue:`4909`)

- Infer and downcast dtype if ``downcast='infer'`` is passed to ``fillna/ffill/bfill`` (:issue:`4604`)
- ``__nonzero__`` for all NDFrame objects, will now raise a ``ValueError``, this reverts back to (:issue:`1073`, :issue:`4633`)
Expand Down
44 changes: 41 additions & 3 deletions pandas/core/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@ def _shouldbe_timestamp(obj):
or tslib.is_timestamp_array(obj))


_Identity = object


class Index(FrozenNDArray):
"""
Immutable ndarray implementing an ordered, sliceable set. The basic object
Expand Down Expand Up @@ -87,6 +90,35 @@ class Index(FrozenNDArray):

_engine_type = _index.ObjectEngine

def is_(self, other):
"""
More flexible, faster check like ``is`` but that works through views

Note: this is *not* the same as ``Index.identical()``, which checks
that metadata is also the same.

Parameters
----------
other : object
other object to compare against.

Returns
-------
True if both have same underlying data, False otherwise : bool
"""
# use something other than None to be clearer
return self._id is getattr(other, '_id', Ellipsis)

def _reset_identity(self):
"Initializes or resets ``_id`` attribute with new object"
self._id = _Identity()

def view(self, *args, **kwargs):
result = super(Index, self).view(*args, **kwargs)
if isinstance(result, Index):
result._id = self._id
return result

def __new__(cls, data, dtype=None, copy=False, name=None, fastpath=False,
**kwargs):

Expand Down Expand Up @@ -151,6 +183,7 @@ def __new__(cls, data, dtype=None, copy=False, name=None, fastpath=False,
return subarr

def __array_finalize__(self, obj):
self._reset_identity()
if not isinstance(obj, type(self)):
# Only relevant if array being created from an Index instance
return
Expand Down Expand Up @@ -279,6 +312,7 @@ def set_names(self, names, inplace=False):
raise TypeError("Must pass list-like as `names`.")
if inplace:
idx = self
idx._reset_identity()
else:
idx = self._shallow_copy()
idx._set_names(names)
Expand Down Expand Up @@ -554,7 +588,7 @@ def equals(self, other):
"""
Determines if two Index objects contain the same elements.
"""
if self is other:
if self.is_(other):
return True

if not isinstance(other, Index):
Expand Down Expand Up @@ -1536,7 +1570,7 @@ def equals(self, other):
"""
Determines if two Index objects contain the same elements.
"""
if self is other:
if self.is_(other):
return True

# if not isinstance(other, Int64Index):
Expand Down Expand Up @@ -1645,6 +1679,7 @@ def set_levels(self, levels, inplace=False):
idx = self
else:
idx = self._shallow_copy()
idx._reset_identity()
idx._set_levels(levels)
return idx

Expand Down Expand Up @@ -1683,6 +1718,7 @@ def set_labels(self, labels, inplace=False):
idx = self
else:
idx = self._shallow_copy()
idx._reset_identity()
idx._set_labels(labels)
return idx

Expand Down Expand Up @@ -1736,6 +1772,8 @@ def __array_finalize__(self, obj):
Update custom MultiIndex attributes when a new array is created by
numpy, e.g. when calling ndarray.view()
"""
# overriden if a view
self._reset_identity()
if not isinstance(obj, type(self)):
# Only relevant if this array is being created from an Index
# instance.
Expand Down Expand Up @@ -2754,7 +2792,7 @@ def equals(self, other):
--------
equal_levels
"""
if self is other:
if self.is_(other):
return True

if not isinstance(other, MultiIndex):
Expand Down
45 changes: 45 additions & 0 deletions pandas/tests/test_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,28 @@ def test_identical(self):
i2 = i2.rename('foo')
self.assert_(i1.identical(i2))

def test_is_(self):
ind = Index(range(10))
self.assertTrue(ind.is_(ind))
self.assertTrue(ind.is_(ind.view().view().view().view()))
self.assertFalse(ind.is_(Index(range(10))))
self.assertFalse(ind.is_(ind.copy()))
self.assertFalse(ind.is_(ind.copy(deep=False)))
self.assertFalse(ind.is_(ind[:]))
self.assertFalse(ind.is_(ind.view(np.ndarray).view(Index)))
self.assertFalse(ind.is_(np.array(range(10))))
self.assertTrue(ind.is_(ind.view().base)) # quasi-implementation dependent
ind2 = ind.view()
ind2.name = 'bob'
self.assertTrue(ind.is_(ind2))
self.assertTrue(ind2.is_(ind))
# doesn't matter if Indices are *actually* views of underlying data,
self.assertFalse(ind.is_(Index(ind.values)))
arr = np.array(range(1, 11))
ind1 = Index(arr, copy=False)
ind2 = Index(arr, copy=False)
self.assertFalse(ind1.is_(ind2))

def test_asof(self):
d = self.dateIndex[0]
self.assert_(self.dateIndex.asof(d) is d)
Expand Down Expand Up @@ -1719,6 +1741,29 @@ def test_identical(self):
mi2 = mi2.set_names(['new1','new2'])
self.assert_(mi.identical(mi2))

def test_is_(self):
mi = MultiIndex.from_tuples(lzip(range(10), range(10)))
self.assertTrue(mi.is_(mi))
self.assertTrue(mi.is_(mi.view()))
self.assertTrue(mi.is_(mi.view().view().view().view()))
mi2 = mi.view()
# names are metadata, they don't change id
mi2.names = ["A", "B"]
self.assertTrue(mi2.is_(mi))
self.assertTrue(mi.is_(mi2))
self.assertTrue(mi.is_(mi.set_names(["C", "D"])))
# levels are inherent properties, they change identity
mi3 = mi2.set_levels([lrange(10), lrange(10)])
self.assertFalse(mi3.is_(mi2))
# shouldn't change
self.assertTrue(mi2.is_(mi))
mi4 = mi3.view()
mi4.set_levels([[1 for _ in range(10)], lrange(10)], inplace=True)
self.assertFalse(mi4.is_(mi3))
mi5 = mi.view()
mi5.set_levels(mi5.levels, inplace=True)
self.assertFalse(mi5.is_(mi))

def test_union(self):
piece1 = self.index[:5][::-1]
piece2 = self.index[3:]
Expand Down
5 changes: 3 additions & 2 deletions pandas/tseries/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from pandas.core.common import (isnull, _NS_DTYPE, _INT64_DTYPE,
is_list_like,_values_from_object, _maybe_box)
from pandas.core.index import Index, Int64Index
from pandas.core.index import Index, Int64Index, _Identity
import pandas.compat as compat
from pandas.compat import u
from pandas.tseries.frequencies import (
Expand Down Expand Up @@ -1029,6 +1029,7 @@ def __array_finalize__(self, obj):
self.offset = getattr(obj, 'offset', None)
self.tz = getattr(obj, 'tz', None)
self.name = getattr(obj, 'name', None)
self._reset_identity()

def intersection(self, other):
"""
Expand Down Expand Up @@ -1446,7 +1447,7 @@ def equals(self, other):
"""
Determines if two Index objects contain the same elements.
"""
if self is other:
if self.is_(other):
return True

if (not hasattr(other, 'inferred_type') or
Expand Down
3 changes: 2 additions & 1 deletion pandas/tseries/period.py
Original file line number Diff line number Diff line change
Expand Up @@ -812,7 +812,7 @@ def equals(self, other):
"""
Determines if two Index objects contain the same elements.
"""
if self is other:
if self.is_(other):
return True

return np.array_equal(self.asi8, other.asi8)
Expand Down Expand Up @@ -1076,6 +1076,7 @@ def __array_finalize__(self, obj):

self.freq = getattr(obj, 'freq', None)
self.name = getattr(obj, 'name', None)
self._reset_identity()

def __repr__(self):
output = com.pprint_thing(self.__class__) + '\n'
Expand Down
22 changes: 19 additions & 3 deletions pandas/tseries/tests/test_period.py
Original file line number Diff line number Diff line change
Expand Up @@ -1054,9 +1054,6 @@ def test_conv_secondly(self):


class TestPeriodIndex(TestCase):
def __init__(self, *args, **kwds):
TestCase.__init__(self, *args, **kwds)

def setUp(self):
pass

Expand Down Expand Up @@ -1168,6 +1165,25 @@ def test_constructor_datetime64arr(self):

self.assertRaises(ValueError, PeriodIndex, vals, freq='D')

def test_is_(self):
create_index = lambda: PeriodIndex(freq='A', start='1/1/2001',
end='12/1/2009')
index = create_index()
self.assertTrue(index.is_(index))
self.assertFalse(index.is_(create_index()))
self.assertTrue(index.is_(index.view()))
self.assertTrue(index.is_(index.view().view().view().view().view()))
self.assertTrue(index.view().is_(index))
ind2 = index.view()
index.name = "Apple"
self.assertTrue(ind2.is_(index))
self.assertFalse(index.is_(index[:]))
self.assertFalse(index.is_(index.asfreq('M')))
self.assertFalse(index.is_(index.asfreq('A')))
self.assertFalse(index.is_(index - 2))
self.assertFalse(index.is_(index - 0))


def test_comp_period(self):
idx = period_range('2007-01', periods=20, freq='M')

Expand Down
8 changes: 7 additions & 1 deletion pandas/tseries/tests/test_timeseries.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,12 @@ def assert_range_equal(left, right):
class TestTimeSeries(unittest.TestCase):
_multiprocess_can_split_ = True

def test_is_(self):
dti = DatetimeIndex(start='1/1/2005', end='12/1/2005', freq='M')
self.assertTrue(dti.is_(dti))
self.assertTrue(dti.is_(dti.view()))
self.assertFalse(dti.is_(dti.copy()))

def test_dti_slicing(self):
dti = DatetimeIndex(start='1/1/2005', end='12/1/2005', freq='M')
dti2 = dti[[1, 3, 5]]
Expand Down Expand Up @@ -655,7 +661,7 @@ def test_index_astype_datetime64(self):
idx = Index([datetime(2012, 1, 1)], dtype=object)

if np.__version__ >= '1.7':
raise nose.SkipTest
raise nose.SkipTest("Test requires numpy < 1.7")

casted = idx.astype(np.dtype('M8[D]'))
expected = DatetimeIndex(idx.values)
Expand Down