Skip to content

Commit 55bfcec

Browse files
authored
REF: better use of fused_types for group_ohlc (#40668)
1 parent 5282bef commit 55bfcec

File tree

4 files changed

+14
-16
lines changed

4 files changed

+14
-16
lines changed

pandas/_libs/groupby.pyx

+6-11
Original file line numberDiff line numberDiff line change
@@ -681,18 +681,17 @@ group_mean_float64 = _group_mean['double']
681681

682682
@cython.wraparound(False)
683683
@cython.boundscheck(False)
684-
def _group_ohlc(floating[:, ::1] out,
685-
int64_t[::1] counts,
686-
ndarray[floating, ndim=2] values,
687-
const intp_t[:] labels,
688-
Py_ssize_t min_count=-1):
684+
def group_ohlc(floating[:, ::1] out,
685+
int64_t[::1] counts,
686+
ndarray[floating, ndim=2] values,
687+
const intp_t[:] labels,
688+
Py_ssize_t min_count=-1):
689689
"""
690690
Only aggregates on axis=0
691691
"""
692692
cdef:
693693
Py_ssize_t i, j, N, K, lab
694-
floating val, count
695-
Py_ssize_t ngroups = len(counts)
694+
floating val
696695

697696
assert min_count == -1, "'min_count' only used in add and prod"
698697

@@ -727,10 +726,6 @@ def _group_ohlc(floating[:, ::1] out,
727726
out[lab, 3] = val
728727

729728

730-
group_ohlc_float32 = _group_ohlc['float']
731-
group_ohlc_float64 = _group_ohlc['double']
732-
733-
734729
@cython.boundscheck(False)
735730
@cython.wraparound(False)
736731
def group_quantile(ndarray[float64_t] out,

pandas/core/groupby/ops.py

+6
Original file line numberDiff line numberDiff line change
@@ -486,6 +486,12 @@ def _get_cython_func_and_vals(
486486
func = _get_cython_function(kind, how, values.dtype, is_numeric)
487487
else:
488488
raise
489+
else:
490+
if values.dtype.kind in ["i", "u"]:
491+
if how in ["ohlc"]:
492+
# The output may still include nans, so we have to cast
493+
values = ensure_float64(values)
494+
489495
return func, values
490496

491497
@final

pandas/core/missing.py

+1-4
Original file line numberDiff line numberDiff line change
@@ -861,7 +861,4 @@ def _rolling_window(a: np.ndarray, window: int):
861861
# https://stackoverflow.com/a/6811241
862862
shape = a.shape[:-1] + (a.shape[-1] - window + 1, window)
863863
strides = a.strides + (a.strides[-1],)
864-
# error: Module has no attribute "stride_tricks"
865-
return np.lib.stride_tricks.as_strided( # type: ignore[attr-defined]
866-
a, shape=shape, strides=strides
867-
)
864+
return np.lib.stride_tricks.as_strided(a, shape=shape, strides=strides)

pandas/tests/groupby/test_libgroupby.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ def _check(dtype):
138138
counts = np.zeros(len(out), dtype=np.int64)
139139
labels = ensure_platform_int(np.repeat(np.arange(3), np.diff(np.r_[0, bins])))
140140

141-
func = getattr(libgroupby, f"group_ohlc_{dtype}")
141+
func = libgroupby.group_ohlc
142142
func(out, counts, obj[:, None], labels)
143143

144144
def _ohlc(group):

0 commit comments

Comments
 (0)