Skip to content

Commit 46f6a86

Browse files
committed
BUG: IntervalIndex set op bugs for empty results
1 parent 36a71eb commit 46f6a86

File tree

3 files changed

+87
-9
lines changed

3 files changed

+87
-9
lines changed

doc/source/whatsnew/v0.23.0.txt

+1
Original file line numberDiff line numberDiff line change
@@ -398,6 +398,7 @@ Indexing
398398
- Bug in ``__setitem__`` when indexing a :class:`DataFrame` with a 2-d boolean ndarray (:issue:`18582`)
399399
- Bug in :func:`MultiIndex.set_labels` which would cause casting (and potentially clipping) of the new labels if the ``level`` argument is not 0 or a list like [0, 1, ... ] (:issue:`19057`)
400400
- Bug in ``str.extractall`` when there were no matches empty :class:`Index` was returned instead of appropriate :class:`MultiIndex` (:issue:`19034`)
401+
- Bug in :class:`IntervalIndex` where set operations that returned an empty ``IntervalIndex`` had the wrong dtype (:issue:`19101`)
401402

402403
I/O
403404
^^^

pandas/core/indexes/interval.py

+25-5
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from pandas.core.dtypes.missing import notna, isna
66
from pandas.core.dtypes.generic import ABCDatetimeIndex, ABCPeriodIndex
77
from pandas.core.dtypes.dtypes import IntervalDtype
8-
from pandas.core.dtypes.cast import maybe_convert_platform
8+
from pandas.core.dtypes.cast import maybe_convert_platform, find_common_type
99
from pandas.core.dtypes.common import (
1010
_ensure_platform_int,
1111
is_list_like,
@@ -16,6 +16,7 @@
1616
is_integer_dtype,
1717
is_float_dtype,
1818
is_interval_dtype,
19+
is_object_dtype,
1920
is_scalar,
2021
is_float,
2122
is_number,
@@ -1284,21 +1285,40 @@ def equals(self, other):
12841285
self.right.equals(other.right) and
12851286
self.closed == other.closed)
12861287

1287-
def _setop(op_name):
1288+
def _setop(op_name, check_subtypes=False):
12881289
def func(self, other):
12891290
msg = ('can only do set operations between two IntervalIndex '
12901291
'objects that are closed on the same side')
12911292
other = self._as_like_interval_index(other, msg)
1293+
1294+
if check_subtypes:
1295+
# GH 19016: ensure set op will not return a prohibited dtype
1296+
subtypes = [self.dtype.subtype, other.dtype.subtype]
1297+
result_subtype = find_common_type(subtypes)
1298+
if is_object_dtype(result_subtype):
1299+
msg = ('can only do {op} between two IntervalIndex '
1300+
'objects that have compatible dtypes')
1301+
raise TypeError(msg.format(op=op_name))
1302+
else:
1303+
result_subtype = self.dtype.subtype
1304+
12921305
result = getattr(self._multiindex, op_name)(other._multiindex)
12931306
result_name = self.name if self.name == other.name else None
1294-
return type(self).from_tuples(result.values, closed=self.closed,
1307+
1308+
# GH 19101: ensure empty results have correct dtype
1309+
if result.empty:
1310+
result = result.values.astype(result_subtype)
1311+
else:
1312+
result = result.values
1313+
1314+
return type(self).from_tuples(result, closed=self.closed,
12951315
name=result_name)
12961316
return func
12971317

1298-
union = _setop('union')
1318+
union = _setop('union', check_subtypes=True)
12991319
intersection = _setop('intersection')
13001320
difference = _setop('difference')
1301-
symmetric_difference = _setop('symmetric_difference')
1321+
symmetric_difference = _setop('symmetric_difference', check_subtypes=True)
13021322

13031323
# TODO: arithmetic operations
13041324

pandas/tests/indexes/interval/test_interval.py

+61-4
Original file line numberDiff line numberDiff line change
@@ -880,6 +880,16 @@ def test_union(self, closed):
880880
tm.assert_index_equal(index.union(index), index)
881881
tm.assert_index_equal(index.union(index[:1]), index)
882882

883+
# GH 19101: empty result, same dtype
884+
index = IntervalIndex(np.array([], dtype='int64'), closed=closed)
885+
result = index.union(index)
886+
tm.assert_index_equal(result, index)
887+
888+
# GH 19101: empty result, different dtypes
889+
other = IntervalIndex(np.array([], dtype='float64'), closed=closed)
890+
result = index.union(other)
891+
tm.assert_index_equal(result, other)
892+
883893
def test_intersection(self, closed):
884894
index = self.create_index(closed=closed)
885895
other = IntervalIndex.from_breaks(range(5, 13), closed=closed)
@@ -893,14 +903,49 @@ def test_intersection(self, closed):
893903

894904
tm.assert_index_equal(index.intersection(index), index)
895905

906+
# GH 19101: empty result, same dtype
907+
other = IntervalIndex.from_breaks(range(300, 314), closed=closed)
908+
expected = IntervalIndex(np.array([], dtype='int64'), closed=closed)
909+
result = index.intersection(other)
910+
tm.assert_index_equal(result, expected)
911+
912+
# GH 19101: empty result, different dtypes
913+
breaks = np.arange(300, 314, dtype='float64')
914+
other = IntervalIndex.from_breaks(breaks, closed=closed)
915+
result = index.intersection(other)
916+
tm.assert_index_equal(result, expected)
917+
896918
def test_difference(self, closed):
897919
index = self.create_index(closed=closed)
898920
tm.assert_index_equal(index.difference(index[:1]), index[1:])
899921

922+
# GH 19101: empty result, same dtype
923+
result = index.difference(index)
924+
expected = IntervalIndex(np.array([], dtype='int64'), closed=closed)
925+
tm.assert_index_equal(result, expected)
926+
927+
# GH 19101: empty result, different dtypes
928+
other = IntervalIndex.from_arrays(index.left.astype('float64'),
929+
index.right, closed=closed)
930+
result = index.difference(other)
931+
tm.assert_index_equal(result, expected)
932+
900933
def test_symmetric_difference(self, closed):
901-
idx = self.create_index(closed=closed)
902-
result = idx[1:].symmetric_difference(idx[:-1])
903-
expected = IntervalIndex([idx[0], idx[-1]])
934+
index = self.create_index(closed=closed)
935+
result = index[1:].symmetric_difference(index[:-1])
936+
expected = IntervalIndex([index[0], index[-1]])
937+
tm.assert_index_equal(result, expected)
938+
939+
# GH 19101: empty result, same dtype
940+
result = index.symmetric_difference(index)
941+
expected = IntervalIndex(np.array([], dtype='int64'), closed=closed)
942+
tm.assert_index_equal(result, expected)
943+
944+
# GH 19101: empty result, different dtypes
945+
other = IntervalIndex.from_arrays(index.left.astype('float64'),
946+
index.right, closed=closed)
947+
result = index.symmetric_difference(other)
948+
expected = IntervalIndex(np.array([], dtype='float64'), closed=closed)
904949
tm.assert_index_equal(result, expected)
905950

906951
@pytest.mark.parametrize('op_name', [
@@ -909,17 +954,29 @@ def test_set_operation_errors(self, closed, op_name):
909954
index = self.create_index(closed=closed)
910955
set_op = getattr(index, op_name)
911956

912-
# test errors
957+
# non-IntervalIndex
913958
msg = ('can only do set operations between two IntervalIndex objects '
914959
'that are closed on the same side')
915960
with tm.assert_raises_regex(ValueError, msg):
916961
set_op(Index([1, 2, 3]))
917962

963+
# mixed closed
918964
for other_closed in {'right', 'left', 'both', 'neither'} - {closed}:
919965
other = self.create_index(closed=other_closed)
920966
with tm.assert_raises_regex(ValueError, msg):
921967
set_op(other)
922968

969+
# GH 19016: incompatible dtypes
970+
other = interval_range(Timestamp('20180101'), periods=9, closed=closed)
971+
if op_name in ('union', 'symmetric_difference'):
972+
msg = ('can only do {op} between two IntervalIndex objects '
973+
'that have compatible dtypes').format(op=op_name)
974+
with tm.assert_raises_regex(TypeError, msg):
975+
set_op(other)
976+
else:
977+
# should not raise
978+
set_op(other)
979+
923980
def test_isin(self, closed):
924981
index = self.create_index(closed=closed)
925982

0 commit comments

Comments
 (0)