Skip to content

Commit 82acdee

Browse files
jschendeljreback
authored andcommitted
REGR: Prevent indexes that aren't directly backed by numpy from entering libreduction code paths (#31238)
1 parent ddb3427 commit 82acdee

File tree

9 files changed

+68
-8
lines changed

9 files changed

+68
-8
lines changed

pandas/core/apply.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
is_list_like,
1515
is_sequence,
1616
)
17-
from pandas.core.dtypes.generic import ABCMultiIndex, ABCSeries
17+
from pandas.core.dtypes.generic import ABCSeries
1818

1919
from pandas.core.construction import create_series_with_explicit_dtype
2020

@@ -278,9 +278,8 @@ def apply_standard(self):
278278
if (
279279
self.result_type in ["reduce", None]
280280
and not self.dtypes.apply(is_extension_array_dtype).any()
281-
# Disallow complex_internals since libreduction shortcut
282-
# cannot handle MultiIndex
283-
and not isinstance(self.agg_axis, ABCMultiIndex)
281+
# Disallow complex_internals since libreduction shortcut raises a TypeError
282+
and not self.agg_axis._has_complex_internals
284283
):
285284

286285
values = self.values

pandas/core/groupby/ops.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -164,8 +164,8 @@ def apply(self, f, data: FrameOrSeries, axis: int = 0):
164164
com.get_callable_name(f) not in base.plotting_methods
165165
and isinstance(splitter, FrameSplitter)
166166
and axis == 0
167-
# apply_frame_axis0 doesn't allow MultiIndex
168-
and not isinstance(sdata.index, MultiIndex)
167+
# fast_apply/libreduction doesn't allow non-numpy backed indexes
168+
and not sdata.index._has_complex_internals
169169
):
170170
try:
171171
result_values, mutated = splitter.fast_apply(f, group_keys)
@@ -616,8 +616,8 @@ def agg_series(self, obj: Series, func):
616616
# TODO: can we get a performant workaround for EAs backed by ndarray?
617617
return self._aggregate_series_pure_python(obj, func)
618618

619-
elif isinstance(obj.index, MultiIndex):
620-
# MultiIndex; Pre-empt TypeError in _aggregate_series_fast
619+
elif obj.index._has_complex_internals:
620+
# Pre-empt TypeError in _aggregate_series_fast
621621
return self._aggregate_series_pure_python(obj, func)
622622

623623
try:

pandas/core/indexes/base.py

+8
Original file line numberDiff line numberDiff line change
@@ -4109,6 +4109,14 @@ def _assert_can_do_op(self, value):
41094109
if not is_scalar(value):
41104110
raise TypeError(f"'value' must be a scalar, passed: {type(value).__name__}")
41114111

4112+
@property
4113+
def _has_complex_internals(self):
4114+
"""
4115+
Indicates if an index is not directly backed by a numpy array
4116+
"""
4117+
# used to avoid libreduction code paths, which raise or require conversion
4118+
return False
4119+
41124120
def _is_memory_usage_qualified(self) -> bool:
41134121
"""
41144122
Return a boolean if we need a qualified .info display.

pandas/core/indexes/category.py

+5
Original file line numberDiff line numberDiff line change
@@ -380,6 +380,11 @@ def values(self):
380380
""" return the underlying data, which is a Categorical """
381381
return self._data
382382

383+
@property
384+
def _has_complex_internals(self):
385+
# used to avoid libreduction code paths, which raise or require conversion
386+
return True
387+
383388
def _wrap_setop_result(self, other, result):
384389
name = get_op_result_name(self, other)
385390
# We use _shallow_copy rather than the Index implementation

pandas/core/indexes/interval.py

+5
Original file line numberDiff line numberDiff line change
@@ -404,6 +404,11 @@ def values(self):
404404
"""
405405
return self._data
406406

407+
@property
408+
def _has_complex_internals(self):
409+
# used to avoid libreduction code paths, which raise or require conversion
410+
return True
411+
407412
def __array_wrap__(self, result, context=None):
408413
# we don't want the superclass implementation
409414
return result

pandas/core/indexes/multi.py

+5
Original file line numberDiff line numberDiff line change
@@ -1346,6 +1346,11 @@ def values(self):
13461346
self._tuples = lib.fast_zip(values)
13471347
return self._tuples
13481348

1349+
@property
1350+
def _has_complex_internals(self):
1351+
# used to avoid libreduction code paths, which raise or require conversion
1352+
return True
1353+
13491354
@cache_readonly
13501355
def is_monotonic_increasing(self) -> bool:
13511356
"""

pandas/core/indexes/period.py

+5
Original file line numberDiff line numberDiff line change
@@ -255,6 +255,11 @@ def _simple_new(cls, values, name=None, freq=None, **kwargs):
255255
def values(self):
256256
return np.asarray(self)
257257

258+
@property
259+
def _has_complex_internals(self):
260+
# used to avoid libreduction code paths, which raise or require conversion
261+
return True
262+
258263
def _shallow_copy(self, values=None, **kwargs):
259264
# TODO: simplify, figure out type of values
260265
if values is None:

pandas/tests/groupby/aggregate/test_aggregate.py

+17
Original file line numberDiff line numberDiff line change
@@ -360,6 +360,23 @@ def test_func_duplicates_raises():
360360
df.groupby("A").agg(["min", "min"])
361361

362362

363+
@pytest.mark.parametrize(
364+
"index",
365+
[
366+
pd.CategoricalIndex(list("abc")),
367+
pd.interval_range(0, 3),
368+
pd.period_range("2020", periods=3, freq="D"),
369+
pd.MultiIndex.from_tuples([("a", 0), ("a", 1), ("b", 0)]),
370+
],
371+
)
372+
def test_agg_index_has_complex_internals(index):
373+
# GH 31223
374+
df = DataFrame({"group": [1, 1, 2], "value": [0, 1, 0]}, index=index)
375+
result = df.groupby("group").agg({"value": Series.nunique})
376+
expected = DataFrame({"group": [1, 2], "value": [2, 1]}).set_index("group")
377+
tm.assert_frame_equal(result, expected)
378+
379+
363380
class TestNamedAggregationSeries:
364381
def test_series_named_agg(self):
365382
df = pd.Series([1, 2, 3, 4])

pandas/tests/groupby/test_apply.py

+16
Original file line numberDiff line numberDiff line change
@@ -811,3 +811,19 @@ def test_groupby_apply_datetime_result_dtypes():
811811
index=["observation", "color", "mood", "intensity", "score"],
812812
)
813813
tm.assert_series_equal(result, expected)
814+
815+
816+
@pytest.mark.parametrize(
817+
"index",
818+
[
819+
pd.CategoricalIndex(list("abc")),
820+
pd.interval_range(0, 3),
821+
pd.period_range("2020", periods=3, freq="D"),
822+
pd.MultiIndex.from_tuples([("a", 0), ("a", 1), ("b", 0)]),
823+
],
824+
)
825+
def test_apply_index_has_complex_internals(index):
826+
# GH 31248
827+
df = DataFrame({"group": [1, 1, 2], "value": [0, 1, 0]}, index=index)
828+
result = df.groupby("group").apply(lambda x: x)
829+
tm.assert_frame_equal(result, df)

0 commit comments

Comments
 (0)