Skip to content

Commit 3a973b6

Browse files
authored
TYP: fix ignores in core.groupby (pandas-dev#40643)
1 parent 7c71498 commit 3a973b6

File tree

3 files changed

+21
-26
lines changed

3 files changed

+21
-26
lines changed

pandas/core/groupby/generic.py

+8-9
Original file line numberDiff line numberDiff line change
@@ -355,15 +355,15 @@ def _aggregate_multiple_funcs(self, arg):
355355
# TODO: index should not be Optional - see GH 35490
356356
def _wrap_series_output(
357357
self,
358-
output: Mapping[base.OutputKey, Union[Series, np.ndarray]],
358+
output: Mapping[base.OutputKey, Union[Series, ArrayLike]],
359359
index: Optional[Index],
360360
) -> FrameOrSeriesUnion:
361361
"""
362362
Wraps the output of a SeriesGroupBy operation into the expected result.
363363
364364
Parameters
365365
----------
366-
output : Mapping[base.OutputKey, Union[Series, np.ndarray]]
366+
output : Mapping[base.OutputKey, Union[Series, np.ndarray, ExtensionArray]]
367367
Data to wrap.
368368
index : pd.Index or None
369369
Index to apply to the output.
@@ -420,14 +420,14 @@ def _wrap_aggregated_output(
420420
return self._reindex_output(result)
421421

422422
def _wrap_transformed_output(
423-
self, output: Mapping[base.OutputKey, Union[Series, np.ndarray]]
423+
self, output: Mapping[base.OutputKey, Union[Series, ArrayLike]]
424424
) -> Series:
425425
"""
426426
Wraps the output of a SeriesGroupBy aggregation into the expected result.
427427
428428
Parameters
429429
----------
430-
output : dict[base.OutputKey, Union[Series, np.ndarray]]
430+
output : dict[base.OutputKey, Union[Series, np.ndarray, ExtensionArray]]
431431
Dict with a sole key of 0 and a value of the result values.
432432
433433
Returns
@@ -1119,6 +1119,7 @@ def cast_agg_result(result, values: ArrayLike, how: str) -> ArrayLike:
11191119

11201120
if isinstance(values, Categorical) and isinstance(result, np.ndarray):
11211121
# If the Categorical op didn't raise, it is dtype-preserving
1122+
# We get here with how="first", "last", "min", "max"
11221123
result = type(values)._from_sequence(result.ravel(), dtype=values.dtype)
11231124
# Note this will have result.dtype == dtype from above
11241125

@@ -1195,9 +1196,7 @@ def array_func(values: ArrayLike) -> ArrayLike:
11951196
assert how == "ohlc"
11961197
raise
11971198

1198-
# error: Incompatible types in assignment (expression has type
1199-
# "ExtensionArray", variable has type "ndarray")
1200-
result = py_fallback(values) # type: ignore[assignment]
1199+
result = py_fallback(values)
12011200

12021201
return cast_agg_result(result, values, how)
12031202

@@ -1753,14 +1752,14 @@ def _wrap_aggregated_output(
17531752
return self._reindex_output(result)
17541753

17551754
def _wrap_transformed_output(
1756-
self, output: Mapping[base.OutputKey, Union[Series, np.ndarray]]
1755+
self, output: Mapping[base.OutputKey, Union[Series, ArrayLike]]
17571756
) -> DataFrame:
17581757
"""
17591758
Wraps the output of DataFrameGroupBy transformations into the expected result.
17601759
17611760
Parameters
17621761
----------
1763-
output : Mapping[base.OutputKey, Union[Series, np.ndarray]]
1762+
output : Mapping[base.OutputKey, Union[Series, np.ndarray, ExtensionArray]]
17641763
Data to wrap.
17651764
17661765
Returns

pandas/core/groupby/groupby.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,6 @@ class providing the base-class of operations.
8484
import pandas.core.algorithms as algorithms
8585
from pandas.core.arrays import (
8686
Categorical,
87-
DatetimeArray,
8887
ExtensionArray,
8988
)
9089
from pandas.core.base import (
@@ -1026,7 +1025,7 @@ def _cumcount_array(self, ascending: bool = True):
10261025
def _cython_transform(
10271026
self, how: str, numeric_only: bool = True, axis: int = 0, **kwargs
10281027
):
1029-
output: Dict[base.OutputKey, np.ndarray] = {}
1028+
output: Dict[base.OutputKey, ArrayLike] = {}
10301029

10311030
for idx, obj in enumerate(self._iterate_slices()):
10321031
name = obj.name
@@ -1054,7 +1053,7 @@ def _wrap_aggregated_output(
10541053
):
10551054
raise AbstractMethodError(self)
10561055

1057-
def _wrap_transformed_output(self, output: Mapping[base.OutputKey, np.ndarray]):
1056+
def _wrap_transformed_output(self, output: Mapping[base.OutputKey, ArrayLike]):
10581057
raise AbstractMethodError(self)
10591058

10601059
def _wrap_applied_output(self, data, keys, values, not_indexed_same: bool = False):
@@ -1099,7 +1098,7 @@ def _agg_general(
10991098
def _cython_agg_general(
11001099
self, how: str, alt=None, numeric_only: bool = True, min_count: int = -1
11011100
):
1102-
output: Dict[base.OutputKey, Union[np.ndarray, DatetimeArray]] = {}
1101+
output: Dict[base.OutputKey, ArrayLike] = {}
11031102
# Ideally we would be able to enumerate self._iterate_slices and use
11041103
# the index from enumeration as the key of output, but ohlc in particular
11051104
# returns a (n x 4) array. Output requires 1D ndarrays as values, so we

pandas/core/groupby/ops.py

+10-13
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
import pandas._libs.groupby as libgroupby
3232
import pandas._libs.reduction as libreduction
3333
from pandas._typing import (
34+
ArrayLike,
3435
DtypeObj,
3536
F,
3637
FrameOrSeries,
@@ -524,7 +525,7 @@ def _disallow_invalid_ops(
524525
@final
525526
def _ea_wrap_cython_operation(
526527
self, kind: str, values, how: str, axis: int, min_count: int = -1, **kwargs
527-
) -> np.ndarray:
528+
) -> ArrayLike:
528529
"""
529530
If we have an ExtensionArray, unwrap, call _cython_operation, and
530531
re-wrap if appropriate.
@@ -576,7 +577,7 @@ def _ea_wrap_cython_operation(
576577
@final
577578
def _cython_operation(
578579
self, kind: str, values, how: str, axis: int, min_count: int = -1, **kwargs
579-
) -> np.ndarray:
580+
) -> ArrayLike:
580581
"""
581582
Returns the values of a cython operation.
582583
"""
@@ -683,11 +684,11 @@ def _cython_operation(
683684
# e.g. if we are int64 and need to restore to datetime64/timedelta64
684685
# "rank" is the only member of cython_cast_blocklist we get here
685686
dtype = maybe_cast_result_dtype(orig_values.dtype, how)
686-
# error: Incompatible types in assignment (expression has type
687-
# "Union[ExtensionArray, ndarray]", variable has type "ndarray")
688-
result = maybe_downcast_to_dtype(result, dtype) # type: ignore[assignment]
687+
op_result = maybe_downcast_to_dtype(result, dtype)
688+
else:
689+
op_result = result
689690

690-
return result
691+
return op_result
691692

692693
@final
693694
def _aggregate(
@@ -784,14 +785,10 @@ def _aggregate_series_pure_python(self, obj: Series, func: F):
784785
counts[label] = group.shape[0]
785786
result[label] = res
786787

787-
result = lib.maybe_convert_objects(result, try_float=False)
788-
# error: Incompatible types in assignment (expression has type
789-
# "Union[ExtensionArray, ndarray]", variable has type "ndarray")
790-
result = maybe_cast_result( # type: ignore[assignment]
791-
result, obj, numeric_only=True
792-
)
788+
out = lib.maybe_convert_objects(result, try_float=False)
789+
out = maybe_cast_result(out, obj, numeric_only=True)
793790

794-
return result, counts
791+
return out, counts
795792

796793

797794
class BinGrouper(BaseGrouper):

0 commit comments

Comments
 (0)