diff --git a/pandas/core/groupby/groupby.py b/pandas/core/groupby/groupby.py index 873a31e658625..fa4a184e8f7a4 100644 --- a/pandas/core/groupby/groupby.py +++ b/pandas/core/groupby/groupby.py @@ -899,10 +899,21 @@ def _python_agg_general(self, func, *args, **kwargs): output = {} for name, obj in self._iterate_slices(): try: - result, counts = self.grouper.agg_series(obj, f) + # if this function is invalid for this dtype, we will ignore it. + func(obj[:0]) except TypeError: continue - else: + except AssertionError: + raise + except Exception: + # Our function depends on having a non-empty argument + # See test_groupby_agg_err_catching + pass + + result, counts = self.grouper.agg_series(obj, f) + if result is not None: + # TODO: only 3 test cases get None here, do something + # in those cases output[name] = self._try_cast(result, obj, numeric_only=True) if len(output) == 0: diff --git a/pandas/core/groupby/ops.py b/pandas/core/groupby/ops.py index 2cc0e5fde2290..5bad73bf40ff5 100644 --- a/pandas/core/groupby/ops.py +++ b/pandas/core/groupby/ops.py @@ -61,8 +61,7 @@ class BaseGrouper: Parameters ---------- - axis : int - the axis to group + axis : Index groupings : array of grouping all the grouping instances to handle in this grouper for example for grouper list to groupby, need to pass the list @@ -78,8 +77,15 @@ class BaseGrouper: """ def __init__( - self, axis, groupings, sort=True, group_keys=True, mutated=False, indexer=None + self, + axis: Index, + groupings, + sort=True, + group_keys=True, + mutated=False, + indexer=None, ): + assert isinstance(axis, Index), axis self._filter_empty_groups = self.compressed = len(groupings) != 1 self.axis = axis self.groupings = groupings @@ -623,7 +629,7 @@ def _aggregate_series_pure_python(self, obj, func): counts = np.zeros(ngroups, dtype=int) result = None - splitter = get_splitter(obj, group_index, ngroups, axis=self.axis) + splitter = get_splitter(obj, group_index, ngroups, axis=0) for label, group in splitter: res = func(group) @@ -635,8 +641,12 @@ def _aggregate_series_pure_python(self, obj, func): counts[label] = group.shape[0] result[label] = res - result = lib.maybe_convert_objects(result, try_float=0) - # TODO: try_cast back to EA? + if result is not None: + # if splitter is empty, result can be None, in which case + # maybe_convert_objects would raise TypeError + result = lib.maybe_convert_objects(result, try_float=0) + # TODO: try_cast back to EA? + return result, counts @@ -781,6 +791,11 @@ def groupings(self): ] def agg_series(self, obj: Series, func): + if is_extension_array_dtype(obj.dtype): + # pre-empty SeriesBinGrouper from raising TypeError + # TODO: watch out, this can return None + return self._aggregate_series_pure_python(obj, func) + dummy = obj[:0] grouper = libreduction.SeriesBinGrouper(obj, func, self.bins, dummy) return grouper.get_result() @@ -809,12 +824,13 @@ def _is_indexed_like(obj, axes) -> bool: class DataSplitter: - def __init__(self, data, labels, ngroups, axis=0): + def __init__(self, data, labels, ngroups, axis: int = 0): self.data = data self.labels = ensure_int64(labels) self.ngroups = ngroups self.axis = axis + assert isinstance(axis, int), axis @cache_readonly def slabels(self): @@ -837,12 +853,6 @@ def __iter__(self): starts, ends = lib.generate_slices(self.slabels, self.ngroups) for i, (start, end) in enumerate(zip(starts, ends)): - # Since I'm now compressing the group ids, it's now not "possible" - # to produce empty slices because such groups would not be observed - # in the data - # if start >= end: - # raise AssertionError('Start %s must be less than end %s' - # % (str(start), str(end))) yield i, self._chop(sdata, slice(start, end)) def _get_sorted_data(self): diff --git a/pandas/tests/groupby/aggregate/test_other.py b/pandas/tests/groupby/aggregate/test_other.py index 5dad868c8c3aa..1c297f3e2ada3 100644 --- a/pandas/tests/groupby/aggregate/test_other.py +++ b/pandas/tests/groupby/aggregate/test_other.py @@ -602,3 +602,41 @@ def test_agg_lambda_with_timezone(): columns=["date"], ) tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize( + "err_cls", + [ + NotImplementedError, + RuntimeError, + KeyError, + IndexError, + OSError, + ValueError, + ArithmeticError, + AttributeError, + ], +) +def test_groupby_agg_err_catching(err_cls): + # make sure we suppress anything other than TypeError or AssertionError + # in _python_agg_general + + # Use a non-standard EA to make sure we don't go down ndarray paths + from pandas.tests.extension.decimal.array import DecimalArray, make_data, to_decimal + + data = make_data()[:5] + df = pd.DataFrame( + {"id1": [0, 0, 0, 1, 1], "id2": [0, 1, 0, 1, 1], "decimals": DecimalArray(data)} + ) + + expected = pd.Series(to_decimal([data[0], data[3]])) + + def weird_func(x): + # weird function that raise something other than TypeError or IndexError + # in _python_agg_general + if len(x) == 0: + raise err_cls + return x.iloc[0] + + result = df["decimals"].groupby(df["id1"]).agg(weird_func) + tm.assert_series_equal(result, expected, check_names=False)