From be83b7a036cb8e6bf9927689f8e83d7f7b2fb7a6 Mon Sep 17 00:00:00 2001 From: jschendel Date: Sun, 7 Jan 2018 01:20:31 -0700 Subject: [PATCH 1/2] BUG: IntervalIndex set op bugs for empty results --- doc/source/whatsnew/v0.23.0.txt | 1 + pandas/core/indexes/interval.py | 30 +++++++-- .../tests/indexes/interval/test_interval.py | 65 +++++++++++++++++-- 3 files changed, 87 insertions(+), 9 deletions(-) diff --git a/doc/source/whatsnew/v0.23.0.txt b/doc/source/whatsnew/v0.23.0.txt index dc305f36f32ec..33ca394db47ca 100644 --- a/doc/source/whatsnew/v0.23.0.txt +++ b/doc/source/whatsnew/v0.23.0.txt @@ -405,6 +405,7 @@ Indexing - Bug in :func:`MultiIndex.__contains__` where non-tuple keys would return ``True`` even if they had been dropped (:issue:`19027`) - Bug in :func:`MultiIndex.set_labels` which would cause casting (and potentially clipping) of the new labels if the ``level`` argument is not 0 or a list like [0, 1, ... ] (:issue:`19057`) - Bug in ``str.extractall`` when there were no matches empty :class:`Index` was returned instead of appropriate :class:`MultiIndex` (:issue:`19034`) +- Bug in :class:`IntervalIndex` where set operations that returned an empty ``IntervalIndex`` had the wrong dtype (:issue:`19101`) I/O ^^^ diff --git a/pandas/core/indexes/interval.py b/pandas/core/indexes/interval.py index 43bdc14106b00..5b06813c20dab 100644 --- a/pandas/core/indexes/interval.py +++ b/pandas/core/indexes/interval.py @@ -5,7 +5,7 @@ from pandas.core.dtypes.missing import notna, isna from pandas.core.dtypes.generic import ABCDatetimeIndex, ABCPeriodIndex from pandas.core.dtypes.dtypes import IntervalDtype -from pandas.core.dtypes.cast import maybe_convert_platform +from pandas.core.dtypes.cast import maybe_convert_platform, find_common_type from pandas.core.dtypes.common import ( _ensure_platform_int, is_list_like, @@ -16,6 +16,7 @@ is_integer_dtype, is_float_dtype, is_interval_dtype, + is_object_dtype, is_scalar, is_float, is_number, @@ -1284,21 +1285,40 @@ def equals(self, other): self.right.equals(other.right) and self.closed == other.closed) - def _setop(op_name): + def _setop(op_name, check_subtypes=False): def func(self, other): msg = ('can only do set operations between two IntervalIndex ' 'objects that are closed on the same side') other = self._as_like_interval_index(other, msg) + + if check_subtypes: + # GH 19016: ensure set op will not return a prohibited dtype + subtypes = [self.dtype.subtype, other.dtype.subtype] + result_subtype = find_common_type(subtypes) + if is_object_dtype(result_subtype): + msg = ('can only do {op} between two IntervalIndex ' + 'objects that have compatible dtypes') + raise TypeError(msg.format(op=op_name)) + else: + result_subtype = self.dtype.subtype + result = getattr(self._multiindex, op_name)(other._multiindex) result_name = self.name if self.name == other.name else None - return type(self).from_tuples(result.values, closed=self.closed, + + # GH 19101: ensure empty results have correct dtype + if result.empty: + result = result.values.astype(result_subtype) + else: + result = result.values + + return type(self).from_tuples(result, closed=self.closed, name=result_name) return func - union = _setop('union') + union = _setop('union', check_subtypes=True) intersection = _setop('intersection') difference = _setop('difference') - symmetric_difference = _setop('symmetric_difference') + symmetric_difference = _setop('symmetric_difference', check_subtypes=True) # TODO: arithmetic operations diff --git a/pandas/tests/indexes/interval/test_interval.py b/pandas/tests/indexes/interval/test_interval.py index 98db34a9f90f4..05db4d2cab940 100644 --- a/pandas/tests/indexes/interval/test_interval.py +++ b/pandas/tests/indexes/interval/test_interval.py @@ -880,6 +880,16 @@ def test_union(self, closed): tm.assert_index_equal(index.union(index), index) tm.assert_index_equal(index.union(index[:1]), index) + # GH 19101: empty result, same dtype + index = IntervalIndex(np.array([], dtype='int64'), closed=closed) + result = index.union(index) + tm.assert_index_equal(result, index) + + # GH 19101: empty result, different dtypes + other = IntervalIndex(np.array([], dtype='float64'), closed=closed) + result = index.union(other) + tm.assert_index_equal(result, other) + def test_intersection(self, closed): index = self.create_index(closed=closed) other = IntervalIndex.from_breaks(range(5, 13), closed=closed) @@ -893,14 +903,49 @@ def test_intersection(self, closed): tm.assert_index_equal(index.intersection(index), index) + # GH 19101: empty result, same dtype + other = IntervalIndex.from_breaks(range(300, 314), closed=closed) + expected = IntervalIndex(np.array([], dtype='int64'), closed=closed) + result = index.intersection(other) + tm.assert_index_equal(result, expected) + + # GH 19101: empty result, different dtypes + breaks = np.arange(300, 314, dtype='float64') + other = IntervalIndex.from_breaks(breaks, closed=closed) + result = index.intersection(other) + tm.assert_index_equal(result, expected) + def test_difference(self, closed): index = self.create_index(closed=closed) tm.assert_index_equal(index.difference(index[:1]), index[1:]) + # GH 19101: empty result, same dtype + result = index.difference(index) + expected = IntervalIndex(np.array([], dtype='int64'), closed=closed) + tm.assert_index_equal(result, expected) + + # GH 19101: empty result, different dtypes + other = IntervalIndex.from_arrays(index.left.astype('float64'), + index.right, closed=closed) + result = index.difference(other) + tm.assert_index_equal(result, expected) + def test_symmetric_difference(self, closed): - idx = self.create_index(closed=closed) - result = idx[1:].symmetric_difference(idx[:-1]) - expected = IntervalIndex([idx[0], idx[-1]]) + index = self.create_index(closed=closed) + result = index[1:].symmetric_difference(index[:-1]) + expected = IntervalIndex([index[0], index[-1]]) + tm.assert_index_equal(result, expected) + + # GH 19101: empty result, same dtype + result = index.symmetric_difference(index) + expected = IntervalIndex(np.array([], dtype='int64'), closed=closed) + tm.assert_index_equal(result, expected) + + # GH 19101: empty result, different dtypes + other = IntervalIndex.from_arrays(index.left.astype('float64'), + index.right, closed=closed) + result = index.symmetric_difference(other) + expected = IntervalIndex(np.array([], dtype='float64'), closed=closed) tm.assert_index_equal(result, expected) @pytest.mark.parametrize('op_name', [ @@ -909,17 +954,29 @@ def test_set_operation_errors(self, closed, op_name): index = self.create_index(closed=closed) set_op = getattr(index, op_name) - # test errors + # non-IntervalIndex msg = ('can only do set operations between two IntervalIndex objects ' 'that are closed on the same side') with tm.assert_raises_regex(ValueError, msg): set_op(Index([1, 2, 3])) + # mixed closed for other_closed in {'right', 'left', 'both', 'neither'} - {closed}: other = self.create_index(closed=other_closed) with tm.assert_raises_regex(ValueError, msg): set_op(other) + # GH 19016: incompatible dtypes + other = interval_range(Timestamp('20180101'), periods=9, closed=closed) + if op_name in ('union', 'symmetric_difference'): + msg = ('can only do {op} between two IntervalIndex objects ' + 'that have compatible dtypes').format(op=op_name) + with tm.assert_raises_regex(TypeError, msg): + set_op(other) + else: + # should not raise + set_op(other) + def test_isin(self, closed): index = self.create_index(closed=closed) From 49e0a496dc435f278f01b673bd04fde6217187e4 Mon Sep 17 00:00:00 2001 From: jschendel Date: Sun, 7 Jan 2018 16:56:25 -0700 Subject: [PATCH 2/2] always check subtypes --- pandas/core/indexes/interval.py | 25 ++++++++----------- .../tests/indexes/interval/test_interval.py | 13 +++------- 2 files changed, 15 insertions(+), 23 deletions(-) diff --git a/pandas/core/indexes/interval.py b/pandas/core/indexes/interval.py index 5b06813c20dab..baf80173d7362 100644 --- a/pandas/core/indexes/interval.py +++ b/pandas/core/indexes/interval.py @@ -1285,29 +1285,26 @@ def equals(self, other): self.right.equals(other.right) and self.closed == other.closed) - def _setop(op_name, check_subtypes=False): + def _setop(op_name): def func(self, other): msg = ('can only do set operations between two IntervalIndex ' 'objects that are closed on the same side') other = self._as_like_interval_index(other, msg) - if check_subtypes: - # GH 19016: ensure set op will not return a prohibited dtype - subtypes = [self.dtype.subtype, other.dtype.subtype] - result_subtype = find_common_type(subtypes) - if is_object_dtype(result_subtype): - msg = ('can only do {op} between two IntervalIndex ' - 'objects that have compatible dtypes') - raise TypeError(msg.format(op=op_name)) - else: - result_subtype = self.dtype.subtype + # GH 19016: ensure set op will not return a prohibited dtype + subtypes = [self.dtype.subtype, other.dtype.subtype] + common_subtype = find_common_type(subtypes) + if is_object_dtype(common_subtype): + msg = ('can only do {op} between two IntervalIndex ' + 'objects that have compatible dtypes') + raise TypeError(msg.format(op=op_name)) result = getattr(self._multiindex, op_name)(other._multiindex) result_name = self.name if self.name == other.name else None # GH 19101: ensure empty results have correct dtype if result.empty: - result = result.values.astype(result_subtype) + result = result.values.astype(self.dtype.subtype) else: result = result.values @@ -1315,10 +1312,10 @@ def func(self, other): name=result_name) return func - union = _setop('union', check_subtypes=True) + union = _setop('union') intersection = _setop('intersection') difference = _setop('difference') - symmetric_difference = _setop('symmetric_difference', check_subtypes=True) + symmetric_difference = _setop('symmetric_difference') # TODO: arithmetic operations diff --git a/pandas/tests/indexes/interval/test_interval.py b/pandas/tests/indexes/interval/test_interval.py index 05db4d2cab940..b6d49c9e7ba19 100644 --- a/pandas/tests/indexes/interval/test_interval.py +++ b/pandas/tests/indexes/interval/test_interval.py @@ -888,7 +888,7 @@ def test_union(self, closed): # GH 19101: empty result, different dtypes other = IntervalIndex(np.array([], dtype='float64'), closed=closed) result = index.union(other) - tm.assert_index_equal(result, other) + tm.assert_index_equal(result, index) def test_intersection(self, closed): index = self.create_index(closed=closed) @@ -945,7 +945,6 @@ def test_symmetric_difference(self, closed): other = IntervalIndex.from_arrays(index.left.astype('float64'), index.right, closed=closed) result = index.symmetric_difference(other) - expected = IntervalIndex(np.array([], dtype='float64'), closed=closed) tm.assert_index_equal(result, expected) @pytest.mark.parametrize('op_name', [ @@ -968,13 +967,9 @@ def test_set_operation_errors(self, closed, op_name): # GH 19016: incompatible dtypes other = interval_range(Timestamp('20180101'), periods=9, closed=closed) - if op_name in ('union', 'symmetric_difference'): - msg = ('can only do {op} between two IntervalIndex objects ' - 'that have compatible dtypes').format(op=op_name) - with tm.assert_raises_regex(TypeError, msg): - set_op(other) - else: - # should not raise + msg = ('can only do {op} between two IntervalIndex objects that have ' + 'compatible dtypes').format(op=op_name) + with tm.assert_raises_regex(TypeError, msg): set_op(other) def test_isin(self, closed):