Skip to content

Commit c9144ca

Browse files
authored
CLN: Unify signatures in _libs.groupby (pandas-dev#34372)
1 parent 3c959fc commit c9144ca

File tree

3 files changed

+58
-20
lines changed

3 files changed

+58
-20
lines changed

pandas/_libs/groupby.pyx

+6-5
Original file line numberDiff line numberDiff line change
@@ -378,8 +378,8 @@ def group_fillna_indexer(ndarray[int64_t] out, ndarray[int64_t] labels,
378378
@cython.boundscheck(False)
379379
@cython.wraparound(False)
380380
def group_any_all(uint8_t[:] out,
381-
const int64_t[:] labels,
382381
const uint8_t[:] values,
382+
const int64_t[:] labels,
383383
const uint8_t[:] mask,
384384
object val_test,
385385
bint skipna):
@@ -560,7 +560,8 @@ def _group_var(floating[:, :] out,
560560
int64_t[:] counts,
561561
floating[:, :] values,
562562
const int64_t[:] labels,
563-
Py_ssize_t min_count=-1):
563+
Py_ssize_t min_count=-1,
564+
int64_t ddof=1):
564565
cdef:
565566
Py_ssize_t i, j, N, K, lab, ncounts = len(counts)
566567
floating val, ct, oldmean
@@ -600,10 +601,10 @@ def _group_var(floating[:, :] out,
600601
for i in range(ncounts):
601602
for j in range(K):
602603
ct = nobs[i, j]
603-
if ct < 2:
604+
if ct <= ddof:
604605
out[i, j] = NAN
605606
else:
606-
out[i, j] /= (ct - 1)
607+
out[i, j] /= (ct - ddof)
607608

608609

609610
group_var_float32 = _group_var['float']
@@ -715,8 +716,8 @@ group_ohlc_float64 = _group_ohlc['double']
715716
@cython.boundscheck(False)
716717
@cython.wraparound(False)
717718
def group_quantile(ndarray[float64_t] out,
718-
ndarray[int64_t] labels,
719719
numeric[:] values,
720+
ndarray[int64_t] labels,
720721
ndarray[uint8_t] mask,
721722
float64_t q,
722723
object interpolation):

pandas/core/groupby/generic.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -1736,7 +1736,8 @@ def _wrap_aggregated_output(
17361736
DataFrame
17371737
"""
17381738
indexed_output = {key.position: val for key, val in output.items()}
1739-
columns = Index(key.label for key in output)
1739+
name = self._obj_with_exclusions._get_axis(1 - self.axis).name
1740+
columns = Index([key.label for key in output], name=name)
17401741

17411742
result = self.obj._constructor(indexed_output)
17421743
result.columns = columns

pandas/core/groupby/groupby.py

+50-14
Original file line numberDiff line numberDiff line change
@@ -1277,6 +1277,7 @@ def result_to_bool(result: np.ndarray, inference: Type) -> np.ndarray:
12771277
return self._get_cythonized_result(
12781278
"group_any_all",
12791279
aggregate=True,
1280+
numeric_only=False,
12801281
cython_dtype=np.dtype(np.uint8),
12811282
needs_values=True,
12821283
needs_mask=True,
@@ -1433,18 +1434,16 @@ def std(self, ddof: int = 1):
14331434
Series or DataFrame
14341435
Standard deviation of values within each group.
14351436
"""
1436-
result = self.var(ddof=ddof)
1437-
if result.ndim == 1:
1438-
result = np.sqrt(result)
1439-
else:
1440-
cols = result.columns.get_indexer_for(
1441-
result.columns.difference(self.exclusions).unique()
1442-
)
1443-
# TODO(GH-22046) - setting with iloc broken if labels are not unique
1444-
# .values to remove labels
1445-
result.iloc[:, cols] = np.sqrt(result.iloc[:, cols]).values
1446-
1447-
return result
1437+
return self._get_cythonized_result(
1438+
"group_var_float64",
1439+
aggregate=True,
1440+
needs_counts=True,
1441+
needs_values=True,
1442+
needs_2d=True,
1443+
cython_dtype=np.dtype(np.float64),
1444+
post_processing=lambda vals, inference: np.sqrt(vals),
1445+
ddof=ddof,
1446+
)
14481447

14491448
@Substitution(name="groupby")
14501449
@Appender(_common_see_also)
@@ -1778,6 +1777,7 @@ def _fill(self, direction, limit=None):
17781777

17791778
return self._get_cythonized_result(
17801779
"group_fillna_indexer",
1780+
numeric_only=False,
17811781
needs_mask=True,
17821782
cython_dtype=np.dtype(np.int64),
17831783
result_is_index=True,
@@ -2078,6 +2078,7 @@ def post_processor(vals: np.ndarray, inference: Optional[Type]) -> np.ndarray:
20782078
return self._get_cythonized_result(
20792079
"group_quantile",
20802080
aggregate=True,
2081+
numeric_only=False,
20812082
needs_values=True,
20822083
needs_mask=True,
20832084
cython_dtype=np.dtype(np.float64),
@@ -2367,7 +2368,11 @@ def _get_cythonized_result(
23672368
how: str,
23682369
cython_dtype: np.dtype,
23692370
aggregate: bool = False,
2371+
numeric_only: bool = True,
2372+
needs_counts: bool = False,
23702373
needs_values: bool = False,
2374+
needs_2d: bool = False,
2375+
min_count: Optional[int] = None,
23712376
needs_mask: bool = False,
23722377
needs_ngroups: bool = False,
23732378
result_is_index: bool = False,
@@ -2386,9 +2391,18 @@ def _get_cythonized_result(
23862391
aggregate : bool, default False
23872392
Whether the result should be aggregated to match the number of
23882393
groups
2394+
numeric_only : bool, default True
2395+
Whether only numeric datatypes should be computed
2396+
needs_counts : bool, default False
2397+
Whether the counts should be a part of the Cython call
23892398
needs_values : bool, default False
23902399
Whether the values should be a part of the Cython call
23912400
signature
2401+
needs_2d : bool, default False
2402+
Whether the values and result of the Cython call signature
2403+
are at least 2-dimensional.
2404+
min_count : int, default None
2405+
When not None, min_count for the Cython call
23922406
needs_mask : bool, default False
23932407
Whether boolean mask needs to be part of the Cython call
23942408
signature
@@ -2418,7 +2432,7 @@ def _get_cythonized_result(
24182432
if result_is_index and aggregate:
24192433
raise ValueError("'result_is_index' and 'aggregate' cannot both be True!")
24202434
if post_processing:
2421-
if not callable(pre_processing):
2435+
if not callable(post_processing):
24222436
raise ValueError("'post_processing' must be a callable!")
24232437
if pre_processing:
24242438
if not callable(pre_processing):
@@ -2438,21 +2452,39 @@ def _get_cythonized_result(
24382452
name = obj.name
24392453
values = obj._values
24402454

2455+
if numeric_only and not is_numeric_dtype(values):
2456+
continue
2457+
24412458
if aggregate:
24422459
result_sz = ngroups
24432460
else:
24442461
result_sz = len(values)
24452462

24462463
result = np.zeros(result_sz, dtype=cython_dtype)
2447-
func = partial(base_func, result, labels)
2464+
if needs_2d:
2465+
result = result.reshape((-1, 1))
2466+
func = partial(base_func, result)
2467+
24482468
inferences = None
24492469

2470+
if needs_counts:
2471+
counts = np.zeros(self.ngroups, dtype=np.int64)
2472+
func = partial(func, counts)
2473+
24502474
if needs_values:
24512475
vals = values
24522476
if pre_processing:
24532477
vals, inferences = pre_processing(vals)
2478+
if needs_2d:
2479+
vals = vals.reshape((-1, 1))
2480+
vals = vals.astype(cython_dtype, copy=False)
24542481
func = partial(func, vals)
24552482

2483+
func = partial(func, labels)
2484+
2485+
if min_count is not None:
2486+
func = partial(func, min_count)
2487+
24562488
if needs_mask:
24572489
mask = isna(values).view(np.uint8)
24582490
func = partial(func, mask)
@@ -2462,6 +2494,9 @@ def _get_cythonized_result(
24622494

24632495
func(**kwargs) # Call func to modify indexer values in place
24642496

2497+
if needs_2d:
2498+
result = result.reshape(-1)
2499+
24652500
if result_is_index:
24662501
result = algorithms.take_nd(values, result)
24672502

@@ -2512,6 +2547,7 @@ def shift(self, periods=1, freq=None, axis=0, fill_value=None):
25122547

25132548
return self._get_cythonized_result(
25142549
"group_shift_indexer",
2550+
numeric_only=False,
25152551
cython_dtype=np.dtype(np.int64),
25162552
needs_ngroups=True,
25172553
result_is_index=True,

0 commit comments

Comments
 (0)