@@ -4,32 +4,34 @@ Template for each `dtype` helper function using 1-d template
4
4
WARNING: DO NOT edit .pxi FILE directly, .pxi is generated from .pxi.in
5
5
"""
6
6
7
- {{py:
8
-
9
- # name, c_type, dest_type
10
- dtypes = [('float64', 'float64_t', 'float64_t'),
11
- ('float32', 'float32_t', 'float32_t'),
12
- ('int8', 'int8_t', 'float32_t'),
13
- ('int16', 'int16_t', 'float32_t'),
14
- ('int32', 'int32_t', 'float64_t'),
15
- ('int64', 'int64_t', 'float64_t')]
16
-
17
- def get_dispatch(dtypes):
18
-
19
- for name, c_type, dest_type, in dtypes:
20
- yield name, c_type, dest_type
7
+ ctypedef fused diff_t:
8
+ float64_t
9
+ float32_t
10
+ int8_t
11
+ int16_t
12
+ int32_t
13
+ int64_t
21
14
22
- }}
23
-
24
- {{for name, c_type, dest_type
25
- in get_dispatch(dtypes)}}
15
+ ctypedef fused out_t:
16
+ float32_t
17
+ float64_t
26
18
27
19
28
20
@cython.boundscheck(False)
29
21
@cython.wraparound(False)
30
- def diff_2d_{{name}}(ndarray[{{c_type}}, ndim=2] arr,
31
- ndarray[{{dest_type}}, ndim=2] out,
32
- Py_ssize_t periods, int axis):
22
+ def diff_2d(ndarray[diff_t, ndim=2] arr,
23
+ ndarray[out_t, ndim=2] out,
24
+ Py_ssize_t periods, int axis):
25
+
26
+ # Disable for unsupported dtype combinations,
27
+ # see https://github.com/cython/cython/issues/2646
28
+ if out_t is float32_t:
29
+ if not (diff_t is float32_t or diff_t is int8_t or diff_t is int16_t):
30
+ raise NotImplementedError
31
+ else:
32
+ if (diff_t is float32_t or diff_t is int8_t or diff_t is int16_t):
33
+ raise NotImplementedError
34
+
33
35
cdef:
34
36
Py_ssize_t i, j, sx, sy
35
37
@@ -69,7 +71,6 @@ def diff_2d_{{name}}(ndarray[{{c_type}}, ndim=2] arr,
69
71
for j in range(start, stop):
70
72
out[i, j] = arr[i, j] - arr[i, j - periods]
71
73
72
- {{endfor}}
73
74
74
75
# ----------------------------------------------------------------------
75
76
# ensure_dtype
0 commit comments