Skip to content

Commit 059c8ba

Browse files
authored
REF: remove no-op casting in groupby.generic (#41048)
1 parent 770918a commit 059c8ba

File tree

3 files changed

+33
-39
lines changed

3 files changed

+33
-39
lines changed

pandas/core/groupby/generic.py

+11-14
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,6 @@
4646

4747
from pandas.core.dtypes.cast import (
4848
find_common_type,
49-
maybe_cast_result_dtype,
5049
maybe_downcast_numeric,
5150
)
5251
from pandas.core.dtypes.common import (
@@ -58,7 +57,6 @@
5857
is_interval_dtype,
5958
is_numeric_dtype,
6059
is_scalar,
61-
needs_i8_conversion,
6260
)
6361
from pandas.core.dtypes.missing import (
6462
isna,
@@ -1104,13 +1102,11 @@ def _cython_agg_manager(
11041102

11051103
using_array_manager = isinstance(data, ArrayManager)
11061104

1107-
def cast_agg_result(result, values: ArrayLike, how: str) -> ArrayLike:
1105+
def cast_agg_result(
1106+
result: ArrayLike, values: ArrayLike, how: str
1107+
) -> ArrayLike:
11081108
# see if we can cast the values to the desired dtype
11091109
# this may not be the original dtype
1110-
assert not isinstance(result, DataFrame)
1111-
1112-
dtype = maybe_cast_result_dtype(values.dtype, how)
1113-
result = maybe_downcast_numeric(result, dtype)
11141110

11151111
if isinstance(values, Categorical) and isinstance(result, np.ndarray):
11161112
# If the Categorical op didn't raise, it is dtype-preserving
@@ -1125,6 +1121,7 @@ def cast_agg_result(result, values: ArrayLike, how: str) -> ArrayLike:
11251121
):
11261122
# We went through a SeriesGroupByPath and need to reshape
11271123
# GH#32223 includes case with IntegerArray values
1124+
# We only get here with values.dtype == object
11281125
result = result.reshape(1, -1)
11291126
# test_groupby_duplicate_columns gets here with
11301127
# result.dtype == int64, values.dtype=object, how="min"
@@ -1140,8 +1137,11 @@ def py_fallback(values: ArrayLike) -> ArrayLike:
11401137

11411138
# call our grouper again with only this block
11421139
if values.ndim == 1:
1140+
# We only get here with ExtensionArray
1141+
11431142
obj = Series(values)
11441143
else:
1144+
# We only get here with values.dtype == object
11451145
# TODO special case not needed with ArrayManager
11461146
obj = DataFrame(values.T)
11471147
if obj.shape[1] == 1:
@@ -1193,7 +1193,8 @@ def array_func(values: ArrayLike) -> ArrayLike:
11931193

11941194
result = py_fallback(values)
11951195

1196-
return cast_agg_result(result, values, how)
1196+
return cast_agg_result(result, values, how)
1197+
return result
11971198

11981199
# TypeError -> we may have an exception in trying to aggregate
11991200
# continue and exclude the block
@@ -1366,11 +1367,7 @@ def _wrap_applied_output_series(
13661367

13671368
# if we have date/time like in the original, then coerce dates
13681369
# as we are stacking can easily have object dtypes here
1369-
so = self._selected_obj
1370-
if so.ndim == 2 and so.dtypes.apply(needs_i8_conversion).any():
1371-
result = result._convert(datetime=True)
1372-
else:
1373-
result = result._convert(datetime=True)
1370+
result = result._convert(datetime=True)
13741371

13751372
if not self.as_index:
13761373
self._insert_inaxis_grouper_inplace(result)
@@ -1507,7 +1504,7 @@ def _choose_path(self, fast_path: Callable, slow_path: Callable, group: DataFram
15071504
try:
15081505
res_fast = fast_path(group)
15091506
except AssertionError:
1510-
raise
1507+
raise # pragma: no cover
15111508
except Exception:
15121509
# GH#29631 For user-defined function, we can't predict what may be
15131510
# raised; see test_transform.test_transform_fastpath_raises

pandas/core/groupby/groupby.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1033,7 +1033,7 @@ def _cython_transform(
10331033
result = self.grouper._cython_operation(
10341034
"transform", obj._values, how, axis, **kwargs
10351035
)
1036-
except NotImplementedError:
1036+
except (NotImplementedError, TypeError):
10371037
continue
10381038

10391039
key = base.OutputKey(label=name, position=idx)

pandas/core/groupby/ops.py

+21-24
Original file line numberDiff line numberDiff line change
@@ -136,23 +136,17 @@ def _get_cython_function(
136136

137137
# see if there is a fused-type version of function
138138
# only valid for numeric
139-
f = getattr(libgroupby, ftype, None)
140-
if f is not None:
141-
if is_numeric:
142-
return f
143-
elif dtype == object:
144-
if "object" not in f.__signatures__:
145-
# raise NotImplementedError here rather than TypeError later
146-
raise NotImplementedError(
147-
f"function is not implemented for this dtype: "
148-
f"[how->{how},dtype->{dtype_str}]"
149-
)
150-
return f
151-
152-
raise NotImplementedError(
153-
f"function is not implemented for this dtype: "
154-
f"[how->{how},dtype->{dtype_str}]"
155-
)
139+
f = getattr(libgroupby, ftype)
140+
if is_numeric:
141+
return f
142+
elif dtype == object:
143+
if "object" not in f.__signatures__:
144+
# raise NotImplementedError here rather than TypeError later
145+
raise NotImplementedError(
146+
f"function is not implemented for this dtype: "
147+
f"[how->{how},dtype->{dtype_str}]"
148+
)
149+
return f
156150

157151
def get_cython_func_and_vals(self, values: np.ndarray, is_numeric: bool):
158152
"""
@@ -208,22 +202,25 @@ def disallow_invalid_ops(self, dtype: DtypeObj, is_numeric: bool = False):
208202
# never an invalid op for those dtypes, so return early as fastpath
209203
return
210204

211-
if is_categorical_dtype(dtype) or is_sparse(dtype):
205+
if is_categorical_dtype(dtype):
206+
# NotImplementedError for methods that can fall back to a
207+
# non-cython implementation.
208+
if how in ["add", "prod", "cumsum", "cumprod"]:
209+
raise TypeError(f"{dtype} type does not support {how} operations")
210+
raise NotImplementedError(f"{dtype} dtype not supported")
211+
212+
elif is_sparse(dtype):
212213
# categoricals are only 1d, so we
213214
# are not setup for dim transforming
214215
raise NotImplementedError(f"{dtype} dtype not supported")
215216
elif is_datetime64_any_dtype(dtype):
216217
# we raise NotImplemented if this is an invalid operation
217218
# entirely, e.g. adding datetimes
218219
if how in ["add", "prod", "cumsum", "cumprod"]:
219-
raise NotImplementedError(
220-
f"datetime64 type does not support {how} operations"
221-
)
220+
raise TypeError(f"datetime64 type does not support {how} operations")
222221
elif is_timedelta64_dtype(dtype):
223222
if how in ["prod", "cumprod"]:
224-
raise NotImplementedError(
225-
f"timedelta64 type does not support {how} operations"
226-
)
223+
raise TypeError(f"timedelta64 type does not support {how} operations")
227224

228225
def get_output_shape(self, ngroups: int, values: np.ndarray) -> Shape:
229226
how = self.how

0 commit comments

Comments
 (0)