Skip to content

Commit 46f77b5

Browse files
authored
CLN: numba.prange usage (#38661)
1 parent a943a09 commit 46f77b5

File tree

2 files changed

+6
-22
lines changed

2 files changed

+6
-22
lines changed

pandas/core/groupby/numba_.py

+4-12
Original file line numberDiff line numberDiff line change
@@ -89,10 +89,6 @@ def generate_numba_agg_func(
8989

9090
numba_func = jit_user_function(func, nopython, nogil, parallel)
9191
numba = import_optional_dependency("numba")
92-
if parallel:
93-
loop_range = numba.prange
94-
else:
95-
loop_range = range
9692

9793
@numba.jit(nopython=nopython, nogil=nogil, parallel=parallel)
9894
def group_agg(
@@ -104,9 +100,9 @@ def group_agg(
104100
num_columns: int,
105101
) -> np.ndarray:
106102
result = np.empty((num_groups, num_columns))
107-
for i in loop_range(num_groups):
103+
for i in numba.prange(num_groups):
108104
group_index = index[begin[i] : end[i]]
109-
for j in loop_range(num_columns):
105+
for j in numba.prange(num_columns):
110106
group = values[begin[i] : end[i], j]
111107
result[i, j] = numba_func(group, group_index, *args)
112108
return result
@@ -153,10 +149,6 @@ def generate_numba_transform_func(
153149

154150
numba_func = jit_user_function(func, nopython, nogil, parallel)
155151
numba = import_optional_dependency("numba")
156-
if parallel:
157-
loop_range = numba.prange
158-
else:
159-
loop_range = range
160152

161153
@numba.jit(nopython=nopython, nogil=nogil, parallel=parallel)
162154
def group_transform(
@@ -168,9 +160,9 @@ def group_transform(
168160
num_columns: int,
169161
) -> np.ndarray:
170162
result = np.empty((len(values), num_columns))
171-
for i in loop_range(num_groups):
163+
for i in numba.prange(num_groups):
172164
group_index = index[begin[i] : end[i]]
173-
for j in loop_range(num_columns):
165+
for j in numba.prange(num_columns):
174166
group = values[begin[i] : end[i], j]
175167
result[begin[i] : end[i], j] = numba_func(group, group_index, *args)
176168
return result

pandas/core/window/numba_.py

+2-10
Original file line numberDiff line numberDiff line change
@@ -50,17 +50,13 @@ def generate_numba_apply_func(
5050

5151
numba_func = jit_user_function(func, nopython, nogil, parallel)
5252
numba = import_optional_dependency("numba")
53-
if parallel:
54-
loop_range = numba.prange
55-
else:
56-
loop_range = range
5753

5854
@numba.jit(nopython=nopython, nogil=nogil, parallel=parallel)
5955
def roll_apply(
6056
values: np.ndarray, begin: np.ndarray, end: np.ndarray, minimum_periods: int
6157
) -> np.ndarray:
6258
result = np.empty(len(begin))
63-
for i in loop_range(len(result)):
59+
for i in numba.prange(len(result)):
6460
start = begin[i]
6561
stop = end[i]
6662
window = values[start:stop]
@@ -103,10 +99,6 @@ def generate_numba_groupby_ewma_func(
10399
return NUMBA_FUNC_CACHE[cache_key]
104100

105101
numba = import_optional_dependency("numba")
106-
if parallel:
107-
loop_range = numba.prange
108-
else:
109-
loop_range = range
110102

111103
@numba.jit(nopython=nopython, nogil=nogil, parallel=parallel)
112104
def groupby_ewma(
@@ -117,7 +109,7 @@ def groupby_ewma(
117109
) -> np.ndarray:
118110
result = np.empty(len(values))
119111
alpha = 1.0 / (1.0 + com)
120-
for i in loop_range(len(begin)):
112+
for i in numba.prange(len(begin)):
121113
start = begin[i]
122114
stop = end[i]
123115
window = values[start:stop]

0 commit comments

Comments
 (0)