Skip to content

Commit fbcad98

Browse files
jbrockmendelfeefladder
authored andcommitted
PERF: GroupBy.any/all operate blockwise instead of column-wise (pandas-dev#42841)
1 parent f41cac3 commit fbcad98

File tree

4 files changed

+108
-44
lines changed

4 files changed

+108
-44
lines changed

asv_bench/benchmarks/groupby.py

+24-11
Original file line numberDiff line numberDiff line change
@@ -403,7 +403,7 @@ def time_srs_bfill(self):
403403

404404
class GroupByMethods:
405405

406-
param_names = ["dtype", "method", "application"]
406+
param_names = ["dtype", "method", "application", "ncols"]
407407
params = [
408408
["int", "float", "object", "datetime", "uint"],
409409
[
@@ -443,15 +443,23 @@ class GroupByMethods:
443443
"var",
444444
],
445445
["direct", "transformation"],
446+
[1, 2, 5, 10],
446447
]
447448

448-
def setup(self, dtype, method, application):
449+
def setup(self, dtype, method, application, ncols):
449450
if method in method_blocklist.get(dtype, {}):
450451
raise NotImplementedError # skip benchmark
452+
453+
if ncols != 1 and method in ["value_counts", "unique"]:
454+
# DataFrameGroupBy doesn't have these methods
455+
raise NotImplementedError
456+
451457
ngroups = 1000
452458
size = ngroups * 2
453-
rng = np.arange(ngroups)
454-
values = rng.take(np.random.randint(0, ngroups, size=size))
459+
rng = np.arange(ngroups).reshape(-1, 1)
460+
rng = np.broadcast_to(rng, (len(rng), ncols))
461+
taker = np.random.randint(0, ngroups, size=size)
462+
values = rng.take(taker, axis=0)
455463
if dtype == "int":
456464
key = np.random.randint(0, size, size=size)
457465
elif dtype == "uint":
@@ -465,22 +473,27 @@ def setup(self, dtype, method, application):
465473
elif dtype == "datetime":
466474
key = date_range("1/1/2011", periods=size, freq="s")
467475

468-
df = DataFrame({"values": values, "key": key})
476+
cols = [f"values{n}" for n in range(ncols)]
477+
df = DataFrame(values, columns=cols)
478+
df["key"] = key
479+
480+
if len(cols) == 1:
481+
cols = cols[0]
469482

470483
if application == "transform":
471484
if method == "describe":
472485
raise NotImplementedError
473486

474-
self.as_group_method = lambda: df.groupby("key")["values"].transform(method)
475-
self.as_field_method = lambda: df.groupby("values")["key"].transform(method)
487+
self.as_group_method = lambda: df.groupby("key")[cols].transform(method)
488+
self.as_field_method = lambda: df.groupby(cols)["key"].transform(method)
476489
else:
477-
self.as_group_method = getattr(df.groupby("key")["values"], method)
478-
self.as_field_method = getattr(df.groupby("values")["key"], method)
490+
self.as_group_method = getattr(df.groupby("key")[cols], method)
491+
self.as_field_method = getattr(df.groupby(cols)["key"], method)
479492

480-
def time_dtype_as_group(self, dtype, method, application):
493+
def time_dtype_as_group(self, dtype, method, application, ncols):
481494
self.as_group_method()
482495

483-
def time_dtype_as_field(self, dtype, method, application):
496+
def time_dtype_as_field(self, dtype, method, application, ncols):
484497
self.as_field_method()
485498

486499

pandas/_libs/groupby.pyx

+23-17
Original file line numberDiff line numberDiff line change
@@ -388,10 +388,10 @@ def group_fillna_indexer(ndarray[int64_t] out, ndarray[intp_t] labels,
388388

389389
@cython.boundscheck(False)
390390
@cython.wraparound(False)
391-
def group_any_all(int8_t[::1] out,
392-
const int8_t[::1] values,
391+
def group_any_all(int8_t[:, ::1] out,
392+
const int8_t[:, :] values,
393393
const intp_t[::1] labels,
394-
const uint8_t[::1] mask,
394+
const uint8_t[:, :] mask,
395395
str val_test,
396396
bint skipna,
397397
bint nullable) -> None:
@@ -426,9 +426,9 @@ def group_any_all(int8_t[::1] out,
426426
-1 to signify a masked position in the case of a nullable input.
427427
"""
428428
cdef:
429-
Py_ssize_t i, N = len(labels)
429+
Py_ssize_t i, j, N = len(labels), K = out.shape[1]
430430
intp_t lab
431-
int8_t flag_val
431+
int8_t flag_val, val
432432

433433
if val_test == 'all':
434434
# Because the 'all' value of an empty iterable in Python is True we can
@@ -448,21 +448,27 @@ def group_any_all(int8_t[::1] out,
448448
with nogil:
449449
for i in range(N):
450450
lab = labels[i]
451-
if lab < 0 or (skipna and mask[i]):
451+
if lab < 0:
452452
continue
453453

454-
if nullable and mask[i]:
455-
# Set the position as masked if `out[lab] != flag_val`, which
456-
# would indicate True/False has not yet been seen for any/all,
457-
# so by Kleene logic the result is currently unknown
458-
if out[lab] != flag_val:
459-
out[lab] = -1
460-
continue
454+
for j in range(K):
455+
if skipna and mask[i, j]:
456+
continue
457+
458+
if nullable and mask[i, j]:
459+
# Set the position as masked if `out[lab] != flag_val`, which
460+
# would indicate True/False has not yet been seen for any/all,
461+
# so by Kleene logic the result is currently unknown
462+
if out[lab, j] != flag_val:
463+
out[lab, j] = -1
464+
continue
465+
466+
val = values[i, j]
461467

462-
# If True and 'any' or False and 'all', the result is
463-
# already determined
464-
if values[i] == flag_val:
465-
out[lab] = flag_val
468+
# If True and 'any' or False and 'all', the result is
469+
# already determined
470+
if val == flag_val:
471+
out[lab, j] = flag_val
466472

467473

468474
# ----------------------------------------------------------------------

pandas/core/groupby/generic.py

+8-5
Original file line numberDiff line numberDiff line change
@@ -1634,12 +1634,15 @@ def _wrap_aggregated_output(
16341634
-------
16351635
DataFrame
16361636
"""
1637-
indexed_output = {key.position: val for key, val in output.items()}
1638-
columns = Index([key.label for key in output])
1639-
columns._set_names(self._obj_with_exclusions._get_axis(1 - self.axis).names)
1637+
if isinstance(output, DataFrame):
1638+
result = output
1639+
else:
1640+
indexed_output = {key.position: val for key, val in output.items()}
1641+
columns = Index([key.label for key in output])
1642+
columns._set_names(self._obj_with_exclusions._get_axis(1 - self.axis).names)
16401643

1641-
result = self.obj._constructor(indexed_output)
1642-
result.columns = columns
1644+
result = self.obj._constructor(indexed_output)
1645+
result.columns = columns
16431646

16441647
if not self.as_index:
16451648
self._insert_inaxis_grouper_inplace(result)

pandas/core/groupby/groupby.py

+53-11
Original file line numberDiff line numberDiff line change
@@ -1527,13 +1527,13 @@ def _obj_1d_constructor(self) -> type[Series]:
15271527
return self.obj._constructor
15281528

15291529
@final
1530-
def _bool_agg(self, val_test, skipna):
1530+
def _bool_agg(self, val_test: Literal["any", "all"], skipna: bool):
15311531
"""
15321532
Shared func to call any / all Cython GroupBy implementations.
15331533
"""
15341534

15351535
def objs_to_bool(vals: ArrayLike) -> tuple[np.ndarray, type]:
1536-
if is_object_dtype(vals):
1536+
if is_object_dtype(vals.dtype):
15371537
# GH#37501: don't raise on pd.NA when skipna=True
15381538
if skipna:
15391539
vals = np.array([bool(x) if not isna(x) else True for x in vals])
@@ -1542,7 +1542,7 @@ def objs_to_bool(vals: ArrayLike) -> tuple[np.ndarray, type]:
15421542
elif isinstance(vals, BaseMaskedArray):
15431543
vals = vals._data.astype(bool, copy=False)
15441544
else:
1545-
vals = vals.astype(bool)
1545+
vals = vals.astype(bool, copy=False)
15461546

15471547
return vals.view(np.int8), bool
15481548

@@ -1562,6 +1562,7 @@ def result_to_bool(
15621562
numeric_only=False,
15631563
cython_dtype=np.dtype(np.int8),
15641564
needs_values=True,
1565+
needs_2d=True,
15651566
needs_mask=True,
15661567
needs_nullable=True,
15671568
pre_processing=objs_to_bool,
@@ -2917,16 +2918,24 @@ def _get_cythonized_result(
29172918
if min_count is not None:
29182919
base_func = partial(base_func, min_count=min_count)
29192920

2921+
real_2d = how in ["group_any_all"]
2922+
29202923
def blk_func(values: ArrayLike) -> ArrayLike:
2924+
values = values.T
2925+
ncols = 1 if values.ndim == 1 else values.shape[1]
2926+
29212927
if aggregate:
29222928
result_sz = ngroups
29232929
else:
2924-
result_sz = len(values)
2930+
result_sz = values.shape[-1]
29252931

29262932
result: ArrayLike
2927-
result = np.zeros(result_sz, dtype=cython_dtype)
2933+
result = np.zeros(result_sz * ncols, dtype=cython_dtype)
29282934
if needs_2d:
2929-
result = result.reshape((-1, 1))
2935+
if real_2d:
2936+
result = result.reshape((result_sz, ncols))
2937+
else:
2938+
result = result.reshape(-1, 1)
29302939
func = partial(base_func, out=result)
29312940

29322941
inferences = None
@@ -2941,12 +2950,14 @@ def blk_func(values: ArrayLike) -> ArrayLike:
29412950
vals, inferences = pre_processing(vals)
29422951

29432952
vals = vals.astype(cython_dtype, copy=False)
2944-
if needs_2d:
2953+
if needs_2d and vals.ndim == 1:
29452954
vals = vals.reshape((-1, 1))
29462955
func = partial(func, values=vals)
29472956

29482957
if needs_mask:
29492958
mask = isna(values).view(np.uint8)
2959+
if needs_2d and mask.ndim == 1:
2960+
mask = mask.reshape(-1, 1)
29502961
func = partial(func, mask=mask)
29512962

29522963
if needs_nullable:
@@ -2955,20 +2966,51 @@ def blk_func(values: ArrayLike) -> ArrayLike:
29552966

29562967
func(**kwargs) # Call func to modify indexer values in place
29572968

2958-
if needs_2d:
2959-
result = result.reshape(-1)
2960-
29612969
if result_is_index:
29622970
result = algorithms.take_nd(values, result, fill_value=fill_value)
29632971

2972+
if real_2d and values.ndim == 1:
2973+
assert result.shape[1] == 1, result.shape
2974+
# error: Invalid index type "Tuple[slice, int]" for
2975+
# "Union[ExtensionArray, ndarray[Any, Any]]"; expected type
2976+
# "Union[int, integer[Any], slice, Sequence[int], ndarray[Any, Any]]"
2977+
result = result[:, 0] # type: ignore[index]
2978+
if needs_mask:
2979+
mask = mask[:, 0]
2980+
29642981
if post_processing:
29652982
pp_kwargs = {}
29662983
if needs_nullable:
29672984
pp_kwargs["nullable"] = isinstance(values, BaseMaskedArray)
29682985

29692986
result = post_processing(result, inferences, **pp_kwargs)
29702987

2971-
return result
2988+
if needs_2d and not real_2d:
2989+
if result.ndim == 2:
2990+
assert result.shape[1] == 1
2991+
# error: Invalid index type "Tuple[slice, int]" for
2992+
# "Union[ExtensionArray, Any, ndarray[Any, Any]]"; expected
2993+
# type "Union[int, integer[Any], slice, Sequence[int],
2994+
# ndarray[Any, Any]]"
2995+
result = result[:, 0] # type: ignore[index]
2996+
2997+
return result.T
2998+
2999+
obj = self._obj_with_exclusions
3000+
if obj.ndim == 2 and self.axis == 0 and needs_2d and real_2d:
3001+
# Operate block-wise instead of column-by-column
3002+
3003+
mgr = obj._mgr
3004+
if numeric_only:
3005+
mgr = mgr.get_numeric_data()
3006+
3007+
# setting ignore_failures=False for troubleshooting
3008+
res_mgr = mgr.grouped_reduce(blk_func, ignore_failures=False)
3009+
output = type(obj)(res_mgr)
3010+
if aggregate:
3011+
return self._wrap_aggregated_output(output)
3012+
else:
3013+
return self._wrap_transformed_output(output)
29723014

29733015
error_msg = ""
29743016
for idx, obj in enumerate(self._iterate_slices()):

0 commit comments

Comments
 (0)