Skip to content

Commit 743c2ba

Browse files
committed
BUG: IntervalIndex set op bugs for empty results
1 parent 36a71eb commit 743c2ba

File tree

3 files changed

+78
-6
lines changed

3 files changed

+78
-6
lines changed

doc/source/whatsnew/v0.23.0.txt

+1
Original file line numberDiff line numberDiff line change
@@ -398,6 +398,7 @@ Indexing
398398
- Bug in ``__setitem__`` when indexing a :class:`DataFrame` with a 2-d boolean ndarray (:issue:`18582`)
399399
- Bug in :func:`MultiIndex.set_labels` which would cause casting (and potentially clipping) of the new labels if the ``level`` argument is not 0 or a list like [0, 1, ... ] (:issue:`19057`)
400400
- Bug in ``str.extractall`` when there were no matches empty :class:`Index` was returned instead of appropriate :class:`MultiIndex` (:issue:`19034`)
401+
- Bug in :class:`IntervalIndex` where set operations that returned an empty ``IntervalIndex`` had the wrong dtype (:issue:`19101`)
401402

402403
I/O
403404
^^^

pandas/core/indexes/interval.py

+18-2
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from pandas.core.dtypes.missing import notna, isna
66
from pandas.core.dtypes.generic import ABCDatetimeIndex, ABCPeriodIndex
77
from pandas.core.dtypes.dtypes import IntervalDtype
8-
from pandas.core.dtypes.cast import maybe_convert_platform
8+
from pandas.core.dtypes.cast import maybe_convert_platform, find_common_type
99
from pandas.core.dtypes.common import (
1010
_ensure_platform_int,
1111
is_list_like,
@@ -16,6 +16,7 @@
1616
is_integer_dtype,
1717
is_float_dtype,
1818
is_interval_dtype,
19+
is_object_dtype,
1920
is_scalar,
2021
is_float,
2122
is_number,
@@ -1289,9 +1290,24 @@ def func(self, other):
12891290
msg = ('can only do set operations between two IntervalIndex '
12901291
'objects that are closed on the same side')
12911292
other = self._as_like_interval_index(other, msg)
1293+
1294+
# GH 19016: ensure set op won't return a prohibited dtype
1295+
dtype = find_common_type([self.dtype.subtype, other.dtype.subtype])
1296+
if is_object_dtype(dtype):
1297+
msg = ('can only do set operations between two IntervalIndex '
1298+
'objects that have compatible dtypes')
1299+
raise TypeError(msg)
1300+
12921301
result = getattr(self._multiindex, op_name)(other._multiindex)
12931302
result_name = self.name if self.name == other.name else None
1294-
return type(self).from_tuples(result.values, closed=self.closed,
1303+
1304+
# GH 19101: ensure empty results have correct dtype
1305+
if result.empty:
1306+
result = result.values.astype(dtype)
1307+
else:
1308+
result = result.values
1309+
1310+
return type(self).from_tuples(result, closed=self.closed,
12951311
name=result_name)
12961312
return func
12971313

pandas/tests/indexes/interval/test_interval.py

+59-4
Original file line numberDiff line numberDiff line change
@@ -880,6 +880,16 @@ def test_union(self, closed):
880880
tm.assert_index_equal(index.union(index), index)
881881
tm.assert_index_equal(index.union(index[:1]), index)
882882

883+
# GH 19101: empty result, same dtype
884+
index = IntervalIndex(np.array([], dtype='int64'), closed=closed)
885+
result = index.union(index)
886+
tm.assert_index_equal(result, index)
887+
888+
# GH 19101: empty result, different dtypes
889+
other = IntervalIndex(np.array([], dtype='float64'), closed=closed)
890+
result = index.union(other)
891+
tm.assert_index_equal(result, other)
892+
883893
def test_intersection(self, closed):
884894
index = self.create_index(closed=closed)
885895
other = IntervalIndex.from_breaks(range(5, 13), closed=closed)
@@ -893,14 +903,51 @@ def test_intersection(self, closed):
893903

894904
tm.assert_index_equal(index.intersection(index), index)
895905

906+
# GH 19101: empty result, same dtype
907+
other = IntervalIndex.from_breaks(range(300, 314), closed=closed)
908+
expected = IntervalIndex(np.array([], dtype='int64'), closed=closed)
909+
result = index.intersection(other)
910+
tm.assert_index_equal(result, expected)
911+
912+
# GH 19101: empty result, different dtypes
913+
breaks = np.arange(300, 314, dtype='float64')
914+
other = IntervalIndex.from_breaks(breaks, closed=closed)
915+
expected = IntervalIndex(np.array([], dtype='float64'), closed=closed)
916+
result = index.intersection(other)
917+
tm.assert_index_equal(result, expected)
918+
896919
def test_difference(self, closed):
897920
index = self.create_index(closed=closed)
898921
tm.assert_index_equal(index.difference(index[:1]), index[1:])
899922

923+
# GH 19101: empty result, same dtype
924+
result = index.difference(index)
925+
expected = IntervalIndex(np.array([], dtype='int64'), closed=closed)
926+
tm.assert_index_equal(result, expected)
927+
928+
# GH 19101: empty result, different dtypes
929+
other = IntervalIndex.from_arrays(index.left.astype('float64'),
930+
index.right, closed=closed)
931+
result = index.difference(other)
932+
expected = IntervalIndex(np.array([], dtype='float64'), closed=closed)
933+
tm.assert_index_equal(result, expected)
934+
900935
def test_symmetric_difference(self, closed):
901-
idx = self.create_index(closed=closed)
902-
result = idx[1:].symmetric_difference(idx[:-1])
903-
expected = IntervalIndex([idx[0], idx[-1]])
936+
index = self.create_index(closed=closed)
937+
result = index[1:].symmetric_difference(index[:-1])
938+
expected = IntervalIndex([index[0], index[-1]])
939+
tm.assert_index_equal(result, expected)
940+
941+
# GH 19101: empty result, same dtype
942+
result = index.symmetric_difference(index)
943+
expected = IntervalIndex(np.array([], dtype='int64'), closed=closed)
944+
tm.assert_index_equal(result, expected)
945+
946+
# GH 19101: empty result, different dtypes
947+
other = IntervalIndex.from_arrays(index.left.astype('float64'),
948+
index.right, closed=closed)
949+
result = index.symmetric_difference(other)
950+
expected = IntervalIndex(np.array([], dtype='float64'), closed=closed)
904951
tm.assert_index_equal(result, expected)
905952

906953
@pytest.mark.parametrize('op_name', [
@@ -909,17 +956,25 @@ def test_set_operation_errors(self, closed, op_name):
909956
index = self.create_index(closed=closed)
910957
set_op = getattr(index, op_name)
911958

912-
# test errors
959+
# non-IntervalIndex
913960
msg = ('can only do set operations between two IntervalIndex objects '
914961
'that are closed on the same side')
915962
with tm.assert_raises_regex(ValueError, msg):
916963
set_op(Index([1, 2, 3]))
917964

965+
# mixed closed
918966
for other_closed in {'right', 'left', 'both', 'neither'} - {closed}:
919967
other = self.create_index(closed=other_closed)
920968
with tm.assert_raises_regex(ValueError, msg):
921969
set_op(other)
922970

971+
# GH 19016: incompatible dtypes
972+
other = interval_range(Timestamp('20180101'), periods=9, closed=closed)
973+
msg = ('can only do set operations between two IntervalIndex objects '
974+
'that have compatible dtypes')
975+
with tm.assert_raises_regex(TypeError, msg):
976+
set_op(other)
977+
923978
def test_isin(self, closed):
924979
index = self.create_index(closed=closed)
925980

0 commit comments

Comments
 (0)