Skip to content

Commit e3c7f66

Browse files
jbrockmendelyehoshuadimarsky
authored andcommitted
REF: implement _reconstruct_ea_result, _ea_to_cython_values (pandas-dev#46172)
1 parent 25382fe commit e3c7f66

File tree

2 files changed

+36
-19
lines changed

2 files changed

+36
-19
lines changed

pandas/_libs/groupby.pyx

+8
Original file line numberDiff line numberDiff line change
@@ -1350,6 +1350,10 @@ cdef group_min_max(iu_64_floating_t[:, ::1] out,
13501350
else:
13511351
if uses_mask:
13521352
result_mask[i, j] = True
1353+
# set out[i, j] to 0 to be deterministic, as
1354+
# it was initialized with np.empty. Also ensures
1355+
# we can downcast out if appropriate.
1356+
out[i, j] = 0
13531357
else:
13541358
out[i, j] = nan_val
13551359
else:
@@ -1494,6 +1498,10 @@ cdef group_cummin_max(iu_64_floating_t[:, ::1] out,
14941498
if not skipna and na_possible and seen_na[lab, j]:
14951499
if uses_mask:
14961500
mask[i, j] = 1 # FIXME: shouldn't alter inplace
1501+
# Set to 0 ensures that we are deterministic and can
1502+
# downcast if appropriate
1503+
out[i, j] = 0
1504+
14971505
else:
14981506
out[i, j] = na_val
14991507
else:

pandas/core/groupby/ops.py

+28-19
Original file line numberDiff line numberDiff line change
@@ -339,6 +339,26 @@ def _ea_wrap_cython_operation(
339339
**kwargs,
340340
)
341341

342+
npvalues = self._ea_to_cython_values(values)
343+
344+
res_values = self._cython_op_ndim_compat(
345+
npvalues,
346+
min_count=min_count,
347+
ngroups=ngroups,
348+
comp_ids=comp_ids,
349+
mask=None,
350+
**kwargs,
351+
)
352+
353+
if self.how in ["rank"]:
354+
# i.e. how in WrappedCythonOp.cast_blocklist, since
355+
# other cast_blocklist methods dont go through cython_operation
356+
return res_values
357+
358+
return self._reconstruct_ea_result(values, res_values)
359+
360+
def _ea_to_cython_values(self, values: ExtensionArray):
361+
# GH#43682
342362
if isinstance(values, (DatetimeArray, PeriodArray, TimedeltaArray)):
343363
# All of the functions implemented here are ordinal, so we can
344364
# operate on the tz-naive equivalents
@@ -356,22 +376,7 @@ def _ea_wrap_cython_operation(
356376
raise NotImplementedError(
357377
f"function is not implemented for this dtype: {values.dtype}"
358378
)
359-
360-
res_values = self._cython_op_ndim_compat(
361-
npvalues,
362-
min_count=min_count,
363-
ngroups=ngroups,
364-
comp_ids=comp_ids,
365-
mask=None,
366-
**kwargs,
367-
)
368-
369-
if self.how in ["rank"]:
370-
# i.e. how in WrappedCythonOp.cast_blocklist, since
371-
# other cast_blocklist methods dont go through cython_operation
372-
return res_values
373-
374-
return self._reconstruct_ea_result(values, res_values)
379+
return npvalues
375380

376381
def _reconstruct_ea_result(self, values, res_values):
377382
"""
@@ -387,6 +392,7 @@ def _reconstruct_ea_result(self, values, res_values):
387392
return cls._from_sequence(res_values, dtype=dtype)
388393

389394
elif needs_i8_conversion(values.dtype):
395+
assert res_values.dtype.kind != "f" # just to be on the safe side
390396
i8values = res_values.view("i8")
391397
return type(values)(i8values, dtype=values.dtype)
392398

@@ -577,9 +583,12 @@ def _call_cython_op(
577583
cutoff = max(1, min_count)
578584
empty_groups = counts < cutoff
579585
if empty_groups.any():
580-
# Note: this conversion could be lossy, see GH#40767
581-
result = result.astype("float64")
582-
result[empty_groups] = np.nan
586+
if result_mask is not None and self.uses_mask():
587+
assert result_mask[empty_groups].all()
588+
else:
589+
# Note: this conversion could be lossy, see GH#40767
590+
result = result.astype("float64")
591+
result[empty_groups] = np.nan
583592

584593
result = result.T
585594

0 commit comments

Comments
 (0)