Skip to content

Commit 20b8d47

Browse files
committed
BUG: Index.difference of itself doesn't preserve type
1 parent 92c2910 commit 20b8d47

File tree

4 files changed

+62
-6
lines changed

4 files changed

+62
-6
lines changed

doc/source/whatsnew/v0.23.0.txt

+1
Original file line numberDiff line numberDiff line change
@@ -939,6 +939,7 @@ Indexing
939939
- Bug in :class:`IntervalIndex` where set operations that returned an empty ``IntervalIndex`` had the wrong dtype (:issue:`19101`)
940940
- Bug in :meth:`DataFrame.drop_duplicates` where no ``KeyError`` is raised when passing in columns that don't exist on the ``DataFrame`` (issue:`19726`)
941941
- Bug in ``Index`` subclasses constructors that ignore unexpected keyword arguments (:issue:`19348`)
942+
- Bug in :meth:`Index.difference` when taking difference of an ``Index`` with itself (:issue:`20040`)
942943

943944

944945
MultiIndex

pandas/core/indexes/base.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -458,7 +458,7 @@ def _simple_new(cls, values, name=None, dtype=None, **kwargs):
458458
Must be careful not to recurse.
459459
"""
460460
if not hasattr(values, 'dtype'):
461-
if values is None and dtype is not None:
461+
if (values is None or not len(values)) and dtype is not None:
462462
values = np.empty(0, dtype=dtype)
463463
else:
464464
values = np.array(values, copy=False)
@@ -492,6 +492,8 @@ def _shallow_copy(self, values=None, **kwargs):
492492
values = self.values
493493
attributes = self._get_attributes_dict()
494494
attributes.update(kwargs)
495+
if not len(values) and 'dtype' not in kwargs:
496+
attributes['dtype'] = self.dtype
495497
return self._simple_new(values, **attributes)
496498

497499
def _shallow_copy_with_infer(self, values=None, **kwargs):
@@ -512,6 +514,8 @@ def _shallow_copy_with_infer(self, values=None, **kwargs):
512514
attributes = self._get_attributes_dict()
513515
attributes.update(kwargs)
514516
attributes['copy'] = False
517+
if not len(values) and 'dtype' not in kwargs:
518+
attributes['dtype'] = self.dtype
515519
if self._infer_as_myclass:
516520
try:
517521
return self._constructor(values, **attributes)
@@ -2816,7 +2820,7 @@ def difference(self, other):
28162820
self._assert_can_do_setop(other)
28172821

28182822
if self.equals(other):
2819-
return Index([], name=self.name)
2823+
return self._shallow_copy([])
28202824

28212825
other, result_name = self._convert_can_do_setop(other)
28222826

pandas/core/indexes/multi.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -2756,7 +2756,7 @@ def intersection(self, other):
27562756
other_tuples = other._ndarray_values
27572757
uniq_tuples = sorted(set(self_tuples) & set(other_tuples))
27582758
if len(uniq_tuples) == 0:
2759-
return MultiIndex(levels=[[]] * self.nlevels,
2759+
return MultiIndex(levels=self.levels,
27602760
labels=[[]] * self.nlevels,
27612761
names=result_names, verify_integrity=False)
27622762
else:
@@ -2778,7 +2778,7 @@ def difference(self, other):
27782778
return self
27792779

27802780
if self.equals(other):
2781-
return MultiIndex(levels=[[]] * self.nlevels,
2781+
return MultiIndex(levels=self.levels,
27822782
labels=[[]] * self.nlevels,
27832783
names=result_names, verify_integrity=False)
27842784

pandas/tests/indexes/test_base.py

+53-2
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from pandas import (period_range, date_range, Series,
2121
DataFrame, Float64Index, Int64Index, UInt64Index,
2222
CategoricalIndex, DatetimeIndex, TimedeltaIndex,
23-
PeriodIndex, isna)
23+
PeriodIndex, RangeIndex, isna)
2424
from pandas.core.index import _get_combined_index, _ensure_index_from_sequences
2525
from pandas.util.testing import assert_almost_equal
2626
from pandas.compat.numpy import np_datetime64_compat
@@ -44,7 +44,7 @@ def setup_method(self, method):
4444
tdIndex=tm.makeTimedeltaIndex(100),
4545
intIndex=tm.makeIntIndex(100),
4646
uintIndex=tm.makeUIntIndex(100),
47-
rangeIndex=tm.makeIntIndex(100),
47+
rangeIndex=tm.makeRangeIndex(100),
4848
floatIndex=tm.makeFloatIndex(100),
4949
boolIndex=Index([True, False]),
5050
catIndex=tm.makeCategoricalIndex(100),
@@ -57,6 +57,15 @@ def setup_method(self, method):
5757
def create_index(self):
5858
return Index(list('abcde'))
5959

60+
def generate_index_types(self, skip_index_keys=[]):
61+
"""
62+
Return a generator of the various index types, leaving
63+
out the ones with a key in skip_index_keys
64+
"""
65+
for key, idx in self.indices.items():
66+
if key not in skip_index_keys:
67+
yield key, idx
68+
6069
def test_new_axis(self):
6170
new_index = self.dateIndex[None, :]
6271
assert new_index.ndim == 2
@@ -406,6 +415,27 @@ def test_constructor_dtypes_timedelta(self):
406415
pd.TimedeltaIndex(list(values), dtype=dtype)]:
407416
tm.assert_index_equal(res, idx)
408417

418+
def test_constructor_empty(self):
419+
skip_index_keys = ["repeats", "periodIndex", "rangeIndex",
420+
"tuples"]
421+
for key, idx in self.generate_index_types(skip_index_keys):
422+
empty = idx.__class__([])
423+
assert isinstance(empty, idx.__class__)
424+
assert not len(empty)
425+
426+
empty = PeriodIndex([], freq='B')
427+
assert isinstance(empty, PeriodIndex)
428+
assert not len(empty)
429+
430+
empty = RangeIndex(step=1)
431+
assert isinstance(empty, pd.RangeIndex)
432+
assert not len(empty)
433+
434+
empty = MultiIndex(levels=[[1, 2], ['blue', 'red']],
435+
labels=[[], []])
436+
assert isinstance(empty, MultiIndex)
437+
assert not len(empty)
438+
409439
def test_view_with_args(self):
410440

411441
restricted = ['unicodeIndex', 'strIndex', 'catIndex', 'boolIndex',
@@ -1034,6 +1064,27 @@ def test_symmetric_difference(self):
10341064
assert tm.equalContents(result, expected)
10351065
assert result.name == 'new_name'
10361066

1067+
def test_difference_type(self):
1068+
# GH 20040
1069+
# If taking difference of a set and itself, it
1070+
# needs to preserve the type of the index
1071+
skip_index_keys = ['repeats']
1072+
for key, idx in self.generate_index_types(skip_index_keys):
1073+
result = idx.difference(idx)
1074+
expected = idx.drop(idx)
1075+
tm.assert_index_equal(result, expected)
1076+
1077+
def test_intersection_difference(self):
1078+
# GH 20040
1079+
# Test that the intersection of an index with an
1080+
# empty index produces the same index as the difference
1081+
# of an index with itself. Test for all types
1082+
skip_index_keys = ['repeats']
1083+
for key, idx in self.generate_index_types(skip_index_keys):
1084+
inter = idx.intersection(idx.drop(idx))
1085+
diff = idx.difference(idx)
1086+
tm.assert_index_equal(inter, diff)
1087+
10371088
def test_is_numeric(self):
10381089
assert not self.dateIndex.is_numeric()
10391090
assert not self.strIndex.is_numeric()

0 commit comments

Comments
 (0)