Skip to content

BUG: fix TypeErrors raised within _python_agg_general #29425

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Nov 6, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions pandas/core/groupby/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -899,10 +899,21 @@ def _python_agg_general(self, func, *args, **kwargs):
output = {}
for name, obj in self._iterate_slices():
try:
result, counts = self.grouper.agg_series(obj, f)
# if this function is invalid for this dtype, we will ignore it.
func(obj[:0])
except TypeError:
continue
else:
except AssertionError:
raise
except Exception:
# Our function depends on having a non-empty argument
# See test_groupby_agg_err_catching
pass

result, counts = self.grouper.agg_series(obj, f)
if result is not None:
# TODO: only 3 test cases get None here, do something
# in those cases
output[name] = self._try_cast(result, obj, numeric_only=True)

if len(output) == 0:
Expand Down
36 changes: 23 additions & 13 deletions pandas/core/groupby/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,7 @@ class BaseGrouper:

Parameters
----------
axis : int
the axis to group
axis : Index
groupings : array of grouping
all the grouping instances to handle in this grouper
for example for grouper list to groupby, need to pass the list
Expand All @@ -78,8 +77,15 @@ class BaseGrouper:
"""

def __init__(
self, axis, groupings, sort=True, group_keys=True, mutated=False, indexer=None
self,
axis: Index,
groupings,
sort=True,
group_keys=True,
mutated=False,
indexer=None,
):
assert isinstance(axis, Index), axis
self._filter_empty_groups = self.compressed = len(groupings) != 1
self.axis = axis
self.groupings = groupings
Expand Down Expand Up @@ -623,7 +629,7 @@ def _aggregate_series_pure_python(self, obj, func):
counts = np.zeros(ngroups, dtype=int)
result = None

splitter = get_splitter(obj, group_index, ngroups, axis=self.axis)
splitter = get_splitter(obj, group_index, ngroups, axis=0)

for label, group in splitter:
res = func(group)
Expand All @@ -635,8 +641,12 @@ def _aggregate_series_pure_python(self, obj, func):
counts[label] = group.shape[0]
result[label] = res

result = lib.maybe_convert_objects(result, try_float=0)
# TODO: try_cast back to EA?
if result is not None:
# if splitter is empty, result can be None, in which case
# maybe_convert_objects would raise TypeError
result = lib.maybe_convert_objects(result, try_float=0)
# TODO: try_cast back to EA?

return result, counts


Expand Down Expand Up @@ -781,6 +791,11 @@ def groupings(self):
]

def agg_series(self, obj: Series, func):
if is_extension_array_dtype(obj.dtype):
# pre-empty SeriesBinGrouper from raising TypeError
# TODO: watch out, this can return None
return self._aggregate_series_pure_python(obj, func)

dummy = obj[:0]
grouper = libreduction.SeriesBinGrouper(obj, func, self.bins, dummy)
return grouper.get_result()
Expand Down Expand Up @@ -809,12 +824,13 @@ def _is_indexed_like(obj, axes) -> bool:


class DataSplitter:
def __init__(self, data, labels, ngroups, axis=0):
def __init__(self, data, labels, ngroups, axis: int = 0):
self.data = data
self.labels = ensure_int64(labels)
self.ngroups = ngroups

self.axis = axis
assert isinstance(axis, int), axis

@cache_readonly
def slabels(self):
Expand All @@ -837,12 +853,6 @@ def __iter__(self):
starts, ends = lib.generate_slices(self.slabels, self.ngroups)

for i, (start, end) in enumerate(zip(starts, ends)):
# Since I'm now compressing the group ids, it's now not "possible"
# to produce empty slices because such groups would not be observed
# in the data
# if start >= end:
# raise AssertionError('Start %s must be less than end %s'
# % (str(start), str(end)))
yield i, self._chop(sdata, slice(start, end))

def _get_sorted_data(self):
Expand Down
38 changes: 38 additions & 0 deletions pandas/tests/groupby/aggregate/test_other.py
Original file line number Diff line number Diff line change
Expand Up @@ -602,3 +602,41 @@ def test_agg_lambda_with_timezone():
columns=["date"],
)
tm.assert_frame_equal(result, expected)


@pytest.mark.parametrize(
"err_cls",
[
NotImplementedError,
RuntimeError,
KeyError,
IndexError,
OSError,
ValueError,
ArithmeticError,
AttributeError,
],
)
def test_groupby_agg_err_catching(err_cls):
# make sure we suppress anything other than TypeError or AssertionError
# in _python_agg_general

# Use a non-standard EA to make sure we don't go down ndarray paths
from pandas.tests.extension.decimal.array import DecimalArray, make_data, to_decimal

data = make_data()[:5]
df = pd.DataFrame(
{"id1": [0, 0, 0, 1, 1], "id2": [0, 1, 0, 1, 1], "decimals": DecimalArray(data)}
)

expected = pd.Series(to_decimal([data[0], data[3]]))

def weird_func(x):
# weird function that raise something other than TypeError or IndexError
# in _python_agg_general
if len(x) == 0:
raise err_cls
return x.iloc[0]

result = df["decimals"].groupby(df["id1"]).agg(weird_func)
tm.assert_series_equal(result, expected, check_names=False)