Skip to content

Commit b905f2b

Browse files
jschendeljreback
authored andcommitted
Backport PR pandas-dev#31238: REGR: Prevent indexes that aren't directly backed by numpy from entering libreduction code paths (pandas-dev#31378)
1 parent 161f3f7 commit b905f2b

File tree

9 files changed

+68
-8
lines changed

9 files changed

+68
-8
lines changed

pandas/core/apply.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
is_list_like,
1414
is_sequence,
1515
)
16-
from pandas.core.dtypes.generic import ABCMultiIndex, ABCSeries
16+
from pandas.core.dtypes.generic import ABCSeries
1717

1818
from pandas.core.construction import create_series_with_explicit_dtype
1919

@@ -277,9 +277,8 @@ def apply_standard(self):
277277
if (
278278
self.result_type in ["reduce", None]
279279
and not self.dtypes.apply(is_extension_array_dtype).any()
280-
# Disallow complex_internals since libreduction shortcut
281-
# cannot handle MultiIndex
282-
and not isinstance(self.agg_axis, ABCMultiIndex)
280+
# Disallow complex_internals since libreduction shortcut raises a TypeError
281+
and not self.agg_axis._has_complex_internals
283282
):
284283

285284
values = self.values

pandas/core/groupby/ops.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -164,8 +164,8 @@ def apply(self, f, data: FrameOrSeries, axis: int = 0):
164164
com.get_callable_name(f) not in base.plotting_methods
165165
and isinstance(splitter, FrameSplitter)
166166
and axis == 0
167-
# apply_frame_axis0 doesn't allow MultiIndex
168-
and not isinstance(sdata.index, MultiIndex)
167+
# fast_apply/libreduction doesn't allow non-numpy backed indexes
168+
and not sdata.index._has_complex_internals
169169
):
170170
try:
171171
result_values, mutated = splitter.fast_apply(f, group_keys)
@@ -616,8 +616,8 @@ def agg_series(self, obj: Series, func):
616616
# TODO: can we get a performant workaround for EAs backed by ndarray?
617617
return self._aggregate_series_pure_python(obj, func)
618618

619-
elif isinstance(obj.index, MultiIndex):
620-
# MultiIndex; Pre-empt TypeError in _aggregate_series_fast
619+
elif obj.index._has_complex_internals:
620+
# Pre-empt TypeError in _aggregate_series_fast
621621
return self._aggregate_series_pure_python(obj, func)
622622

623623
try:

pandas/core/indexes/base.py

+8
Original file line numberDiff line numberDiff line change
@@ -3825,6 +3825,14 @@ def _assert_can_do_op(self, value):
38253825
if not is_scalar(value):
38263826
raise TypeError(f"'value' must be a scalar, passed: {type(value).__name__}")
38273827

3828+
@property
3829+
def _has_complex_internals(self):
3830+
"""
3831+
Indicates if an index is not directly backed by a numpy array
3832+
"""
3833+
# used to avoid libreduction code paths, which raise or require conversion
3834+
return False
3835+
38283836
def _is_memory_usage_qualified(self) -> bool:
38293837
"""
38303838
Return a boolean if we need a qualified .info display.

pandas/core/indexes/category.py

+5
Original file line numberDiff line numberDiff line change
@@ -380,6 +380,11 @@ def values(self):
380380
""" return the underlying data, which is a Categorical """
381381
return self._data
382382

383+
@property
384+
def _has_complex_internals(self):
385+
# used to avoid libreduction code paths, which raise or require conversion
386+
return True
387+
383388
def _wrap_setop_result(self, other, result):
384389
name = get_op_result_name(self, other)
385390
# We use _shallow_copy rather than the Index implementation

pandas/core/indexes/interval.py

+5
Original file line numberDiff line numberDiff line change
@@ -402,6 +402,11 @@ def values(self):
402402
def _values(self):
403403
return self._data
404404

405+
@property
406+
def _has_complex_internals(self):
407+
# used to avoid libreduction code paths, which raise or require conversion
408+
return True
409+
405410
def __array_wrap__(self, result, context=None):
406411
# we don't want the superclass implementation
407412
return result

pandas/core/indexes/multi.py

+5
Original file line numberDiff line numberDiff line change
@@ -1346,6 +1346,11 @@ def values(self):
13461346
self._tuples = lib.fast_zip(values)
13471347
return self._tuples
13481348

1349+
@property
1350+
def _has_complex_internals(self):
1351+
# used to avoid libreduction code paths, which raise or require conversion
1352+
return True
1353+
13491354
@cache_readonly
13501355
def is_monotonic_increasing(self) -> bool:
13511356
"""

pandas/core/indexes/period.py

+5
Original file line numberDiff line numberDiff line change
@@ -266,6 +266,11 @@ def _simple_new(cls, values, name=None, freq=None, **kwargs):
266266
def values(self):
267267
return np.asarray(self)
268268

269+
@property
270+
def _has_complex_internals(self):
271+
# used to avoid libreduction code paths, which raise or require conversion
272+
return True
273+
269274
def _shallow_copy(self, values=None, **kwargs):
270275
# TODO: simplify, figure out type of values
271276
if values is None:

pandas/tests/groupby/aggregate/test_aggregate.py

+17
Original file line numberDiff line numberDiff line change
@@ -361,6 +361,23 @@ def test_func_duplicates_raises():
361361
df.groupby("A").agg(["min", "min"])
362362

363363

364+
@pytest.mark.parametrize(
365+
"index",
366+
[
367+
pd.CategoricalIndex(list("abc")),
368+
pd.interval_range(0, 3),
369+
pd.period_range("2020", periods=3, freq="D"),
370+
pd.MultiIndex.from_tuples([("a", 0), ("a", 1), ("b", 0)]),
371+
],
372+
)
373+
def test_agg_index_has_complex_internals(index):
374+
# GH 31223
375+
df = DataFrame({"group": [1, 1, 2], "value": [0, 1, 0]}, index=index)
376+
result = df.groupby("group").agg({"value": Series.nunique})
377+
expected = DataFrame({"group": [1, 2], "value": [2, 1]}).set_index("group")
378+
tm.assert_frame_equal(result, expected)
379+
380+
364381
class TestNamedAggregationSeries:
365382
def test_series_named_agg(self):
366383
df = pd.Series([1, 2, 3, 4])

pandas/tests/groupby/test_apply.py

+16
Original file line numberDiff line numberDiff line change
@@ -769,3 +769,19 @@ def test_apply_multi_level_name(category):
769769
)
770770
tm.assert_frame_equal(result, expected)
771771
assert df.index.names == ["A", "B"]
772+
773+
774+
@pytest.mark.parametrize(
775+
"index",
776+
[
777+
pd.CategoricalIndex(list("abc")),
778+
pd.interval_range(0, 3),
779+
pd.period_range("2020", periods=3, freq="D"),
780+
pd.MultiIndex.from_tuples([("a", 0), ("a", 1), ("b", 0)]),
781+
],
782+
)
783+
def test_apply_index_has_complex_internals(index):
784+
# GH 31248
785+
df = DataFrame({"group": [1, 1, 2], "value": [0, 1, 0]}, index=index)
786+
result = df.groupby("group").apply(lambda x: x)
787+
tm.assert_frame_equal(result, df)

0 commit comments

Comments
 (0)