diff --git a/doc/source/whatsnew/v0.23.0.txt b/doc/source/whatsnew/v0.23.0.txt index dc305f36f32ec..33ca394db47ca 100644 --- a/doc/source/whatsnew/v0.23.0.txt +++ b/doc/source/whatsnew/v0.23.0.txt @@ -405,6 +405,7 @@ Indexing - Bug in :func:`MultiIndex.__contains__` where non-tuple keys would return ``True`` even if they had been dropped (:issue:`19027`) - Bug in :func:`MultiIndex.set_labels` which would cause casting (and potentially clipping) of the new labels if the ``level`` argument is not 0 or a list like [0, 1, ... ] (:issue:`19057`) - Bug in ``str.extractall`` when there were no matches empty :class:`Index` was returned instead of appropriate :class:`MultiIndex` (:issue:`19034`) +- Bug in :class:`IntervalIndex` where set operations that returned an empty ``IntervalIndex`` had the wrong dtype (:issue:`19101`) I/O ^^^ diff --git a/pandas/core/indexes/interval.py b/pandas/core/indexes/interval.py index 43bdc14106b00..baf80173d7362 100644 --- a/pandas/core/indexes/interval.py +++ b/pandas/core/indexes/interval.py @@ -5,7 +5,7 @@ from pandas.core.dtypes.missing import notna, isna from pandas.core.dtypes.generic import ABCDatetimeIndex, ABCPeriodIndex from pandas.core.dtypes.dtypes import IntervalDtype -from pandas.core.dtypes.cast import maybe_convert_platform +from pandas.core.dtypes.cast import maybe_convert_platform, find_common_type from pandas.core.dtypes.common import ( _ensure_platform_int, is_list_like, @@ -16,6 +16,7 @@ is_integer_dtype, is_float_dtype, is_interval_dtype, + is_object_dtype, is_scalar, is_float, is_number, @@ -1289,9 +1290,25 @@ def func(self, other): msg = ('can only do set operations between two IntervalIndex ' 'objects that are closed on the same side') other = self._as_like_interval_index(other, msg) + + # GH 19016: ensure set op will not return a prohibited dtype + subtypes = [self.dtype.subtype, other.dtype.subtype] + common_subtype = find_common_type(subtypes) + if is_object_dtype(common_subtype): + msg = ('can only do {op} between two IntervalIndex ' + 'objects that have compatible dtypes') + raise TypeError(msg.format(op=op_name)) + result = getattr(self._multiindex, op_name)(other._multiindex) result_name = self.name if self.name == other.name else None - return type(self).from_tuples(result.values, closed=self.closed, + + # GH 19101: ensure empty results have correct dtype + if result.empty: + result = result.values.astype(self.dtype.subtype) + else: + result = result.values + + return type(self).from_tuples(result, closed=self.closed, name=result_name) return func diff --git a/pandas/tests/indexes/interval/test_interval.py b/pandas/tests/indexes/interval/test_interval.py index 98db34a9f90f4..b6d49c9e7ba19 100644 --- a/pandas/tests/indexes/interval/test_interval.py +++ b/pandas/tests/indexes/interval/test_interval.py @@ -880,6 +880,16 @@ def test_union(self, closed): tm.assert_index_equal(index.union(index), index) tm.assert_index_equal(index.union(index[:1]), index) + # GH 19101: empty result, same dtype + index = IntervalIndex(np.array([], dtype='int64'), closed=closed) + result = index.union(index) + tm.assert_index_equal(result, index) + + # GH 19101: empty result, different dtypes + other = IntervalIndex(np.array([], dtype='float64'), closed=closed) + result = index.union(other) + tm.assert_index_equal(result, index) + def test_intersection(self, closed): index = self.create_index(closed=closed) other = IntervalIndex.from_breaks(range(5, 13), closed=closed) @@ -893,14 +903,48 @@ def test_intersection(self, closed): tm.assert_index_equal(index.intersection(index), index) + # GH 19101: empty result, same dtype + other = IntervalIndex.from_breaks(range(300, 314), closed=closed) + expected = IntervalIndex(np.array([], dtype='int64'), closed=closed) + result = index.intersection(other) + tm.assert_index_equal(result, expected) + + # GH 19101: empty result, different dtypes + breaks = np.arange(300, 314, dtype='float64') + other = IntervalIndex.from_breaks(breaks, closed=closed) + result = index.intersection(other) + tm.assert_index_equal(result, expected) + def test_difference(self, closed): index = self.create_index(closed=closed) tm.assert_index_equal(index.difference(index[:1]), index[1:]) + # GH 19101: empty result, same dtype + result = index.difference(index) + expected = IntervalIndex(np.array([], dtype='int64'), closed=closed) + tm.assert_index_equal(result, expected) + + # GH 19101: empty result, different dtypes + other = IntervalIndex.from_arrays(index.left.astype('float64'), + index.right, closed=closed) + result = index.difference(other) + tm.assert_index_equal(result, expected) + def test_symmetric_difference(self, closed): - idx = self.create_index(closed=closed) - result = idx[1:].symmetric_difference(idx[:-1]) - expected = IntervalIndex([idx[0], idx[-1]]) + index = self.create_index(closed=closed) + result = index[1:].symmetric_difference(index[:-1]) + expected = IntervalIndex([index[0], index[-1]]) + tm.assert_index_equal(result, expected) + + # GH 19101: empty result, same dtype + result = index.symmetric_difference(index) + expected = IntervalIndex(np.array([], dtype='int64'), closed=closed) + tm.assert_index_equal(result, expected) + + # GH 19101: empty result, different dtypes + other = IntervalIndex.from_arrays(index.left.astype('float64'), + index.right, closed=closed) + result = index.symmetric_difference(other) tm.assert_index_equal(result, expected) @pytest.mark.parametrize('op_name', [ @@ -909,17 +953,25 @@ def test_set_operation_errors(self, closed, op_name): index = self.create_index(closed=closed) set_op = getattr(index, op_name) - # test errors + # non-IntervalIndex msg = ('can only do set operations between two IntervalIndex objects ' 'that are closed on the same side') with tm.assert_raises_regex(ValueError, msg): set_op(Index([1, 2, 3])) + # mixed closed for other_closed in {'right', 'left', 'both', 'neither'} - {closed}: other = self.create_index(closed=other_closed) with tm.assert_raises_regex(ValueError, msg): set_op(other) + # GH 19016: incompatible dtypes + other = interval_range(Timestamp('20180101'), periods=9, closed=closed) + msg = ('can only do {op} between two IntervalIndex objects that have ' + 'compatible dtypes').format(op=op_name) + with tm.assert_raises_regex(TypeError, msg): + set_op(other) + def test_isin(self, closed): index = self.create_index(closed=closed)