@@ -123,14 +123,14 @@ def __init__(self, kind: str, how: str):
123
123
"min" : "group_min" ,
124
124
"max" : "group_max" ,
125
125
"mean" : "group_mean" ,
126
- "median" : "group_median " ,
126
+ "median" : "group_median_float64 " ,
127
127
"var" : "group_var" ,
128
128
"first" : "group_nth" ,
129
129
"last" : "group_last" ,
130
130
"ohlc" : "group_ohlc" ,
131
131
},
132
132
"transform" : {
133
- "cumprod" : "group_cumprod " ,
133
+ "cumprod" : "group_cumprod_float64 " ,
134
134
"cumsum" : "group_cumsum" ,
135
135
"cummin" : "group_cummin" ,
136
136
"cummax" : "group_cummax" ,
@@ -161,52 +161,54 @@ def _get_cython_function(
161
161
if is_numeric :
162
162
return f
163
163
elif dtype == object :
164
- if "object" not in f .__signatures__ :
164
+ if how in ["median" , "cumprod" ]:
165
+ # no fused types -> no __signatures__
166
+ raise NotImplementedError (
167
+ f"function is not implemented for this dtype: "
168
+ f"[how->{ how } ,dtype->{ dtype_str } ]"
169
+ )
170
+ elif "object" not in f .__signatures__ :
165
171
# raise NotImplementedError here rather than TypeError later
166
172
raise NotImplementedError (
167
173
f"function is not implemented for this dtype: "
168
174
f"[how->{ how } ,dtype->{ dtype_str } ]"
169
175
)
170
176
return f
177
+ else :
178
+ raise NotImplementedError (
179
+ "This should not be reached. Please report a bug at "
180
+ "github.com/pandas-dev/pandas/" ,
181
+ dtype ,
182
+ )
171
183
172
- def get_cython_func_and_vals (self , values : np .ndarray , is_numeric : bool ) :
184
+ def _get_cython_vals (self , values : np .ndarray ) -> np . ndarray :
173
185
"""
174
- Find the appropriate cython function, casting if necessary .
186
+ Cast numeric dtypes to float64 for functions that only support that .
175
187
176
188
Parameters
177
189
----------
178
190
values : np.ndarray
179
- is_numeric : bool
180
191
181
192
Returns
182
193
-------
183
- func : callable
184
194
values : np.ndarray
185
195
"""
186
196
how = self .how
187
- kind = self .kind
188
197
189
198
if how in ["median" , "cumprod" ]:
190
199
# these two only have float64 implementations
191
- if is_numeric :
192
- values = ensure_float64 (values )
193
- else :
194
- raise NotImplementedError (
195
- f"function is not implemented for this dtype: "
196
- f"[how->{ how } ,dtype->{ values .dtype .name } ]"
197
- )
198
- func = getattr (libgroupby , f"group_{ how } _float64" )
199
- return func , values
200
-
201
- func = self ._get_cython_function (kind , how , values .dtype , is_numeric )
200
+ # We should only get here with is_numeric, as non-numeric cases
201
+ # should raise in _get_cython_function
202
+ values = ensure_float64 (values )
202
203
203
- if values .dtype .kind in ["i" , "u" ]:
204
+ elif values .dtype .kind in ["i" , "u" ]:
204
205
if how in ["add" , "var" , "prod" , "mean" , "ohlc" ]:
205
206
# result may still include NaN, so we have to cast
206
207
values = ensure_float64 (values )
207
208
208
- return func , values
209
+ return values
209
210
211
+ # TODO: general case implementation overridable by EAs.
210
212
def _disallow_invalid_ops (self , dtype : DtypeObj , is_numeric : bool = False ):
211
213
"""
212
214
Check if we can do this operation with our cython functions.
@@ -235,6 +237,7 @@ def _disallow_invalid_ops(self, dtype: DtypeObj, is_numeric: bool = False):
235
237
# are not setup for dim transforming
236
238
raise NotImplementedError (f"{ dtype } dtype not supported" )
237
239
elif is_datetime64_any_dtype (dtype ):
240
+ # TODO: same for period_dtype? no for these methods with Period
238
241
# we raise NotImplemented if this is an invalid operation
239
242
# entirely, e.g. adding datetimes
240
243
if how in ["add" , "prod" , "cumsum" , "cumprod" ]:
@@ -262,7 +265,7 @@ def _get_output_shape(self, ngroups: int, values: np.ndarray) -> Shape:
262
265
out_shape = (ngroups ,) + values .shape [1 :]
263
266
return out_shape
264
267
265
- def get_out_dtype (self , dtype : np .dtype ) -> np .dtype :
268
+ def _get_out_dtype (self , dtype : np .dtype ) -> np .dtype :
266
269
how = self .how
267
270
268
271
if how == "rank" :
@@ -282,6 +285,7 @@ def _get_result_dtype(self, dtype: np.dtype) -> np.dtype:
282
285
def _get_result_dtype (self , dtype : ExtensionDtype ) -> ExtensionDtype :
283
286
... # pragma: no cover
284
287
288
+ # TODO: general case implementation overridable by EAs.
285
289
def _get_result_dtype (self , dtype : DtypeObj ) -> DtypeObj :
286
290
"""
287
291
Get the desired dtype of a result based on the
@@ -329,7 +333,6 @@ def _ea_wrap_cython_operation(
329
333
If we have an ExtensionArray, unwrap, call _cython_operation, and
330
334
re-wrap if appropriate.
331
335
"""
332
- # TODO: general case implementation overridable by EAs.
333
336
if isinstance (values , BaseMaskedArray ) and self .uses_mask ():
334
337
return self ._masked_ea_wrap_cython_operation (
335
338
values ,
@@ -357,7 +360,8 @@ def _ea_wrap_cython_operation(
357
360
358
361
return self ._reconstruct_ea_result (values , res_values )
359
362
360
- def _ea_to_cython_values (self , values : ExtensionArray ):
363
+ # TODO: general case implementation overridable by EAs.
364
+ def _ea_to_cython_values (self , values : ExtensionArray ) -> np .ndarray :
361
365
# GH#43682
362
366
if isinstance (values , (DatetimeArray , PeriodArray , TimedeltaArray )):
363
367
# All of the functions implemented here are ordinal, so we can
@@ -378,23 +382,24 @@ def _ea_to_cython_values(self, values: ExtensionArray):
378
382
)
379
383
return npvalues
380
384
381
- def _reconstruct_ea_result (self , values , res_values ):
385
+ # TODO: general case implementation overridable by EAs.
386
+ def _reconstruct_ea_result (
387
+ self , values : ExtensionArray , res_values : np .ndarray
388
+ ) -> ExtensionArray :
382
389
"""
383
390
Construct an ExtensionArray result from an ndarray result.
384
391
"""
385
- # TODO: allow EAs to override this logic
386
392
387
- if isinstance (
388
- values .dtype , (BooleanDtype , IntegerDtype , FloatingDtype , StringDtype )
389
- ):
393
+ if isinstance (values .dtype , (BaseMaskedDtype , StringDtype )):
390
394
dtype = self ._get_result_dtype (values .dtype )
391
395
cls = dtype .construct_array_type ()
392
396
return cls ._from_sequence (res_values , dtype = dtype )
393
397
394
398
elif needs_i8_conversion (values .dtype ):
395
399
assert res_values .dtype .kind != "f" # just to be on the safe side
396
400
i8values = res_values .view ("i8" )
397
- return type (values )(i8values , dtype = values .dtype )
401
+ # error: Too many arguments for "ExtensionArray"
402
+ return type (values )(i8values , dtype = values .dtype ) # type: ignore[call-arg]
398
403
399
404
raise NotImplementedError
400
405
@@ -429,13 +434,16 @@ def _masked_ea_wrap_cython_operation(
429
434
)
430
435
431
436
dtype = self ._get_result_dtype (orig_values .dtype )
432
- assert isinstance (dtype , BaseMaskedDtype )
433
- cls = dtype .construct_array_type ()
437
+ # TODO: avoid cast as res_values *should* already have the right
438
+ # dtype; last attempt ran into trouble on 32bit linux build
439
+ res_values = res_values .astype (dtype .type , copy = False )
434
440
435
441
if self .kind != "aggregate" :
436
- return cls ( res_values . astype ( dtype . type , copy = False ), mask )
442
+ out_mask = mask
437
443
else :
438
- return cls (res_values .astype (dtype .type , copy = False ), result_mask )
444
+ out_mask = result_mask
445
+
446
+ return orig_values ._maybe_mask_result (res_values , out_mask )
439
447
440
448
@final
441
449
def _cython_op_ndim_compat (
@@ -521,8 +529,9 @@ def _call_cython_op(
521
529
result_mask = result_mask .T
522
530
523
531
out_shape = self ._get_output_shape (ngroups , values )
524
- func , values = self .get_cython_func_and_vals (values , is_numeric )
525
- out_dtype = self .get_out_dtype (values .dtype )
532
+ func = self ._get_cython_function (self .kind , self .how , values .dtype , is_numeric )
533
+ values = self ._get_cython_vals (values )
534
+ out_dtype = self ._get_out_dtype (values .dtype )
526
535
527
536
result = maybe_fill (np .empty (out_shape , dtype = out_dtype ))
528
537
if self .kind == "aggregate" :
0 commit comments