Skip to content

Commit e036caf

Browse files
committed
Update numpy deprecated imports
- replaced np.AxisError with np.exceptions.AxisError - the `numpy.core` submodule has been renamed to `numpy._core` - some parts of `numpy.core` have been moved to `numpy.lib.array_utils` Except for `AxisError`, the updated imports are conditional on the version of numpy, so the imports should work for numpy >= 1.26. The conditional imports have been added to `npy_2_compat.py`, so the imports elsewhere are unconditonal.
1 parent bbe663d commit e036caf

18 files changed

+311
-35
lines changed

pytensor/link/c/basic.py

+1-6
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,6 @@
1010
from io import StringIO
1111
from typing import TYPE_CHECKING, Any, Optional
1212

13-
import numpy as np
14-
1513
from pytensor.compile.compilelock import lock_ctx
1614
from pytensor.configdefaults import config
1715
from pytensor.graph.basic import (
@@ -33,6 +31,7 @@
3331
from pytensor.link.c.cmodule import get_module_cache as _get_module_cache
3432
from pytensor.link.c.interface import CLinkerObject, CLinkerOp, CLinkerType
3533
from pytensor.link.utils import gc_helper, map_storage, raise_with_op, streamline
34+
from pytensor.npy_2_compat import ndarray_c_version
3635
from pytensor.utils import difference, uniq
3736

3837

@@ -1367,10 +1366,6 @@ def cmodule_key_(
13671366

13681367
# We must always add the numpy ABI version here as
13691368
# DynamicModule always add the include <numpy/arrayobject.h>
1370-
if np.lib.NumpyVersion(np.__version__) < "1.16.0a":
1371-
ndarray_c_version = np.core.multiarray._get_ndarray_c_version()
1372-
else:
1373-
ndarray_c_version = np.core._multiarray_umath._get_ndarray_c_version()
13741369
sig.append(f"NPY_ABI_VERSION=0x{ndarray_c_version:X}")
13751370
if c_compiler:
13761371
sig.append("c_compiler_str=" + c_compiler.version_str())

pytensor/link/numba/dispatch/elemwise.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import numba
55
import numpy as np
66
from numba.core.extending import overload
7-
from numpy.core.numeric import normalize_axis_index, normalize_axis_tuple
87

98
from pytensor.graph.op import Op
109
from pytensor.link.numba.dispatch import basic as numba_basic
@@ -19,6 +18,7 @@
1918
store_core_outputs,
2019
)
2120
from pytensor.link.utils import compile_function_src
21+
from pytensor.npy_2_compat import normalize_axis_index, normalize_axis_tuple
2222
from pytensor.scalar.basic import (
2323
AND,
2424
OR,

pytensor/npy_2_compat.py

+275
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,275 @@
1+
from textwrap import dedent
2+
3+
import numpy as np
4+
5+
6+
# Conditional numpy imports for numpy 1.26 and 2.x compatibility
7+
try:
8+
from numpy.lib.array_utils import normalize_axis_index, normalize_axis_tuple
9+
except ModuleNotFoundError:
10+
# numpy < 2.0
11+
from numpy.core.multiarray import normalize_axis_index # type: ignore[no-redef]
12+
from numpy.core.numeric import normalize_axis_tuple # type: ignore[no-redef]
13+
14+
15+
try:
16+
from numpy._core.einsumfunc import ( # type: ignore[attr-defined]
17+
_find_contraction,
18+
_parse_einsum_input,
19+
)
20+
except ModuleNotFoundError:
21+
from numpy.core.einsumfunc import ( # type: ignore[no-redef]
22+
_find_contraction,
23+
_parse_einsum_input,
24+
)
25+
26+
27+
# suppress linting warning by "using" the imports here:
28+
__all__ = [
29+
"_find_contraction",
30+
"_parse_einsum_input",
31+
"normalize_axis_index",
32+
"normalize_axis_tuple",
33+
]
34+
35+
36+
numpy_version_tuple = tuple(int(n) for n in np.__version__.split(".")[:2])
37+
numpy_version = np.lib.NumpyVersion(
38+
np.__version__
39+
) # used to compare with version strings, e.g. numpy_version < "1.16.0"
40+
using_numpy_2 = numpy_version >= "2.0.0rc1"
41+
42+
43+
if using_numpy_2:
44+
ndarray_c_version = np._core._multiarray_umath._get_ndarray_c_version()
45+
else:
46+
ndarray_c_version = np.core._multiarray_umath._get_ndarray_c_version() # type: ignore[attr-defined]
47+
48+
49+
if using_numpy_2:
50+
UintOverflowError = OverflowError
51+
else:
52+
UintOverflowError = TypeError
53+
54+
55+
def npy_2_compat_header() -> str:
56+
"""Compatibility header that Numpy suggests is vendored with code that uses Numpy < 2.0 and Numpy 2.x"""
57+
return dedent("""
58+
#ifndef NUMPY_CORE_INCLUDE_NUMPY_NPY_2_COMPAT_H_
59+
#define NUMPY_CORE_INCLUDE_NUMPY_NPY_2_COMPAT_H_
60+
61+
62+
/*
63+
* This header is meant to be included by downstream directly for 1.x compat.
64+
* In that case we need to ensure that users first included the full headers
65+
* and not just `ndarraytypes.h`.
66+
*/
67+
68+
#ifndef NPY_FEATURE_VERSION
69+
#error "The NumPy 2 compat header requires `import_array()` for which " \\
70+
"the `ndarraytypes.h` header include is not sufficient. Please " \\
71+
"include it after `numpy/ndarrayobject.h` or similar." \\
72+
"" \\
73+
"To simplify inclusion, you may use `PyArray_ImportNumPy()` " \\
74+
"which is defined in the compat header and is lightweight (can be)."
75+
#endif
76+
77+
#if NPY_ABI_VERSION < 0x02000000
78+
/*
79+
* Define 2.0 feature version as it is needed below to decide whether we
80+
* compile for both 1.x and 2.x (defining it gaurantees 1.x only).
81+
*/
82+
#define NPY_2_0_API_VERSION 0x00000012
83+
/*
84+
* If we are compiling with NumPy 1.x, PyArray_RUNTIME_VERSION so we
85+
* pretend the `PyArray_RUNTIME_VERSION` is `NPY_FEATURE_VERSION`.
86+
* This allows downstream to use `PyArray_RUNTIME_VERSION` if they need to.
87+
*/
88+
#define PyArray_RUNTIME_VERSION NPY_FEATURE_VERSION
89+
/* Compiling on NumPy 1.x where these are the same: */
90+
#define PyArray_DescrProto PyArray_Descr
91+
#endif
92+
93+
94+
/*
95+
* Define a better way to call `_import_array()` to simplify backporting as
96+
* we now require imports more often (necessary to make ABI flexible).
97+
*/
98+
#ifdef import_array1
99+
100+
static inline int
101+
PyArray_ImportNumPyAPI()
102+
{
103+
if (NPY_UNLIKELY(PyArray_API == NULL)) {
104+
import_array1(-1);
105+
}
106+
return 0;
107+
}
108+
109+
#endif /* import_array1 */
110+
111+
112+
/*
113+
* NPY_DEFAULT_INT
114+
*
115+
* The default integer has changed, `NPY_DEFAULT_INT` is available at runtime
116+
* for use as type number, e.g. `PyArray_DescrFromType(NPY_DEFAULT_INT)`.
117+
*
118+
* NPY_RAVEL_AXIS
119+
*
120+
* This was introduced in NumPy 2.0 to allow indicating that an axis should be
121+
* raveled in an operation. Before NumPy 2.0, NPY_MAXDIMS was used for this purpose.
122+
*
123+
* NPY_MAXDIMS
124+
*
125+
* A constant indicating the maximum number dimensions allowed when creating
126+
* an ndarray.
127+
*
128+
* NPY_NTYPES_LEGACY
129+
*
130+
* The number of built-in NumPy dtypes.
131+
*/
132+
#if NPY_FEATURE_VERSION >= NPY_2_0_API_VERSION
133+
#define NPY_DEFAULT_INT NPY_INTP
134+
#define NPY_RAVEL_AXIS NPY_MIN_INT
135+
#define NPY_MAXARGS 64
136+
137+
#elif NPY_ABI_VERSION < 0x02000000
138+
#define NPY_DEFAULT_INT NPY_LONG
139+
#define NPY_RAVEL_AXIS 32
140+
#define NPY_MAXARGS 32
141+
142+
/* Aliases of 2.x names to 1.x only equivalent names */
143+
#define NPY_NTYPES NPY_NTYPES_LEGACY
144+
#define PyArray_DescrProto PyArray_Descr
145+
#define _PyArray_LegacyDescr PyArray_Descr
146+
/* NumPy 2 definition always works, but add it for 1.x only */
147+
#define PyDataType_ISLEGACY(dtype) (1)
148+
#else
149+
#define NPY_DEFAULT_INT \\
150+
(PyArray_RUNTIME_VERSION >= NPY_2_0_API_VERSION ? NPY_INTP : NPY_LONG)
151+
#define NPY_RAVEL_AXIS \\
152+
(PyArray_RUNTIME_VERSION >= NPY_2_0_API_VERSION ? -1 : 32)
153+
#define NPY_MAXARGS \\
154+
(PyArray_RUNTIME_VERSION >= NPY_2_0_API_VERSION ? 64 : 32)
155+
#endif
156+
157+
158+
/*
159+
* Access inline functions for descriptor fields. Except for the first
160+
* few fields, these needed to be moved (elsize, alignment) for
161+
* additional space. Or they are descriptor specific and are not generally
162+
* available anymore (metadata, c_metadata, subarray, names, fields).
163+
*
164+
* Most of these are defined via the `DESCR_ACCESSOR` macro helper.
165+
*/
166+
#if NPY_FEATURE_VERSION >= NPY_2_0_API_VERSION || NPY_ABI_VERSION < 0x02000000
167+
/* Compiling for 1.x or 2.x only, direct field access is OK: */
168+
169+
static inline void
170+
PyDataType_SET_ELSIZE(PyArray_Descr *dtype, npy_intp size)
171+
{
172+
dtype->elsize = size;
173+
}
174+
175+
static inline npy_uint64
176+
PyDataType_FLAGS(const PyArray_Descr *dtype)
177+
{
178+
#if NPY_FEATURE_VERSION >= NPY_2_0_API_VERSION
179+
return dtype->flags;
180+
#else
181+
return (unsigned char)dtype->flags; /* Need unsigned cast on 1.x */
182+
#endif
183+
}
184+
185+
#define DESCR_ACCESSOR(FIELD, field, type, legacy_only) \\
186+
static inline type \\
187+
PyDataType_##FIELD(const PyArray_Descr *dtype) { \\
188+
if (legacy_only && !PyDataType_ISLEGACY(dtype)) { \\
189+
return (type)0; \\
190+
} \\
191+
return ((_PyArray_LegacyDescr *)dtype)->field; \\
192+
}
193+
#else /* compiling for both 1.x and 2.x */
194+
195+
static inline void
196+
PyDataType_SET_ELSIZE(PyArray_Descr *dtype, npy_intp size)
197+
{
198+
if (PyArray_RUNTIME_VERSION >= NPY_2_0_API_VERSION) {
199+
((_PyArray_DescrNumPy2 *)dtype)->elsize = size;
200+
}
201+
else {
202+
((PyArray_DescrProto *)dtype)->elsize = (int)size;
203+
}
204+
}
205+
206+
static inline npy_uint64
207+
PyDataType_FLAGS(const PyArray_Descr *dtype)
208+
{
209+
if (PyArray_RUNTIME_VERSION >= NPY_2_0_API_VERSION) {
210+
return ((_PyArray_DescrNumPy2 *)dtype)->flags;
211+
}
212+
else {
213+
return (unsigned char)((PyArray_DescrProto *)dtype)->flags;
214+
}
215+
}
216+
217+
/* Cast to LegacyDescr always fine but needed when `legacy_only` */
218+
#define DESCR_ACCESSOR(FIELD, field, type, legacy_only) \\
219+
static inline type \\
220+
PyDataType_##FIELD(const PyArray_Descr *dtype) { \\
221+
if (legacy_only && !PyDataType_ISLEGACY(dtype)) { \\
222+
return (type)0; \\
223+
} \\
224+
if (PyArray_RUNTIME_VERSION >= NPY_2_0_API_VERSION) { \\
225+
return ((_PyArray_LegacyDescr *)dtype)->field; \\
226+
} \\
227+
else { \\
228+
return ((PyArray_DescrProto *)dtype)->field; \\
229+
} \\
230+
}
231+
#endif
232+
233+
DESCR_ACCESSOR(ELSIZE, elsize, npy_intp, 0)
234+
DESCR_ACCESSOR(ALIGNMENT, alignment, npy_intp, 0)
235+
DESCR_ACCESSOR(METADATA, metadata, PyObject *, 1)
236+
DESCR_ACCESSOR(SUBARRAY, subarray, PyArray_ArrayDescr *, 1)
237+
DESCR_ACCESSOR(NAMES, names, PyObject *, 1)
238+
DESCR_ACCESSOR(FIELDS, fields, PyObject *, 1)
239+
DESCR_ACCESSOR(C_METADATA, c_metadata, NpyAuxData *, 1)
240+
241+
#undef DESCR_ACCESSOR
242+
243+
244+
#if !(defined(NPY_INTERNAL_BUILD) && NPY_INTERNAL_BUILD)
245+
#if NPY_FEATURE_VERSION >= NPY_2_0_API_VERSION
246+
static inline PyArray_ArrFuncs *
247+
PyDataType_GetArrFuncs(const PyArray_Descr *descr)
248+
{
249+
return _PyDataType_GetArrFuncs(descr);
250+
}
251+
#elif NPY_ABI_VERSION < 0x02000000
252+
static inline PyArray_ArrFuncs *
253+
PyDataType_GetArrFuncs(const PyArray_Descr *descr)
254+
{
255+
return descr->f;
256+
}
257+
#else
258+
static inline PyArray_ArrFuncs *
259+
PyDataType_GetArrFuncs(const PyArray_Descr *descr)
260+
{
261+
if (PyArray_RUNTIME_VERSION >= NPY_2_0_API_VERSION) {
262+
return _PyDataType_GetArrFuncs(descr);
263+
}
264+
else {
265+
return ((PyArray_DescrProto *)descr)->f;
266+
}
267+
}
268+
#endif
269+
270+
271+
#endif /* not internal build */
272+
273+
#endif /* NUMPY_CORE_INCLUDE_NUMPY_NPY_2_COMPAT_H_ */
274+
275+
""")

pytensor/tensor/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ def _get_vector_length_Constant(op: Op | Variable, var: Constant) -> int:
123123

124124
# isort: on
125125
# Allow accessing numpy constants from pytensor.tensor
126-
from numpy import e, euler_gamma, inf, infty, nan, newaxis, pi
126+
from numpy import e, euler_gamma, inf, nan, newaxis, pi
127127

128128
from pytensor.tensor.basic import *
129129
from pytensor.tensor.blas import batched_dot, batched_tensordot

pytensor/tensor/basic.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,7 @@
1414
from typing import cast as type_cast
1515

1616
import numpy as np
17-
from numpy.core.multiarray import normalize_axis_index
18-
from numpy.core.numeric import normalize_axis_tuple
17+
from numpy.exceptions import AxisError
1918

2019
import pytensor
2120
import pytensor.scalar.sharedvar
@@ -32,6 +31,7 @@
3231
from pytensor.graph.type import HasShape, Type
3332
from pytensor.link.c.op import COp
3433
from pytensor.link.c.params_type import ParamsType
34+
from pytensor.npy_2_compat import normalize_axis_index, normalize_axis_tuple
3535
from pytensor.printing import Printer, min_informative_str, pprint, set_precedence
3636
from pytensor.raise_op import CheckAndRaise
3737
from pytensor.scalar import int32
@@ -228,7 +228,7 @@ def constant(x, name=None, ndim=None, dtype=None) -> TensorConstant:
228228
elif x_.ndim > ndim:
229229
try:
230230
x_ = np.squeeze(x_, axis=tuple(range(x_.ndim - ndim)))
231-
except np.AxisError:
231+
except AxisError:
232232
raise ValueError(
233233
f"ndarray could not be cast to constant with {int(ndim)} dimensions"
234234
)
@@ -4405,7 +4405,7 @@ def expand_dims(a: np.ndarray | TensorVariable, axis: Sequence[int]) -> TensorVa
44054405
axis = (axis,)
44064406

44074407
out_ndim = len(axis) + a.ndim
4408-
axis = np.core.numeric.normalize_axis_tuple(axis, out_ndim)
4408+
axis = normalize_axis_tuple(axis, out_ndim)
44094409

44104410
if not axis:
44114411
return a

pytensor/tensor/conv/abstract_conv.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from math import gcd
99

1010
import numpy as np
11+
from numpy.exceptions import ComplexWarning
1112

1213

1314
try:
@@ -2338,7 +2339,7 @@ def conv(
23382339
bval = _bvalfromboundary("fill")
23392340

23402341
with warnings.catch_warnings():
2341-
warnings.simplefilter("ignore", np.ComplexWarning)
2342+
warnings.simplefilter("ignore", ComplexWarning)
23422343
for b in range(img.shape[0]):
23432344
for g in range(self.num_groups):
23442345
for n in range(output_channel_offset):

0 commit comments

Comments
 (0)