From 613ad1c83c766db8d60ff46408b89d9b2d67e4a0 Mon Sep 17 00:00:00 2001 From: jbrockmendel Date: Tue, 5 Nov 2019 17:20:41 -0800 Subject: [PATCH 1/2] BUG: fix TypeErrors raised within _python_agg_general --- pandas/core/groupby/groupby.py | 15 ++++++-- pandas/core/groupby/ops.py | 36 ++++++++++++------- pandas/tests/groupby/aggregate/test_other.py | 38 ++++++++++++++++++++ 3 files changed, 74 insertions(+), 15 deletions(-) diff --git a/pandas/core/groupby/groupby.py b/pandas/core/groupby/groupby.py index 59b118431cfc9..639a92e4438f0 100644 --- a/pandas/core/groupby/groupby.py +++ b/pandas/core/groupby/groupby.py @@ -897,10 +897,21 @@ def _python_agg_general(self, func, *args, **kwargs): output = {} for name, obj in self._iterate_slices(): try: - result, counts = self.grouper.agg_series(obj, f) + # if this function is invalid for this dtype, we will ignore it + func(obj[:0]) except TypeError: continue - else: + except AssertionError: + raise + except Exception: + # Our function depends on having a non-empty argument + # See test_decimal.test_groupby_agg + pass + + result, counts = self.grouper.agg_series(obj, f) + if result is not None: + # TODO: only 3 test cases get None here, do something + # in those cases output[name] = self._try_cast(result, obj, numeric_only=True) if len(output) == 0: diff --git a/pandas/core/groupby/ops.py b/pandas/core/groupby/ops.py index 9bbe73c1851b5..f9a1c20d2a7be 100644 --- a/pandas/core/groupby/ops.py +++ b/pandas/core/groupby/ops.py @@ -61,8 +61,7 @@ class BaseGrouper: Parameters ---------- - axis : int - the axis to group + axis : Index groupings : array of grouping all the grouping instances to handle in this grouper for example for grouper list to groupby, need to pass the list @@ -78,8 +77,15 @@ class BaseGrouper: """ def __init__( - self, axis, groupings, sort=True, group_keys=True, mutated=False, indexer=None + self, + axis: Index, + groupings, + sort=True, + group_keys=True, + mutated=False, + indexer=None, ): + assert isinstance(axis, Index), axis self._filter_empty_groups = self.compressed = len(groupings) != 1 self.axis = axis self.groupings = groupings @@ -614,7 +620,7 @@ def _aggregate_series_pure_python(self, obj, func): counts = np.zeros(ngroups, dtype=int) result = None - splitter = get_splitter(obj, group_index, ngroups, axis=self.axis) + splitter = get_splitter(obj, group_index, ngroups, axis=0) for label, group in splitter: res = func(group) @@ -626,8 +632,12 @@ def _aggregate_series_pure_python(self, obj, func): counts[label] = group.shape[0] result[label] = res - result = lib.maybe_convert_objects(result, try_float=0) - # TODO: try_cast back to EA? + if result is not None: + # if splitter is empty, result can be None, in which case + # maybe_convert_objects would raise TypeError + result = lib.maybe_convert_objects(result, try_float=0) + # TODO: try_cast back to EA? + return result, counts @@ -772,6 +782,11 @@ def groupings(self): ] def agg_series(self, obj, func): + if is_extension_array_dtype(obj.dtype): + # pre-empty SeriesBinGrouper from raising TypeError + # TODO: watch out, this can return None + return self._aggregate_series_pure_python(obj, func) + dummy = obj[:0] grouper = libreduction.SeriesBinGrouper(obj, func, self.bins, dummy) return grouper.get_result() @@ -800,12 +815,13 @@ def _is_indexed_like(obj, axes) -> bool: class DataSplitter: - def __init__(self, data, labels, ngroups, axis=0): + def __init__(self, data, labels, ngroups, axis: int = 0): self.data = data self.labels = ensure_int64(labels) self.ngroups = ngroups self.axis = axis + assert isinstance(axis, int), axis @cache_readonly def slabels(self): @@ -828,12 +844,6 @@ def __iter__(self): starts, ends = lib.generate_slices(self.slabels, self.ngroups) for i, (start, end) in enumerate(zip(starts, ends)): - # Since I'm now compressing the group ids, it's now not "possible" - # to produce empty slices because such groups would not be observed - # in the data - # if start >= end: - # raise AssertionError('Start %s must be less than end %s' - # % (str(start), str(end))) yield i, self._chop(sdata, slice(start, end)) def _get_sorted_data(self): diff --git a/pandas/tests/groupby/aggregate/test_other.py b/pandas/tests/groupby/aggregate/test_other.py index 5dad868c8c3aa..1c297f3e2ada3 100644 --- a/pandas/tests/groupby/aggregate/test_other.py +++ b/pandas/tests/groupby/aggregate/test_other.py @@ -602,3 +602,41 @@ def test_agg_lambda_with_timezone(): columns=["date"], ) tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize( + "err_cls", + [ + NotImplementedError, + RuntimeError, + KeyError, + IndexError, + OSError, + ValueError, + ArithmeticError, + AttributeError, + ], +) +def test_groupby_agg_err_catching(err_cls): + # make sure we suppress anything other than TypeError or AssertionError + # in _python_agg_general + + # Use a non-standard EA to make sure we don't go down ndarray paths + from pandas.tests.extension.decimal.array import DecimalArray, make_data, to_decimal + + data = make_data()[:5] + df = pd.DataFrame( + {"id1": [0, 0, 0, 1, 1], "id2": [0, 1, 0, 1, 1], "decimals": DecimalArray(data)} + ) + + expected = pd.Series(to_decimal([data[0], data[3]])) + + def weird_func(x): + # weird function that raise something other than TypeError or IndexError + # in _python_agg_general + if len(x) == 0: + raise err_cls + return x.iloc[0] + + result = df["decimals"].groupby(df["id1"]).agg(weird_func) + tm.assert_series_equal(result, expected, check_names=False) From 58550750b263b63cb62001a79f58ca8342ceaf73 Mon Sep 17 00:00:00 2001 From: jbrockmendel Date: Tue, 5 Nov 2019 17:21:54 -0800 Subject: [PATCH 2/2] update comment --- pandas/core/groupby/groupby.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pandas/core/groupby/groupby.py b/pandas/core/groupby/groupby.py index 639a92e4438f0..55534f4549c47 100644 --- a/pandas/core/groupby/groupby.py +++ b/pandas/core/groupby/groupby.py @@ -897,7 +897,7 @@ def _python_agg_general(self, func, *args, **kwargs): output = {} for name, obj in self._iterate_slices(): try: - # if this function is invalid for this dtype, we will ignore it + # if this function is invalid for this dtype, we will ignore it. func(obj[:0]) except TypeError: continue @@ -905,7 +905,7 @@ def _python_agg_general(self, func, *args, **kwargs): raise except Exception: # Our function depends on having a non-empty argument - # See test_decimal.test_groupby_agg + # See test_groupby_agg_err_catching pass result, counts = self.grouper.agg_series(obj, f)