diff --git a/doc/source/whatsnew/v0.23.0.txt b/doc/source/whatsnew/v0.23.0.txt index c08e22af295f4..b42aaac3cef96 100644 --- a/doc/source/whatsnew/v0.23.0.txt +++ b/doc/source/whatsnew/v0.23.0.txt @@ -939,6 +939,7 @@ Indexing - Bug in :class:`IntervalIndex` where set operations that returned an empty ``IntervalIndex`` had the wrong dtype (:issue:`19101`) - Bug in :meth:`DataFrame.drop_duplicates` where no ``KeyError`` is raised when passing in columns that don't exist on the ``DataFrame`` (issue:`19726`) - Bug in ``Index`` subclasses constructors that ignore unexpected keyword arguments (:issue:`19348`) +- Bug in :meth:`Index.difference` when taking difference of an ``Index`` with itself (:issue:`20040`) MultiIndex diff --git a/pandas/core/indexes/base.py b/pandas/core/indexes/base.py index d5daece62cba8..6cd0b5d7697d2 100644 --- a/pandas/core/indexes/base.py +++ b/pandas/core/indexes/base.py @@ -457,7 +457,7 @@ def _simple_new(cls, values, name=None, dtype=None, **kwargs): Must be careful not to recurse. """ if not hasattr(values, 'dtype'): - if values is None and dtype is not None: + if (values is None or not len(values)) and dtype is not None: values = np.empty(0, dtype=dtype) else: values = np.array(values, copy=False) @@ -491,6 +491,8 @@ def _shallow_copy(self, values=None, **kwargs): values = self.values attributes = self._get_attributes_dict() attributes.update(kwargs) + if not len(values) and 'dtype' not in kwargs: + attributes['dtype'] = self.dtype return self._simple_new(values, **attributes) def _shallow_copy_with_infer(self, values=None, **kwargs): @@ -511,6 +513,8 @@ def _shallow_copy_with_infer(self, values=None, **kwargs): attributes = self._get_attributes_dict() attributes.update(kwargs) attributes['copy'] = False + if not len(values) and 'dtype' not in kwargs: + attributes['dtype'] = self.dtype if self._infer_as_myclass: try: return self._constructor(values, **attributes) @@ -2815,7 +2819,7 @@ def difference(self, other): self._assert_can_do_setop(other) if self.equals(other): - return Index([], name=self.name) + return self._shallow_copy([]) other, result_name = self._convert_can_do_setop(other) diff --git a/pandas/core/indexes/multi.py b/pandas/core/indexes/multi.py index 60eda70714da5..8226c4bcac494 100644 --- a/pandas/core/indexes/multi.py +++ b/pandas/core/indexes/multi.py @@ -2755,7 +2755,7 @@ def intersection(self, other): other_tuples = other._ndarray_values uniq_tuples = sorted(set(self_tuples) & set(other_tuples)) if len(uniq_tuples) == 0: - return MultiIndex(levels=[[]] * self.nlevels, + return MultiIndex(levels=self.levels, labels=[[]] * self.nlevels, names=result_names, verify_integrity=False) else: @@ -2777,7 +2777,7 @@ def difference(self, other): return self if self.equals(other): - return MultiIndex(levels=[[]] * self.nlevels, + return MultiIndex(levels=self.levels, labels=[[]] * self.nlevels, names=result_names, verify_integrity=False) diff --git a/pandas/tests/indexes/test_base.py b/pandas/tests/indexes/test_base.py index e8f05cb928cad..603fa254d5ca6 100644 --- a/pandas/tests/indexes/test_base.py +++ b/pandas/tests/indexes/test_base.py @@ -20,7 +20,7 @@ from pandas import (period_range, date_range, Series, DataFrame, Float64Index, Int64Index, UInt64Index, CategoricalIndex, DatetimeIndex, TimedeltaIndex, - PeriodIndex, isna) + PeriodIndex, RangeIndex, isna) from pandas.core.index import _get_combined_index, _ensure_index_from_sequences from pandas.util.testing import assert_almost_equal from pandas.compat.numpy import np_datetime64_compat @@ -44,7 +44,7 @@ def setup_method(self, method): tdIndex=tm.makeTimedeltaIndex(100), intIndex=tm.makeIntIndex(100), uintIndex=tm.makeUIntIndex(100), - rangeIndex=tm.makeIntIndex(100), + rangeIndex=tm.makeRangeIndex(100), floatIndex=tm.makeFloatIndex(100), boolIndex=Index([True, False]), catIndex=tm.makeCategoricalIndex(100), @@ -57,6 +57,15 @@ def setup_method(self, method): def create_index(self): return Index(list('abcde')) + def generate_index_types(self, skip_index_keys=[]): + """ + Return a generator of the various index types, leaving + out the ones with a key in skip_index_keys + """ + for key, idx in self.indices.items(): + if key not in skip_index_keys: + yield key, idx + def test_new_axis(self): new_index = self.dateIndex[None, :] assert new_index.ndim == 2 @@ -406,6 +415,27 @@ def test_constructor_dtypes_timedelta(self): pd.TimedeltaIndex(list(values), dtype=dtype)]: tm.assert_index_equal(res, idx) + def test_constructor_empty(self): + skip_index_keys = ["repeats", "periodIndex", "rangeIndex", + "tuples"] + for key, idx in self.generate_index_types(skip_index_keys): + empty = idx.__class__([]) + assert isinstance(empty, idx.__class__) + assert not len(empty) + + empty = PeriodIndex([], freq='B') + assert isinstance(empty, PeriodIndex) + assert not len(empty) + + empty = RangeIndex(step=1) + assert isinstance(empty, pd.RangeIndex) + assert not len(empty) + + empty = MultiIndex(levels=[[1, 2], ['blue', 'red']], + labels=[[], []]) + assert isinstance(empty, MultiIndex) + assert not len(empty) + def test_view_with_args(self): restricted = ['unicodeIndex', 'strIndex', 'catIndex', 'boolIndex', @@ -1034,6 +1064,27 @@ def test_symmetric_difference(self): assert tm.equalContents(result, expected) assert result.name == 'new_name' + def test_difference_type(self): + # GH 20040 + # If taking difference of a set and itself, it + # needs to preserve the type of the index + skip_index_keys = ['repeats'] + for key, idx in self.generate_index_types(skip_index_keys): + result = idx.difference(idx) + expected = idx.drop(idx) + tm.assert_index_equal(result, expected) + + def test_intersection_difference(self): + # GH 20040 + # Test that the intersection of an index with an + # empty index produces the same index as the difference + # of an index with itself. Test for all types + skip_index_keys = ['repeats'] + for key, idx in self.generate_index_types(skip_index_keys): + inter = idx.intersection(idx.drop(idx)) + diff = idx.difference(idx) + tm.assert_index_equal(inter, diff) + def test_is_numeric(self): assert not self.dateIndex.is_numeric() assert not self.strIndex.is_numeric()