|
26 | 26 | lib,
|
27 | 27 | )
|
28 | 28 | import pandas._libs.groupby as libgroupby
|
29 |
| -import pandas._libs.reduction as libreduction |
30 | 29 | from pandas._typing import (
|
31 | 30 | ArrayLike,
|
32 | 31 | AxisInt,
|
|
75 | 74 | from pandas.core.generic import NDFrame
|
76 | 75 |
|
77 | 76 |
|
| 77 | +def check_result_array(obj, dtype): |
| 78 | + # Our operation is supposed to be an aggregation/reduction. If |
| 79 | + # it returns an ndarray, this likely means an invalid operation has |
| 80 | + # been passed. See test_apply_without_aggregation, test_agg_must_agg |
| 81 | + if isinstance(obj, np.ndarray): |
| 82 | + if dtype != object: |
| 83 | + # If it is object dtype, the function can be a reduction/aggregation |
| 84 | + # and still return an ndarray e.g. test_agg_over_numpy_arrays |
| 85 | + raise ValueError("Must produce aggregated value") |
| 86 | + |
| 87 | + |
| 88 | +def extract_result(res): |
| 89 | + """ |
| 90 | + Extract the result object, it might be a 0-dim ndarray |
| 91 | + or a len-1 0-dim, or a scalar |
| 92 | + """ |
| 93 | + if hasattr(res, "_values"): |
| 94 | + # Preserve EA |
| 95 | + res = res._values |
| 96 | + if res.ndim == 1 and len(res) == 1: |
| 97 | + # see test_agg_lambda_with_timezone, test_resampler_grouper.py::test_apply |
| 98 | + res = res[0] |
| 99 | + return res |
| 100 | + |
| 101 | + |
78 | 102 | class WrappedCythonOp:
|
79 | 103 | """
|
80 | 104 | Dispatch logic for functions defined in _libs.groupby
|
@@ -836,11 +860,11 @@ def _aggregate_series_pure_python(
|
836 | 860 |
|
837 | 861 | for i, group in enumerate(splitter):
|
838 | 862 | res = func(group)
|
839 |
| - res = libreduction.extract_result(res) |
| 863 | + res = extract_result(res) |
840 | 864 |
|
841 | 865 | if not initialized:
|
842 | 866 | # We only do this validation on the first iteration
|
843 |
| - libreduction.check_result_array(res, group.dtype) |
| 867 | + check_result_array(res, group.dtype) |
844 | 868 | initialized = True
|
845 | 869 |
|
846 | 870 | result[i] = res
|
@@ -948,7 +972,7 @@ def __init__(
|
948 | 972 | self.indexer = indexer
|
949 | 973 |
|
950 | 974 | # These lengths must match, otherwise we could call agg_series
|
951 |
| - # with empty self.bins, which would raise in libreduction. |
| 975 | + # with empty self.bins, which would raise later. |
952 | 976 | assert len(self.binlabels) == len(self.bins)
|
953 | 977 |
|
954 | 978 | @cache_readonly
|
|
0 commit comments