Skip to content

Commit 243b8a6

Browse files
TYP: use TYPE_CHECKING for import_optional_dependency("numba") (#44273)
1 parent f711e71 commit 243b8a6

File tree

6 files changed

+60
-30
lines changed

6 files changed

+60
-30
lines changed

pandas/core/_numba/executor.py

+9-4
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
from __future__ import annotations
22

3-
from typing import Callable
3+
from typing import (
4+
TYPE_CHECKING,
5+
Callable,
6+
)
47

58
import numpy as np
69

@@ -42,10 +45,12 @@ def generate_shared_aggregator(
4245
if cache_key in NUMBA_FUNC_CACHE:
4346
return NUMBA_FUNC_CACHE[cache_key]
4447

45-
numba = import_optional_dependency("numba")
48+
if TYPE_CHECKING:
49+
import numba
50+
else:
51+
numba = import_optional_dependency("numba")
4652

47-
# error: Untyped decorator makes function "column_looper" untyped
48-
@numba.jit(nopython=nopython, nogil=nogil, parallel=parallel) # type: ignore[misc]
53+
@numba.jit(nopython=nopython, nogil=nogil, parallel=parallel)
4954
def column_looper(
5055
values: np.ndarray,
5156
start: np.ndarray,

pandas/core/groupby/numba_.py

+11-6
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
import inspect
55
from typing import (
6+
TYPE_CHECKING,
67
Any,
78
Callable,
89
)
@@ -90,10 +91,12 @@ def generate_numba_agg_func(
9091
return NUMBA_FUNC_CACHE[cache_key]
9192

9293
numba_func = jit_user_function(func, nopython, nogil, parallel)
93-
numba = import_optional_dependency("numba")
94+
if TYPE_CHECKING:
95+
import numba
96+
else:
97+
numba = import_optional_dependency("numba")
9498

95-
# error: Untyped decorator makes function "group_agg" untyped
96-
@numba.jit(nopython=nopython, nogil=nogil, parallel=parallel) # type: ignore[misc]
99+
@numba.jit(nopython=nopython, nogil=nogil, parallel=parallel)
97100
def group_agg(
98101
values: np.ndarray,
99102
index: np.ndarray,
@@ -152,10 +155,12 @@ def generate_numba_transform_func(
152155
return NUMBA_FUNC_CACHE[cache_key]
153156

154157
numba_func = jit_user_function(func, nopython, nogil, parallel)
155-
numba = import_optional_dependency("numba")
158+
if TYPE_CHECKING:
159+
import numba
160+
else:
161+
numba = import_optional_dependency("numba")
156162

157-
# error: Untyped decorator makes function "group_transform" untyped
158-
@numba.jit(nopython=nopython, nogil=nogil, parallel=parallel) # type: ignore[misc]
163+
@numba.jit(nopython=nopython, nogil=nogil, parallel=parallel)
159164
def group_transform(
160165
values: np.ndarray,
161166
index: np.ndarray,

pandas/core/util/numba_.py

+8-3
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
"""Common utilities for Numba operations"""
2-
# pyright: reportUntypedFunctionDecorator = false
32
from __future__ import annotations
43

54
import types
6-
from typing import Callable
5+
from typing import (
6+
TYPE_CHECKING,
7+
Callable,
8+
)
79

810
import numpy as np
911

@@ -84,7 +86,10 @@ def jit_user_function(
8486
function
8587
Numba JITed function
8688
"""
87-
numba = import_optional_dependency("numba")
89+
if TYPE_CHECKING:
90+
import numba
91+
else:
92+
numba = import_optional_dependency("numba")
8893

8994
if numba.extending.is_jitted(func):
9095
# Don't jit a user passed jitted function

pandas/core/window/numba_.py

+25-14
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
1-
# pyright: reportUntypedFunctionDecorator = false
21
from __future__ import annotations
32

43
import functools
54
from typing import (
5+
TYPE_CHECKING,
66
Any,
77
Callable,
88
)
@@ -56,10 +56,12 @@ def generate_numba_apply_func(
5656
return NUMBA_FUNC_CACHE[cache_key]
5757

5858
numba_func = jit_user_function(func, nopython, nogil, parallel)
59-
numba = import_optional_dependency("numba")
59+
if TYPE_CHECKING:
60+
import numba
61+
else:
62+
numba = import_optional_dependency("numba")
6063

61-
# error: Untyped decorator makes function "roll_apply" untyped
62-
@numba.jit(nopython=nopython, nogil=nogil, parallel=parallel) # type: ignore[misc]
64+
@numba.jit(nopython=nopython, nogil=nogil, parallel=parallel)
6365
def roll_apply(
6466
values: np.ndarray,
6567
begin: np.ndarray,
@@ -115,10 +117,12 @@ def generate_numba_ewm_func(
115117
if cache_key in NUMBA_FUNC_CACHE:
116118
return NUMBA_FUNC_CACHE[cache_key]
117119

118-
numba = import_optional_dependency("numba")
120+
if TYPE_CHECKING:
121+
import numba
122+
else:
123+
numba = import_optional_dependency("numba")
119124

120-
# error: Untyped decorator makes function "ewma" untyped
121-
@numba.jit(nopython=nopython, nogil=nogil, parallel=parallel) # type: ignore[misc]
125+
@numba.jit(nopython=nopython, nogil=nogil, parallel=parallel)
122126
def ewm(
123127
values: np.ndarray,
124128
begin: np.ndarray,
@@ -217,10 +221,12 @@ def generate_numba_table_func(
217221
return NUMBA_FUNC_CACHE[cache_key]
218222

219223
numba_func = jit_user_function(func, nopython, nogil, parallel)
220-
numba = import_optional_dependency("numba")
224+
if TYPE_CHECKING:
225+
import numba
226+
else:
227+
numba = import_optional_dependency("numba")
221228

222-
# error: Untyped decorator makes function "roll_table" untyped
223-
@numba.jit(nopython=nopython, nogil=nogil, parallel=parallel) # type: ignore[misc]
229+
@numba.jit(nopython=nopython, nogil=nogil, parallel=parallel)
224230
def roll_table(
225231
values: np.ndarray,
226232
begin: np.ndarray,
@@ -250,7 +256,10 @@ def roll_table(
250256
# https://github.com/numba/numba/issues/1269
251257
@functools.lru_cache(maxsize=None)
252258
def generate_manual_numpy_nan_agg_with_axis(nan_func):
253-
numba = import_optional_dependency("numba")
259+
if TYPE_CHECKING:
260+
import numba
261+
else:
262+
numba = import_optional_dependency("numba")
254263

255264
@numba.jit(nopython=True, nogil=True, parallel=True)
256265
def nan_agg_with_axis(table):
@@ -296,10 +305,12 @@ def generate_numba_ewm_table_func(
296305
if cache_key in NUMBA_FUNC_CACHE:
297306
return NUMBA_FUNC_CACHE[cache_key]
298307

299-
numba = import_optional_dependency("numba")
308+
if TYPE_CHECKING:
309+
import numba
310+
else:
311+
numba = import_optional_dependency("numba")
300312

301-
# error: Untyped decorator makes function "ewm_table" untyped
302-
@numba.jit(nopython=nopython, nogil=nogil, parallel=parallel) # type: ignore[misc]
313+
@numba.jit(nopython=nopython, nogil=nogil, parallel=parallel)
303314
def ewm_table(
304315
values: np.ndarray,
305316
begin: np.ndarray,

pandas/core/window/online.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from typing import (
2+
TYPE_CHECKING,
23
Dict,
34
Optional,
45
)
@@ -31,10 +32,12 @@ def generate_online_numba_ewma_func(engine_kwargs: Optional[Dict[str, bool]]):
3132
if cache_key in NUMBA_FUNC_CACHE:
3233
return NUMBA_FUNC_CACHE[cache_key]
3334

34-
numba = import_optional_dependency("numba")
35+
if TYPE_CHECKING:
36+
import numba
37+
else:
38+
numba = import_optional_dependency("numba")
3539

36-
# error: Untyped decorator makes function "online_ewma" untyped
37-
@numba.jit(nopython=nopython, nogil=nogil, parallel=parallel) # type: ignore[misc]
40+
@numba.jit(nopython=nopython, nogil=nogil, parallel=parallel)
3841
def online_ewma(
3942
values: np.ndarray,
4043
deltas: np.ndarray,

typings/numba.pyi

+1
Original file line numberDiff line numberDiff line change
@@ -40,3 +40,4 @@ def jit(
4040
) -> Callable[[F], F]: ...
4141

4242
njit = jit
43+
generated_jit = jit

0 commit comments

Comments
 (0)