Skip to content

Commit d9e2fb5

Browse files
authored
ENH: Add support for groupby.ohlc for ea dtypes (#48081)
* ENH: Add support for groupby.ohlc for ea dtypes * Fix type
1 parent e0cadc5 commit d9e2fb5

File tree

5 files changed

+65
-10
lines changed

5 files changed

+65
-10
lines changed

doc/source/whatsnew/v1.5.0.rst

+1
Original file line numberDiff line numberDiff line change
@@ -287,6 +287,7 @@ Other enhancements
287287
- ``times`` argument in :class:`.ExponentialMovingWindow` now accepts ``np.timedelta64`` (:issue:`47003`)
288288
- :class:`.DataError`, :class:`.SpecificationError`, :class:`.SettingWithCopyError`, :class:`.SettingWithCopyWarning`, :class:`.NumExprClobberingError`, :class:`.UndefinedVariableError`, :class:`.IndexingError`, :class:`.PyperclipException`, :class:`.PyperclipWindowsException`, :class:`.CSSWarning`, :class:`.PossibleDataLossError`, :class:`.ClosedFileError`, :class:`.IncompatibilityWarning`, :class:`.AttributeConflictWarning`, :class:`.DatabaseError, :class:`.PossiblePrecisionLoss, :class:`.ValueLabelTypeMismatch, :class:`.InvalidColumnName, and :class:`.CategoricalConversionWarning` are now exposed in ``pandas.errors`` (:issue:`27656`)
289289
- Added ``check_like`` argument to :func:`testing.assert_series_equal` (:issue:`47247`)
290+
- Add support for :meth:`GroupBy.ohlc` for extension array dtypes (:issue:`37493`)
290291
- Allow reading compressed SAS files with :func:`read_sas` (e.g., ``.sas7bdat.gz`` files)
291292
- :meth:`DatetimeIndex.astype` now supports casting timezone-naive indexes to ``datetime64[s]``, ``datetime64[ms]``, and ``datetime64[us]``, and timezone-aware indexes to the corresponding ``datetime64[unit, tzname]`` dtypes (:issue:`47579`)
292293
- :class:`Series` reducers (e.g. ``min``, ``max``, ``sum``, ``mean``) will now successfully operate when the dtype is numeric and ``numeric_only=True`` is provided; previously this would raise a ``NotImplementedError`` (:issue:`47500`)

pandas/_libs/groupby.pyi

+4-2
Original file line numberDiff line numberDiff line change
@@ -86,11 +86,13 @@ def group_mean(
8686
result_mask: np.ndarray | None = ...,
8787
) -> None: ...
8888
def group_ohlc(
89-
out: np.ndarray, # floating[:, ::1]
89+
out: np.ndarray, # floatingintuint_t[:, ::1]
9090
counts: np.ndarray, # int64_t[::1]
91-
values: np.ndarray, # ndarray[floating, ndim=2]
91+
values: np.ndarray, # ndarray[floatingintuint_t, ndim=2]
9292
labels: np.ndarray, # const intp_t[:]
9393
min_count: int = ...,
94+
mask: np.ndarray | None = ...,
95+
result_mask: np.ndarray | None = ...,
9496
) -> None: ...
9597
def group_quantile(
9698
out: npt.NDArray[np.float64],

pandas/_libs/groupby.pyx

+36-6
Original file line numberDiff line numberDiff line change
@@ -834,21 +834,32 @@ def group_mean(
834834
out[i, j] = sumx[i, j] / count
835835

836836

837+
ctypedef fused int64float_t:
838+
float32_t
839+
float64_t
840+
int64_t
841+
uint64_t
842+
843+
837844
@cython.wraparound(False)
838845
@cython.boundscheck(False)
839846
def group_ohlc(
840-
floating[:, ::1] out,
847+
int64float_t[:, ::1] out,
841848
int64_t[::1] counts,
842-
ndarray[floating, ndim=2] values,
849+
ndarray[int64float_t, ndim=2] values,
843850
const intp_t[::1] labels,
844851
Py_ssize_t min_count=-1,
852+
const uint8_t[:, ::1] mask=None,
853+
uint8_t[:, ::1] result_mask=None,
845854
) -> None:
846855
"""
847856
Only aggregates on axis=0
848857
"""
849858
cdef:
850859
Py_ssize_t i, j, N, K, lab
851-
floating val
860+
int64float_t val
861+
uint8_t[::1] first_element_set
862+
bint isna_entry, uses_mask = not mask is None
852863

853864
assert min_count == -1, "'min_count' only used in sum and prod"
854865

@@ -862,7 +873,15 @@ def group_ohlc(
862873

863874
if K > 1:
864875
raise NotImplementedError("Argument 'values' must have only one dimension")
865-
out[:] = np.nan
876+
877+
if int64float_t is float32_t or int64float_t is float64_t:
878+
out[:] = np.nan
879+
else:
880+
out[:] = 0
881+
882+
first_element_set = np.zeros((<object>counts).shape, dtype=np.uint8)
883+
if uses_mask:
884+
result_mask[:] = True
866885

867886
with nogil:
868887
for i in range(N):
@@ -872,11 +891,22 @@ def group_ohlc(
872891

873892
counts[lab] += 1
874893
val = values[i, 0]
875-
if val != val:
894+
895+
if uses_mask:
896+
isna_entry = mask[i, 0]
897+
elif int64float_t is float32_t or int64float_t is float64_t:
898+
isna_entry = val != val
899+
else:
900+
isna_entry = False
901+
902+
if isna_entry:
876903
continue
877904

878-
if out[lab, 0] != out[lab, 0]:
905+
if not first_element_set[lab]:
879906
out[lab, 0] = out[lab, 1] = out[lab, 2] = out[lab, 3] = val
907+
first_element_set[lab] = True
908+
if uses_mask:
909+
result_mask[lab] = False
880910
else:
881911
out[lab, 1] = max(out[lab, 1], val)
882912
out[lab, 2] = min(out[lab, 2], val)

pandas/core/groupby/ops.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,7 @@ def __init__(self, kind: str, how: str, has_dropped_na: bool) -> None:
157157
"first",
158158
"rank",
159159
"sum",
160+
"ohlc",
160161
}
161162

162163
_cython_arity = {"ohlc": 4} # OHLC
@@ -219,13 +220,13 @@ def _get_cython_vals(self, values: np.ndarray) -> np.ndarray:
219220
values = ensure_float64(values)
220221

221222
elif values.dtype.kind in ["i", "u"]:
222-
if how in ["var", "prod", "mean", "ohlc"] or (
223+
if how in ["var", "prod", "mean"] or (
223224
self.kind == "transform" and self.has_dropped_na
224225
):
225226
# result may still include NaN, so we have to cast
226227
values = ensure_float64(values)
227228

228-
elif how == "sum":
229+
elif how in ["sum", "ohlc"]:
229230
# Avoid overflow during group op
230231
if values.dtype.kind == "i":
231232
values = ensure_int64(values)
@@ -480,6 +481,9 @@ def _masked_ea_wrap_cython_operation(
480481
**kwargs,
481482
)
482483

484+
if self.how == "ohlc":
485+
result_mask = np.tile(result_mask, (4, 1)).T
486+
483487
# res_values should already have the correct dtype, we just need to
484488
# wrap in a MaskedArray
485489
return orig_values._maybe_mask_result(res_values, result_mask)
@@ -592,6 +596,8 @@ def _call_cython_op(
592596
min_count=min_count,
593597
is_datetimelike=is_datetimelike,
594598
)
599+
elif self.how == "ohlc":
600+
func(result, counts, values, comp_ids, min_count, mask, result_mask)
595601
else:
596602
func(result, counts, values, comp_ids, min_count)
597603
else:

pandas/tests/groupby/aggregate/test_aggregate.py

+16
Original file line numberDiff line numberDiff line change
@@ -564,6 +564,22 @@ def test_order_aggregate_multiple_funcs():
564564
tm.assert_index_equal(result, expected)
565565

566566

567+
def test_ohlc_ea_dtypes(any_numeric_ea_dtype):
568+
# GH#37493
569+
df = DataFrame(
570+
{"a": [1, 1, 2, 3, 4, 4], "b": [22, 11, pd.NA, 10, 20, pd.NA]},
571+
dtype=any_numeric_ea_dtype,
572+
)
573+
result = df.groupby("a").ohlc()
574+
expected = DataFrame(
575+
[[22, 22, 11, 11], [pd.NA] * 4, [10] * 4, [20] * 4],
576+
columns=MultiIndex.from_product([["b"], ["open", "high", "low", "close"]]),
577+
index=Index([1, 2, 3, 4], dtype=any_numeric_ea_dtype, name="a"),
578+
dtype=any_numeric_ea_dtype,
579+
)
580+
tm.assert_frame_equal(result, expected)
581+
582+
567583
@pytest.mark.parametrize("dtype", [np.int64, np.uint64])
568584
@pytest.mark.parametrize("how", ["first", "last", "min", "max", "mean", "median"])
569585
def test_uint64_type_handling(dtype, how):

0 commit comments

Comments
 (0)