72
72
PeriodArray ,
73
73
TimedeltaArray ,
74
74
)
75
- from pandas .core .arrays .boolean import BooleanDtype
76
- from pandas .core .arrays .floating import FloatingDtype
77
- from pandas .core .arrays .integer import IntegerDtype
78
75
from pandas .core .arrays .masked import (
79
76
BaseMaskedArray ,
80
77
BaseMaskedDtype ,
@@ -147,26 +144,6 @@ def __init__(self, kind: str, how: str, has_dropped_na: bool) -> None:
147
144
},
148
145
}
149
146
150
- # "group_any" and "group_all" are also support masks, but don't go
151
- # through WrappedCythonOp
152
- _MASKED_CYTHON_FUNCTIONS = {
153
- "cummin" ,
154
- "cummax" ,
155
- "min" ,
156
- "max" ,
157
- "last" ,
158
- "first" ,
159
- "rank" ,
160
- "sum" ,
161
- "ohlc" ,
162
- "cumprod" ,
163
- "cumsum" ,
164
- "prod" ,
165
- "mean" ,
166
- "var" ,
167
- "median" ,
168
- }
169
-
170
147
_cython_arity = {"ohlc" : 4 } # OHLC
171
148
172
149
# Note: we make this a classmethod and pass kind+how so that caching
@@ -220,8 +197,8 @@ def _get_cython_vals(self, values: np.ndarray) -> np.ndarray:
220
197
"""
221
198
how = self .how
222
199
223
- if how in [ "median" ] :
224
- # these two only have float64 implementations
200
+ if how == "median" :
201
+ # median only has a float64 implementation
225
202
# We should only get here with is_numeric, as non-numeric cases
226
203
# should raise in _get_cython_function
227
204
values = ensure_float64 (values )
@@ -293,7 +270,7 @@ def _get_output_shape(self, ngroups: int, values: np.ndarray) -> Shape:
293
270
294
271
out_shape : Shape
295
272
if how == "ohlc" :
296
- out_shape = (ngroups , 4 )
273
+ out_shape = (ngroups , arity )
297
274
elif arity > 1 :
298
275
raise NotImplementedError (
299
276
"arity of more than 1 is not supported for the 'how' argument"
@@ -342,9 +319,6 @@ def _get_result_dtype(self, dtype: np.dtype) -> np.dtype:
342
319
return np .dtype (np .float64 )
343
320
return dtype
344
321
345
- def uses_mask (self ) -> bool :
346
- return self .how in self ._MASKED_CYTHON_FUNCTIONS
347
-
348
322
@final
349
323
def _ea_wrap_cython_operation (
350
324
self ,
@@ -358,7 +332,7 @@ def _ea_wrap_cython_operation(
358
332
If we have an ExtensionArray, unwrap, call _cython_operation, and
359
333
re-wrap if appropriate.
360
334
"""
361
- if isinstance (values , BaseMaskedArray ) and self . uses_mask () :
335
+ if isinstance (values , BaseMaskedArray ):
362
336
return self ._masked_ea_wrap_cython_operation (
363
337
values ,
364
338
min_count = min_count ,
@@ -367,7 +341,7 @@ def _ea_wrap_cython_operation(
367
341
** kwargs ,
368
342
)
369
343
370
- elif isinstance (values , Categorical ) and self . uses_mask () :
344
+ elif isinstance (values , Categorical ):
371
345
assert self .how == "rank" # the only one implemented ATM
372
346
assert values .ordered # checked earlier
373
347
mask = values .isna ()
@@ -398,7 +372,7 @@ def _ea_wrap_cython_operation(
398
372
)
399
373
400
374
if self .how in self .cast_blocklist :
401
- # i.e. how in ["rank"], since other cast_blocklist methods dont go
375
+ # i.e. how in ["rank"], since other cast_blocklist methods don't go
402
376
# through cython_operation
403
377
return res_values
404
378
@@ -411,12 +385,6 @@ def _ea_to_cython_values(self, values: ExtensionArray) -> np.ndarray:
411
385
# All of the functions implemented here are ordinal, so we can
412
386
# operate on the tz-naive equivalents
413
387
npvalues = values ._ndarray .view ("M8[ns]" )
414
- elif isinstance (values .dtype , (BooleanDtype , IntegerDtype )):
415
- # IntegerArray or BooleanArray
416
- npvalues = values .to_numpy ("float64" , na_value = np .nan )
417
- elif isinstance (values .dtype , FloatingDtype ):
418
- # FloatingArray
419
- npvalues = values .to_numpy (values .dtype .numpy_dtype , na_value = np .nan )
420
388
elif isinstance (values .dtype , StringDtype ):
421
389
# StringArray
422
390
npvalues = values .to_numpy (object , na_value = np .nan )
@@ -440,12 +408,6 @@ def _reconstruct_ea_result(
440
408
string_array_cls = dtype .construct_array_type ()
441
409
return string_array_cls ._from_sequence (res_values , dtype = dtype )
442
410
443
- elif isinstance (values .dtype , BaseMaskedDtype ):
444
- new_dtype = self ._get_result_dtype (values .dtype .numpy_dtype )
445
- dtype = BaseMaskedDtype .from_numpy_dtype (new_dtype )
446
- masked_array_cls = dtype .construct_array_type ()
447
- return masked_array_cls ._from_sequence (res_values , dtype = dtype )
448
-
449
411
elif isinstance (values , (DatetimeArray , TimedeltaArray , PeriodArray )):
450
412
# In to_cython_values we took a view as M8[ns]
451
413
assert res_values .dtype == "M8[ns]"
@@ -489,7 +451,8 @@ def _masked_ea_wrap_cython_operation(
489
451
)
490
452
491
453
if self .how == "ohlc" :
492
- result_mask = np .tile (result_mask , (4 , 1 )).T
454
+ arity = self ._cython_arity .get (self .how , 1 )
455
+ result_mask = np .tile (result_mask , (arity , 1 )).T
493
456
494
457
# res_values should already have the correct dtype, we just need to
495
458
# wrap in a MaskedArray
@@ -580,7 +543,7 @@ def _call_cython_op(
580
543
result = maybe_fill (np .empty (out_shape , dtype = out_dtype ))
581
544
if self .kind == "aggregate" :
582
545
counts = np .zeros (ngroups , dtype = np .int64 )
583
- if self .how in ["min" , "max" , "mean" , "last" , "first" ]:
546
+ if self .how in ["min" , "max" , "mean" , "last" , "first" , "sum" ]:
584
547
func (
585
548
out = result ,
586
549
counts = counts ,
@@ -591,18 +554,6 @@ def _call_cython_op(
591
554
result_mask = result_mask ,
592
555
is_datetimelike = is_datetimelike ,
593
556
)
594
- elif self .how in ["sum" ]:
595
- # We support datetimelike
596
- func (
597
- out = result ,
598
- counts = counts ,
599
- values = values ,
600
- labels = comp_ids ,
601
- mask = mask ,
602
- result_mask = result_mask ,
603
- min_count = min_count ,
604
- is_datetimelike = is_datetimelike ,
605
- )
606
557
elif self .how in ["var" , "ohlc" , "prod" , "median" ]:
607
558
func (
608
559
result ,
@@ -615,31 +566,21 @@ def _call_cython_op(
615
566
** kwargs ,
616
567
)
617
568
else :
618
- func ( result , counts , values , comp_ids , min_count )
569
+ raise NotImplementedError ( f" { self . how } is not implemented" )
619
570
else :
620
571
# TODO: min_count
621
- if self .uses_mask ():
622
- if self .how != "rank" :
623
- # TODO: should rank take result_mask?
624
- kwargs ["result_mask" ] = result_mask
625
- func (
626
- out = result ,
627
- values = values ,
628
- labels = comp_ids ,
629
- ngroups = ngroups ,
630
- is_datetimelike = is_datetimelike ,
631
- mask = mask ,
632
- ** kwargs ,
633
- )
634
- else :
635
- func (
636
- out = result ,
637
- values = values ,
638
- labels = comp_ids ,
639
- ngroups = ngroups ,
640
- is_datetimelike = is_datetimelike ,
641
- ** kwargs ,
642
- )
572
+ if self .how != "rank" :
573
+ # TODO: should rank take result_mask?
574
+ kwargs ["result_mask" ] = result_mask
575
+ func (
576
+ out = result ,
577
+ values = values ,
578
+ labels = comp_ids ,
579
+ ngroups = ngroups ,
580
+ is_datetimelike = is_datetimelike ,
581
+ mask = mask ,
582
+ ** kwargs ,
583
+ )
643
584
644
585
if self .kind == "aggregate" :
645
586
# i.e. counts is defined. Locations where count<min_count
@@ -650,7 +591,7 @@ def _call_cython_op(
650
591
cutoff = max (0 if self .how in ["sum" , "prod" ] else 1 , min_count )
651
592
empty_groups = counts < cutoff
652
593
if empty_groups .any ():
653
- if result_mask is not None and self . uses_mask () :
594
+ if result_mask is not None :
654
595
assert result_mask [empty_groups ].all ()
655
596
else :
656
597
# Note: this conversion could be lossy, see GH#40767
0 commit comments