Skip to content

Commit 893a33b

Browse files
jbrockmendelMateusz Górski
authored and
Mateusz Górski
committed
BUG: fix TypeErrors raised within _python_agg_general (pandas-dev#29425)
1 parent ab37330 commit 893a33b

File tree

3 files changed

+74
-15
lines changed

3 files changed

+74
-15
lines changed

pandas/core/groupby/groupby.py

+13-2
Original file line numberDiff line numberDiff line change
@@ -899,10 +899,21 @@ def _python_agg_general(self, func, *args, **kwargs):
899899
output = {}
900900
for name, obj in self._iterate_slices():
901901
try:
902-
result, counts = self.grouper.agg_series(obj, f)
902+
# if this function is invalid for this dtype, we will ignore it.
903+
func(obj[:0])
903904
except TypeError:
904905
continue
905-
else:
906+
except AssertionError:
907+
raise
908+
except Exception:
909+
# Our function depends on having a non-empty argument
910+
# See test_groupby_agg_err_catching
911+
pass
912+
913+
result, counts = self.grouper.agg_series(obj, f)
914+
if result is not None:
915+
# TODO: only 3 test cases get None here, do something
916+
# in those cases
906917
output[name] = self._try_cast(result, obj, numeric_only=True)
907918

908919
if len(output) == 0:

pandas/core/groupby/ops.py

+23-13
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,7 @@ class BaseGrouper:
6161
6262
Parameters
6363
----------
64-
axis : int
65-
the axis to group
64+
axis : Index
6665
groupings : array of grouping
6766
all the grouping instances to handle in this grouper
6867
for example for grouper list to groupby, need to pass the list
@@ -78,8 +77,15 @@ class BaseGrouper:
7877
"""
7978

8079
def __init__(
81-
self, axis, groupings, sort=True, group_keys=True, mutated=False, indexer=None
80+
self,
81+
axis: Index,
82+
groupings,
83+
sort=True,
84+
group_keys=True,
85+
mutated=False,
86+
indexer=None,
8287
):
88+
assert isinstance(axis, Index), axis
8389
self._filter_empty_groups = self.compressed = len(groupings) != 1
8490
self.axis = axis
8591
self.groupings = groupings
@@ -623,7 +629,7 @@ def _aggregate_series_pure_python(self, obj, func):
623629
counts = np.zeros(ngroups, dtype=int)
624630
result = None
625631

626-
splitter = get_splitter(obj, group_index, ngroups, axis=self.axis)
632+
splitter = get_splitter(obj, group_index, ngroups, axis=0)
627633

628634
for label, group in splitter:
629635
res = func(group)
@@ -635,8 +641,12 @@ def _aggregate_series_pure_python(self, obj, func):
635641
counts[label] = group.shape[0]
636642
result[label] = res
637643

638-
result = lib.maybe_convert_objects(result, try_float=0)
639-
# TODO: try_cast back to EA?
644+
if result is not None:
645+
# if splitter is empty, result can be None, in which case
646+
# maybe_convert_objects would raise TypeError
647+
result = lib.maybe_convert_objects(result, try_float=0)
648+
# TODO: try_cast back to EA?
649+
640650
return result, counts
641651

642652

@@ -781,6 +791,11 @@ def groupings(self):
781791
]
782792

783793
def agg_series(self, obj: Series, func):
794+
if is_extension_array_dtype(obj.dtype):
795+
# pre-empty SeriesBinGrouper from raising TypeError
796+
# TODO: watch out, this can return None
797+
return self._aggregate_series_pure_python(obj, func)
798+
784799
dummy = obj[:0]
785800
grouper = libreduction.SeriesBinGrouper(obj, func, self.bins, dummy)
786801
return grouper.get_result()
@@ -809,12 +824,13 @@ def _is_indexed_like(obj, axes) -> bool:
809824

810825

811826
class DataSplitter:
812-
def __init__(self, data, labels, ngroups, axis=0):
827+
def __init__(self, data, labels, ngroups, axis: int = 0):
813828
self.data = data
814829
self.labels = ensure_int64(labels)
815830
self.ngroups = ngroups
816831

817832
self.axis = axis
833+
assert isinstance(axis, int), axis
818834

819835
@cache_readonly
820836
def slabels(self):
@@ -837,12 +853,6 @@ def __iter__(self):
837853
starts, ends = lib.generate_slices(self.slabels, self.ngroups)
838854

839855
for i, (start, end) in enumerate(zip(starts, ends)):
840-
# Since I'm now compressing the group ids, it's now not "possible"
841-
# to produce empty slices because such groups would not be observed
842-
# in the data
843-
# if start >= end:
844-
# raise AssertionError('Start %s must be less than end %s'
845-
# % (str(start), str(end)))
846856
yield i, self._chop(sdata, slice(start, end))
847857

848858
def _get_sorted_data(self):

pandas/tests/groupby/aggregate/test_other.py

+38
Original file line numberDiff line numberDiff line change
@@ -602,3 +602,41 @@ def test_agg_lambda_with_timezone():
602602
columns=["date"],
603603
)
604604
tm.assert_frame_equal(result, expected)
605+
606+
607+
@pytest.mark.parametrize(
608+
"err_cls",
609+
[
610+
NotImplementedError,
611+
RuntimeError,
612+
KeyError,
613+
IndexError,
614+
OSError,
615+
ValueError,
616+
ArithmeticError,
617+
AttributeError,
618+
],
619+
)
620+
def test_groupby_agg_err_catching(err_cls):
621+
# make sure we suppress anything other than TypeError or AssertionError
622+
# in _python_agg_general
623+
624+
# Use a non-standard EA to make sure we don't go down ndarray paths
625+
from pandas.tests.extension.decimal.array import DecimalArray, make_data, to_decimal
626+
627+
data = make_data()[:5]
628+
df = pd.DataFrame(
629+
{"id1": [0, 0, 0, 1, 1], "id2": [0, 1, 0, 1, 1], "decimals": DecimalArray(data)}
630+
)
631+
632+
expected = pd.Series(to_decimal([data[0], data[3]]))
633+
634+
def weird_func(x):
635+
# weird function that raise something other than TypeError or IndexError
636+
# in _python_agg_general
637+
if len(x) == 0:
638+
raise err_cls
639+
return x.iloc[0]
640+
641+
result = df["decimals"].groupby(df["id1"]).agg(weird_func)
642+
tm.assert_series_equal(result, expected, check_names=False)

0 commit comments

Comments
 (0)