@@ -341,95 +341,54 @@ def _ea_wrap_cython_operation(
341
341
comp_ids = comp_ids ,
342
342
** kwargs ,
343
343
)
344
- orig_values = values
345
344
346
- if isinstance (orig_values , (DatetimeArray , PeriodArray )):
345
+ if isinstance (values , (DatetimeArray , PeriodArray , TimedeltaArray )):
347
346
# All of the functions implemented here are ordinal, so we can
348
347
# operate on the tz-naive equivalents
349
- npvalues = orig_values ._ndarray .view ("M8[ns]" )
350
- res_values = self ._cython_op_ndim_compat (
351
- npvalues ,
352
- min_count = min_count ,
353
- ngroups = ngroups ,
354
- comp_ids = comp_ids ,
355
- mask = None ,
356
- ** kwargs ,
357
- )
358
- if self .how in ["rank" ]:
359
- # i.e. how in WrappedCythonOp.cast_blocklist, since
360
- # other cast_blocklist methods dont go through cython_operation
361
- # preserve float64 dtype
362
- return res_values
363
-
364
- res_values = res_values .view ("i8" )
365
- result = type (orig_values )(res_values , dtype = orig_values .dtype )
366
- return result
367
-
368
- elif isinstance (orig_values , TimedeltaArray ):
369
- # We have an ExtensionArray but not ExtensionDtype
370
- res_values = self ._cython_op_ndim_compat (
371
- orig_values ._ndarray ,
372
- min_count = min_count ,
373
- ngroups = ngroups ,
374
- comp_ids = comp_ids ,
375
- mask = None ,
376
- ** kwargs ,
377
- )
378
- if self .how in ["rank" ]:
379
- # i.e. how in WrappedCythonOp.cast_blocklist, since
380
- # other cast_blocklist methods dont go through cython_operation
381
- # preserve float64 dtype
382
- return res_values
383
-
384
- # otherwise res_values has the same dtype as original values
385
- return type (orig_values )(res_values )
386
-
348
+ npvalues = values ._ndarray .view ("M8[ns]" )
387
349
elif isinstance (values .dtype , (BooleanDtype , _IntegerDtype )):
388
350
# IntegerArray or BooleanArray
389
351
npvalues = values .to_numpy ("float64" , na_value = np .nan )
390
- res_values = self ._cython_op_ndim_compat (
391
- npvalues ,
392
- min_count = min_count ,
393
- ngroups = ngroups ,
394
- comp_ids = comp_ids ,
395
- mask = None ,
396
- ** kwargs ,
397
- )
398
- if self .how in ["rank" ]:
399
- # i.e. how in WrappedCythonOp.cast_blocklist, since
400
- # other cast_blocklist methods dont go through cython_operation
401
- return res_values
402
-
403
- dtype = self ._get_result_dtype (orig_values .dtype )
404
- cls = dtype .construct_array_type ()
405
- return cls ._from_sequence (res_values , dtype = dtype )
406
-
407
352
elif isinstance (values .dtype , FloatingDtype ):
408
353
# FloatingArray
409
- npvalues = values .to_numpy (
410
- values .dtype .numpy_dtype ,
411
- na_value = np .nan ,
412
- )
413
- res_values = self ._cython_op_ndim_compat (
414
- npvalues ,
415
- min_count = min_count ,
416
- ngroups = ngroups ,
417
- comp_ids = comp_ids ,
418
- mask = None ,
419
- ** kwargs ,
354
+ npvalues = values .to_numpy (values .dtype .numpy_dtype , na_value = np .nan )
355
+ else :
356
+ raise NotImplementedError (
357
+ f"function is not implemented for this dtype: { values .dtype } "
420
358
)
421
- if self .how in ["rank" ]:
422
- # i.e. how in WrappedCythonOp.cast_blocklist, since
423
- # other cast_blocklist methods dont go through cython_operation
424
- return res_values
425
359
426
- dtype = self ._get_result_dtype (orig_values .dtype )
360
+ res_values = self ._cython_op_ndim_compat (
361
+ npvalues ,
362
+ min_count = min_count ,
363
+ ngroups = ngroups ,
364
+ comp_ids = comp_ids ,
365
+ mask = None ,
366
+ ** kwargs ,
367
+ )
368
+
369
+ if self .how in ["rank" ]:
370
+ # i.e. how in WrappedCythonOp.cast_blocklist, since
371
+ # other cast_blocklist methods dont go through cython_operation
372
+ return res_values
373
+
374
+ return self ._reconstruct_ea_result (values , res_values )
375
+
376
+ def _reconstruct_ea_result (self , values , res_values ):
377
+ """
378
+ Construct an ExtensionArray result from an ndarray result.
379
+ """
380
+ # TODO: allow EAs to override this logic
381
+
382
+ if isinstance (values .dtype , (BooleanDtype , _IntegerDtype , FloatingDtype )):
383
+ dtype = self ._get_result_dtype (values .dtype )
427
384
cls = dtype .construct_array_type ()
428
385
return cls ._from_sequence (res_values , dtype = dtype )
429
386
430
- raise NotImplementedError (
431
- f"function is not implemented for this dtype: { values .dtype } "
432
- )
387
+ elif needs_i8_conversion (values .dtype ):
388
+ i8values = res_values .view ("i8" )
389
+ return type (values )(i8values , dtype = values .dtype )
390
+
391
+ raise NotImplementedError
433
392
434
393
@final
435
394
def _masked_ea_wrap_cython_operation (
@@ -478,6 +437,8 @@ def _cython_op_ndim_compat(
478
437
if values .ndim == 1 :
479
438
# expand to 2d, dispatch, then squeeze if appropriate
480
439
values2d = values [None , :]
440
+ if mask is not None :
441
+ mask = mask [None , :]
481
442
res = self ._call_cython_op (
482
443
values2d ,
483
444
min_count = min_count ,
@@ -533,9 +494,8 @@ def _call_cython_op(
533
494
values = ensure_float64 (values )
534
495
535
496
values = values .T
536
-
537
497
if mask is not None :
538
- mask = mask .reshape ( values . shape , order = "C" )
498
+ mask = mask .T
539
499
540
500
out_shape = self ._get_output_shape (ngroups , values )
541
501
func , values = self .get_cython_func_and_vals (values , is_numeric )
0 commit comments