Skip to content

Commit 1ec2d1d

Browse files
authored
REF: avoid unnecessary argument to groupby numba funcs (#43487)
1 parent 0072fa8 commit 1ec2d1d

File tree

3 files changed

+15
-15
lines changed

3 files changed

+15
-15
lines changed

pandas/core/groupby/generic.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -228,9 +228,10 @@ def aggregate(self, func=None, *args, engine=None, engine_kwargs=None, **kwargs)
228228
if maybe_use_numba(engine):
229229
with group_selection_context(self):
230230
data = self._selected_obj
231-
result, index = self._aggregate_with_numba(
231+
result = self._aggregate_with_numba(
232232
data.to_frame(), func, *args, engine_kwargs=engine_kwargs, **kwargs
233233
)
234+
index = self._group_keys_index
234235
return self.obj._constructor(result.ravel(), index=index, name=data.name)
235236

236237
relabeling = func is None
@@ -924,9 +925,10 @@ def aggregate(self, func=None, *args, engine=None, engine_kwargs=None, **kwargs)
924925
if maybe_use_numba(engine):
925926
with group_selection_context(self):
926927
data = self._selected_obj
927-
result, index = self._aggregate_with_numba(
928+
result = self._aggregate_with_numba(
928929
data, func, *args, engine_kwargs=engine_kwargs, **kwargs
929930
)
931+
index = self._group_keys_index
930932
return self.obj._constructor(result, index=index, columns=data.columns)
931933

932934
relabeling, func, columns, order = reconstruct_func(func, **kwargs)

pandas/core/groupby/groupby.py

+1-5
Original file line numberDiff line numberDiff line change
@@ -1269,7 +1269,6 @@ def _transform_with_numba(self, data, func, *args, engine_kwargs=None, **kwargs)
12691269
data and indices into a Numba jitted function.
12701270
"""
12711271
starts, ends, sorted_index, sorted_data = self._numba_prep(func, data)
1272-
group_keys = self.grouper.group_keys_seq
12731272

12741273
numba_transform_func = numba_.generate_numba_transform_func(
12751274
kwargs, func, engine_kwargs
@@ -1279,7 +1278,6 @@ def _transform_with_numba(self, data, func, *args, engine_kwargs=None, **kwargs)
12791278
sorted_index,
12801279
starts,
12811280
ends,
1282-
len(group_keys),
12831281
len(data.columns),
12841282
*args,
12851283
)
@@ -1302,15 +1300,13 @@ def _aggregate_with_numba(self, data, func, *args, engine_kwargs=None, **kwargs)
13021300
data and indices into a Numba jitted function.
13031301
"""
13041302
starts, ends, sorted_index, sorted_data = self._numba_prep(func, data)
1305-
index = self._group_keys_index
13061303

13071304
numba_agg_func = numba_.generate_numba_agg_func(kwargs, func, engine_kwargs)
13081305
result = numba_agg_func(
13091306
sorted_data,
13101307
sorted_index,
13111308
starts,
13121309
ends,
1313-
len(index),
13141310
len(data.columns),
13151311
*args,
13161312
)
@@ -1319,7 +1315,7 @@ def _aggregate_with_numba(self, data, func, *args, engine_kwargs=None, **kwargs)
13191315
if cache_key not in NUMBA_FUNC_CACHE:
13201316
NUMBA_FUNC_CACHE[cache_key] = numba_agg_func
13211317

1322-
return result, index
1318+
return result
13231319

13241320
# -----------------------------------------------------------------
13251321
# apply/agg/transform

pandas/core/groupby/numba_.py

+10-8
Original file line numberDiff line numberDiff line change
@@ -59,9 +59,7 @@ def generate_numba_agg_func(
5959
kwargs: dict[str, Any],
6060
func: Callable[..., Scalar],
6161
engine_kwargs: dict[str, bool] | None,
62-
) -> Callable[
63-
[np.ndarray, np.ndarray, np.ndarray, np.ndarray, int, int, Any], np.ndarray
64-
]:
62+
) -> Callable[[np.ndarray, np.ndarray, np.ndarray, np.ndarray, int, Any], np.ndarray]:
6563
"""
6664
Generate a numba jitted agg function specified by values from engine_kwargs.
6765
@@ -100,10 +98,13 @@ def group_agg(
10098
index: np.ndarray,
10199
begin: np.ndarray,
102100
end: np.ndarray,
103-
num_groups: int,
104101
num_columns: int,
105102
*args: Any,
106103
) -> np.ndarray:
104+
105+
assert len(begin) == len(end)
106+
num_groups = len(begin)
107+
107108
result = np.empty((num_groups, num_columns))
108109
for i in numba.prange(num_groups):
109110
group_index = index[begin[i] : end[i]]
@@ -119,9 +120,7 @@ def generate_numba_transform_func(
119120
kwargs: dict[str, Any],
120121
func: Callable[..., np.ndarray],
121122
engine_kwargs: dict[str, bool] | None,
122-
) -> Callable[
123-
[np.ndarray, np.ndarray, np.ndarray, np.ndarray, int, int, Any], np.ndarray
124-
]:
123+
) -> Callable[[np.ndarray, np.ndarray, np.ndarray, np.ndarray, int, Any], np.ndarray]:
125124
"""
126125
Generate a numba jitted transform function specified by values from engine_kwargs.
127126
@@ -160,10 +159,13 @@ def group_transform(
160159
index: np.ndarray,
161160
begin: np.ndarray,
162161
end: np.ndarray,
163-
num_groups: int,
164162
num_columns: int,
165163
*args: Any,
166164
) -> np.ndarray:
165+
166+
assert len(begin) == len(end)
167+
num_groups = len(begin)
168+
167169
result = np.empty((len(values), num_columns))
168170
for i in numba.prange(num_groups):
169171
group_index = index[begin[i] : end[i]]

0 commit comments

Comments
 (0)