Skip to content

Commit a6e83a1

Browse files
jschendelproost
authored andcommitted
BUG: Support categorical targets in IntervalIndex.get_indexer (pandas-dev#30181)
1 parent 7bbd9ee commit a6e83a1

File tree

4 files changed

+44
-2
lines changed

4 files changed

+44
-2
lines changed

doc/source/whatsnew/v1.0.0.rst

+2-1
Original file line numberDiff line numberDiff line change
@@ -713,6 +713,7 @@ Numeric
713713
- Bug in :class:`NumericIndex` construction that caused indexing to fail when integers in the ``np.uint64`` range were used (:issue:`28023`)
714714
- Bug in :class:`NumericIndex` construction that caused :class:`UInt64Index` to be casted to :class:`Float64Index` when integers in the ``np.uint64`` range were used to index a :class:`DataFrame` (:issue:`28279`)
715715
- Bug in :meth:`Series.interpolate` when using method=`index` with an unsorted index, would previously return incorrect results. (:issue:`21037`)
716+
- Bug in :meth:`DataFrame.round` where a :class:`DataFrame` with a :class:`CategoricalIndex` of :class:`IntervalIndex` columns would incorrectly raise a ``TypeError`` (:issue:`30063`)
716717

717718
Conversion
718719
^^^^^^^^^^
@@ -730,7 +731,7 @@ Strings
730731
Interval
731732
^^^^^^^^
732733

733-
-
734+
- Bug in :meth:`IntervalIndex.get_indexer` where a :class:`Categorical` or :class:`CategoricalIndex` ``target`` would incorrectly raise a ``TypeError`` (:issue:`30063`)
734735
-
735736

736737
Indexing

pandas/core/indexes/interval.py

+6
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
)
2020
from pandas.core.dtypes.common import (
2121
ensure_platform_int,
22+
is_categorical,
2223
is_datetime64tz_dtype,
2324
is_datetime_or_timedelta_dtype,
2425
is_dtype_equal,
@@ -36,6 +37,7 @@
3637
from pandas.core.dtypes.missing import isna
3738

3839
from pandas._typing import AnyArrayLike
40+
from pandas.core.algorithms import take_1d
3941
from pandas.core.arrays.interval import IntervalArray, _interval_shared_docs
4042
import pandas.core.common as com
4143
import pandas.core.indexes.base as ibase
@@ -958,6 +960,10 @@ def get_indexer(
958960
left_indexer = self.left.get_indexer(target_as_index.left)
959961
right_indexer = self.right.get_indexer(target_as_index.right)
960962
indexer = np.where(left_indexer == right_indexer, left_indexer, -1)
963+
elif is_categorical(target_as_index):
964+
# get an indexer for unique categories then propogate to codes via take_1d
965+
categories_indexer = self.get_indexer(target_as_index.categories)
966+
indexer = take_1d(categories_indexer, target_as_index.codes, fill_value=-1)
961967
elif not is_object_dtype(target_as_index):
962968
# homogeneous scalar index: use IntervalTree
963969
target_as_index = self._maybe_convert_i8(target_as_index)

pandas/tests/frame/test_analytics.py

+9
Original file line numberDiff line numberDiff line change
@@ -2272,6 +2272,15 @@ def test_round_nonunique_categorical(self):
22722272

22732273
tm.assert_frame_equal(result, expected)
22742274

2275+
def test_round_interval_category_columns(self):
2276+
# GH 30063
2277+
columns = pd.CategoricalIndex(pd.interval_range(0, 2))
2278+
df = DataFrame([[0.66, 1.1], [0.3, 0.25]], columns=columns)
2279+
2280+
result = df.round()
2281+
expected = DataFrame([[1.0, 1.0], [0.0, 0.0]], columns=columns)
2282+
tm.assert_frame_equal(result, expected)
2283+
22752284
# ---------------------------------------------------------------------
22762285
# Clip
22772286

pandas/tests/indexes/interval/test_indexing.py

+27-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,14 @@
33
import numpy as np
44
import pytest
55

6-
from pandas import Interval, IntervalIndex, Timedelta, date_range, timedelta_range
6+
from pandas import (
7+
CategoricalIndex,
8+
Interval,
9+
IntervalIndex,
10+
Timedelta,
11+
date_range,
12+
timedelta_range,
13+
)
714
from pandas.core.indexes.base import InvalidIndexError
815
import pandas.util.testing as tm
916

@@ -231,6 +238,25 @@ def test_get_indexer_length_one_interval(self, size, closed):
231238
expected = np.array([0] * size, dtype="intp")
232239
tm.assert_numpy_array_equal(result, expected)
233240

241+
@pytest.mark.parametrize(
242+
"target",
243+
[
244+
IntervalIndex.from_tuples([(7, 8), (1, 2), (3, 4), (0, 1)]),
245+
IntervalIndex.from_tuples([(0, 1), (1, 2), (3, 4), np.nan]),
246+
IntervalIndex.from_tuples([(0, 1), (1, 2), (3, 4)], closed="both"),
247+
[-1, 0, 0.5, 1, 2, 2.5, np.nan],
248+
["foo", "foo", "bar", "baz"],
249+
],
250+
)
251+
def test_get_indexer_categorical(self, target, ordered_fixture):
252+
# GH 30063: categorical and non-categorical results should be consistent
253+
index = IntervalIndex.from_tuples([(0, 1), (1, 2), (3, 4)])
254+
categorical_target = CategoricalIndex(target, ordered=ordered_fixture)
255+
256+
result = index.get_indexer(categorical_target)
257+
expected = index.get_indexer(target)
258+
tm.assert_numpy_array_equal(result, expected)
259+
234260
@pytest.mark.parametrize(
235261
"tuples, closed",
236262
[

0 commit comments

Comments
 (0)