Skip to content

Commit 3f99b9f

Browse files
Backwards compat for complex types C code
1 parent d0bbc9c commit 3f99b9f

File tree

1 file changed

+52
-8
lines changed

1 file changed

+52
-8
lines changed

pytensor/scalar/basic.py

Lines changed: 52 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -522,10 +522,43 @@ def c_support_code(self, **kwargs):
522522
# In that case we add the 'int' type to the real types.
523523
real_types.append("int")
524524

525+
# Macros for backwards compatibility with numpy < 2.0
526+
#
527+
# In numpy 2.0+, these are defined in npy_math.h, but
528+
# for early versions, they must be vendored by users (e.g. PyTensor)
529+
backwards_compat_macros = """
530+
#ifndef NUMPY_CORE_INCLUDE_NUMPY_NPY_2_COMPLEXCOMPAT_H_
531+
#define NUMPY_CORE_INCLUDE_NUMPY_NPY_2_COMPLEXCOMPAT_H_
532+
533+
#include <numpy/npy_math.h>
534+
535+
#ifndef NPY_CSETREALF
536+
#define NPY_CSETREALF(c, r) (c)->real = (r)
537+
#endif
538+
#ifndef NPY_CSETIMAGF
539+
#define NPY_CSETIMAGF(c, i) (c)->imag = (i)
540+
#endif
541+
#ifndef NPY_CSETREAL
542+
#define NPY_CSETREAL(c, r) (c)->real = (r)
543+
#endif
544+
#ifndef NPY_CSETIMAG
545+
#define NPY_CSETIMAG(c, i) (c)->imag = (i)
546+
#endif
547+
#ifndef NPY_CSETREALL
548+
#define NPY_CSETREALL(c, r) (c)->real = (r)
549+
#endif
550+
#ifndef NPY_CSETIMAGL
551+
#define NPY_CSETIMAGL(c, i) (c)->imag = (i)
552+
#endif
553+
554+
#endif
555+
"""
556+
525557
def _make_get_set_real_imag(scalar_type: str) -> str:
526558
"""Make overloaded getter/setter functions for real/imag parts of numpy complex types.
527559
528-
The functions called by these getter/setter functions are defining in npy_math.h
560+
The functions called by these getter/setter functions are defining in npy_math.h, or
561+
in the `backward_compat_macros` defined above.
529562
530563
Args:
531564
scalar_type: float, double, or longdouble
@@ -536,11 +569,11 @@ def _make_get_set_real_imag(scalar_type: str) -> str:
536569
"""
537570
complex_type = "npy_c" + scalar_type
538571
suffix = "" if scalar_type == "double" else scalar_type[0]
539-
return_type = scalar_type
540572

541573
if scalar_type == "longdouble":
542-
scalar_type += "_t"
543-
return_type = "npy_" + return_type
574+
scalar_type = "npy_" + scalar_type
575+
576+
return_type = scalar_type
544577

545578
template = f"""
546579
static inline {return_type} get_real(const {complex_type} z)
@@ -550,7 +583,7 @@ def _make_get_set_real_imag(scalar_type: str) -> str:
550583
551584
static inline void set_real({complex_type} *z, const {scalar_type} r)
552585
{{
553-
npy_csetreal{suffix}(z, r);
586+
NPY_CSETREAL{suffix.upper()}(z, r);
554587
}}
555588
556589
static inline {return_type} get_imag(const {complex_type} z)
@@ -560,17 +593,28 @@ def _make_get_set_real_imag(scalar_type: str) -> str:
560593
561594
static inline void set_imag({complex_type} *z, const {scalar_type} i)
562595
{{
563-
npy_csetimag{suffix}(z, i);
596+
NPY_CSETIMAG{suffix.upper()}(z, i);
564597
}}
565598
"""
566599
return template
567600

568-
# TODO: add guard code to prevent this from being defining twice, in case we need to add it somewhere else
569601
get_set_aliases = "\n".join(
570602
_make_get_set_real_imag(stype)
571603
for stype in ["float", "double", "longdouble"]
572604
)
573605

606+
get_set_aliases = backwards_compat_macros + "\n" + get_set_aliases
607+
608+
# Template for defining pytensor_complex64 and pytensor_complex128 structs/classes
609+
#
610+
# The npy_complex64, npy_complex128 types are aliases defined at run time based on
611+
# the size of floats and doubles on the machine. This means that both types are
612+
# not necessarily defined on every machine, but a machine with 32-bit floats and
613+
# 64-bit doubles will have npy_complex64 as an alias of npy_cfloat and npy_complex128
614+
# as an alias of npy_complex128.
615+
#
616+
# In any case, the get/set real/imag functions defined above will always work for
617+
# npy_complex64 and npy_complex128.
574618
template = """
575619
struct pytensor_complex%(nbits)s : public npy_complex%(nbits)s {
576620
typedef pytensor_complex%(nbits)s complex_type;
@@ -719,7 +763,7 @@ def c_init_code(self, **kwargs):
719763
return ["import_array();"]
720764

721765
def c_code_cache_version(self):
722-
return (15, np.version.git_revision)
766+
return (18, np.version.git_revision)
723767

724768
def get_shape_info(self, obj):
725769
return obj.itemsize

0 commit comments

Comments
 (0)