Skip to content

Commit 8ba7d2f

Browse files
jbrockmendelyehoshuadimarsky
authored andcommitted
REF: simplify groupby nullable wrapping (pandas-dev#46236)
1 parent bb37493 commit 8ba7d2f

File tree

2 files changed

+12
-16
lines changed

2 files changed

+12
-16
lines changed

pandas/_libs/groupby.pyx

+1
Original file line numberDiff line numberDiff line change
@@ -1238,6 +1238,7 @@ def group_nth(
12381238
if nobs[i, j] < min_count:
12391239
if uses_mask:
12401240
result_mask[i, j] = True
1241+
out[i, j] = 0
12411242
elif iu_64_floating_obj_t is int64_t:
12421243
# TODO: only if datetimelike?
12431244
out[i, j] = NPY_NAT

pandas/core/groupby/ops.py

+11-16
Original file line numberDiff line numberDiff line change
@@ -367,6 +367,7 @@ def _reconstruct_ea_result(
367367
"""
368368
Construct an ExtensionArray result from an ndarray result.
369369
"""
370+
dtype: BaseMaskedDtype | StringDtype
370371

371372
if isinstance(values.dtype, StringDtype):
372373
dtype = values.dtype
@@ -375,19 +376,17 @@ def _reconstruct_ea_result(
375376

376377
elif isinstance(values.dtype, BaseMaskedDtype):
377378
new_dtype = self._get_result_dtype(values.dtype.numpy_dtype)
379+
dtype = BaseMaskedDtype.from_numpy_dtype(new_dtype)
378380
# error: Incompatible types in assignment (expression has type
379-
# "BaseMaskedDtype", variable has type "StringDtype")
380-
dtype = BaseMaskedDtype.from_numpy_dtype( # type: ignore[assignment]
381-
new_dtype
382-
)
383-
cls = dtype.construct_array_type()
381+
# "Type[BaseMaskedArray]", variable has type "Type[BaseStringArray]")
382+
cls = dtype.construct_array_type() # type: ignore[assignment]
384383
return cls._from_sequence(res_values, dtype=dtype)
385384

386-
elif needs_i8_conversion(values.dtype):
387-
assert res_values.dtype.kind != "f" # just to be on the safe side
388-
i8values = res_values.view("i8")
389-
# error: Too many arguments for "ExtensionArray"
390-
return type(values)(i8values, dtype=values.dtype) # type: ignore[call-arg]
385+
elif isinstance(values, (DatetimeArray, TimedeltaArray, PeriodArray)):
386+
# In to_cython_values we took a view as M8[ns]
387+
assert res_values.dtype == "M8[ns]"
388+
res_values = res_values.view(values._ndarray.dtype)
389+
return values._from_backing_data(res_values)
391390

392391
raise NotImplementedError
393392

@@ -425,12 +424,8 @@ def _masked_ea_wrap_cython_operation(
425424
**kwargs,
426425
)
427426

428-
new_dtype = self._get_result_dtype(orig_values.dtype.numpy_dtype)
429-
dtype = BaseMaskedDtype.from_numpy_dtype(new_dtype)
430-
# TODO: avoid cast as res_values *should* already have the right
431-
# dtype; last attempt ran into trouble on 32bit linux build
432-
res_values = res_values.astype(dtype.type, copy=False)
433-
427+
# res_values should already have the correct dtype, we just need to
428+
# wrap in a MaskedArray
434429
return orig_values._maybe_mask_result(res_values, result_mask)
435430

436431
@final

0 commit comments

Comments
 (0)