@@ -522,10 +522,43 @@ def c_support_code(self, **kwargs):
522
522
# In that case we add the 'int' type to the real types.
523
523
real_types .append ("int" )
524
524
525
+ # Macros for backwards compatibility with numpy < 2.0
526
+ #
527
+ # In numpy 2.0+, these are defined in npy_math.h, but
528
+ # for early versions, they must be vendored by users (e.g. PyTensor)
529
+ backwards_compat_macros = """
530
+ #ifndef NUMPY_CORE_INCLUDE_NUMPY_NPY_2_COMPLEXCOMPAT_H_
531
+ #define NUMPY_CORE_INCLUDE_NUMPY_NPY_2_COMPLEXCOMPAT_H_
532
+
533
+ #include <numpy/npy_math.h>
534
+
535
+ #ifndef NPY_CSETREALF
536
+ #define NPY_CSETREALF(c, r) (c)->real = (r)
537
+ #endif
538
+ #ifndef NPY_CSETIMAGF
539
+ #define NPY_CSETIMAGF(c, i) (c)->imag = (i)
540
+ #endif
541
+ #ifndef NPY_CSETREAL
542
+ #define NPY_CSETREAL(c, r) (c)->real = (r)
543
+ #endif
544
+ #ifndef NPY_CSETIMAG
545
+ #define NPY_CSETIMAG(c, i) (c)->imag = (i)
546
+ #endif
547
+ #ifndef NPY_CSETREALL
548
+ #define NPY_CSETREALL(c, r) (c)->real = (r)
549
+ #endif
550
+ #ifndef NPY_CSETIMAGL
551
+ #define NPY_CSETIMAGL(c, i) (c)->imag = (i)
552
+ #endif
553
+
554
+ #endif
555
+ """
556
+
525
557
def _make_get_set_real_imag (scalar_type : str ) -> str :
526
558
"""Make overloaded getter/setter functions for real/imag parts of numpy complex types.
527
559
528
- The functions called by these getter/setter functions are defining in npy_math.h
560
+ The functions called by these getter/setter functions are defining in npy_math.h, or
561
+ in the `backward_compat_macros` defined above.
529
562
530
563
Args:
531
564
scalar_type: float, double, or longdouble
@@ -536,11 +569,11 @@ def _make_get_set_real_imag(scalar_type: str) -> str:
536
569
"""
537
570
complex_type = "npy_c" + scalar_type
538
571
suffix = "" if scalar_type == "double" else scalar_type [0 ]
539
- return_type = scalar_type
540
572
541
573
if scalar_type == "longdouble" :
542
- scalar_type += "_t"
543
- return_type = "npy_" + return_type
574
+ scalar_type = "npy_" + scalar_type
575
+
576
+ return_type = scalar_type
544
577
545
578
template = f"""
546
579
static inline { return_type } get_real(const { complex_type } z)
@@ -550,7 +583,7 @@ def _make_get_set_real_imag(scalar_type: str) -> str:
550
583
551
584
static inline void set_real({ complex_type } *z, const { scalar_type } r)
552
585
{{
553
- npy_csetreal { suffix } (z, r);
586
+ NPY_CSETREAL { suffix . upper () } (z, r);
554
587
}}
555
588
556
589
static inline { return_type } get_imag(const { complex_type } z)
@@ -560,17 +593,28 @@ def _make_get_set_real_imag(scalar_type: str) -> str:
560
593
561
594
static inline void set_imag({ complex_type } *z, const { scalar_type } i)
562
595
{{
563
- npy_csetimag { suffix } (z, i);
596
+ NPY_CSETIMAG { suffix . upper () } (z, i);
564
597
}}
565
598
"""
566
599
return template
567
600
568
- # TODO: add guard code to prevent this from being defining twice, in case we need to add it somewhere else
569
601
get_set_aliases = "\n " .join (
570
602
_make_get_set_real_imag (stype )
571
603
for stype in ["float" , "double" , "longdouble" ]
572
604
)
573
605
606
+ get_set_aliases = backwards_compat_macros + "\n " + get_set_aliases
607
+
608
+ # Template for defining pytensor_complex64 and pytensor_complex128 structs/classes
609
+ #
610
+ # The npy_complex64, npy_complex128 types are aliases defined at run time based on
611
+ # the size of floats and doubles on the machine. This means that both types are
612
+ # not necessarily defined on every machine, but a machine with 32-bit floats and
613
+ # 64-bit doubles will have npy_complex64 as an alias of npy_cfloat and npy_complex128
614
+ # as an alias of npy_complex128.
615
+ #
616
+ # In any case, the get/set real/imag functions defined above will always work for
617
+ # npy_complex64 and npy_complex128.
574
618
template = """
575
619
struct pytensor_complex%(nbits)s : public npy_complex%(nbits)s {
576
620
typedef pytensor_complex%(nbits)s complex_type;
@@ -719,7 +763,7 @@ def c_init_code(self, **kwargs):
719
763
return ["import_array();" ]
720
764
721
765
def c_code_cache_version (self ):
722
- return (15 , np .version .git_revision )
766
+ return (18 , np .version .git_revision )
723
767
724
768
def get_shape_info (self , obj ):
725
769
return obj .itemsize
0 commit comments