5
5
cimport numpy as np
6
6
cimport cython
7
7
8
+ from libc.string cimport memmove
9
+
8
10
from numpy cimport *
9
11
10
12
from cpython cimport (PyDict_New, PyDict_GetItem, PyDict_SetItem,
@@ -86,6 +88,23 @@ def take_2d_axis0_%(name)s_%(dest)s(ndarray[%(c_type_in)s, ndim=2] values,
86
88
k = values.shape[1]
87
89
88
90
fv = fill_value
91
+
92
+ IF %(can_copy)s:
93
+ cdef:
94
+ %(c_type_out)s *v, *o
95
+
96
+ if values.flags.c_contiguous and out.flags.c_contiguous:
97
+ for i from 0 <= i < n:
98
+ idx = indexer[i]
99
+ if idx == -1:
100
+ for j from 0 <= j < k:
101
+ outbuf[i, j] = fv
102
+ else:
103
+ v = &values[idx, 0]
104
+ o = &outbuf[i, 0]
105
+ memmove(o, v, <size_t>(sizeof(%(c_type_out)s) * k))
106
+ return
107
+
89
108
for i from 0 <= i < n:
90
109
idx = indexer[i]
91
110
if idx == -1:
@@ -109,8 +128,25 @@ def take_2d_axis1_%(name)s_%(dest)s(ndarray[%(c_type_in)s, ndim=2] values,
109
128
110
129
n = len(values)
111
130
k = len(indexer)
112
-
131
+
113
132
fv = fill_value
133
+
134
+ IF %(can_copy)s:
135
+ cdef:
136
+ %(c_type_out)s *v, *o
137
+
138
+ if values.flags.f_contiguous and out.flags.f_contiguous:
139
+ for j from 0 <= j < k:
140
+ idx = indexer[j]
141
+ if idx == -1:
142
+ for i from 0 <= i < n:
143
+ outbuf[i, j] = fv
144
+ else:
145
+ v = &values[0, idx]
146
+ o = &outbuf[0, j]
147
+ memmove(o, v, <size_t>(sizeof(%(c_type_out)s) * n))
148
+ return
149
+
114
150
for j from 0 <= j < k:
115
151
idx = indexer[j]
116
152
if idx == -1:
@@ -2115,39 +2151,40 @@ def generate_put_template(template, use_ints = True, use_floats = True):
2115
2151
return output .getvalue ()
2116
2152
2117
2153
def generate_take_template (template , exclude = None ):
2118
- # name, dest, ctypein, ctypeout, preval, postval
2154
+ # name, dest, ctypein, ctypeout, preval, postval, cancopy
2119
2155
function_list = [
2120
- ('bool' , 'bool' , 'uint8_t' , 'uint8_t' , '' , '' ),
2156
+ ('bool' , 'bool' , 'uint8_t' , 'uint8_t' , '' , '' , True ),
2121
2157
('bool' , 'object' , 'uint8_t' , 'object' ,
2122
- 'True if ' , ' > 0 else False' ),
2123
- ('int8' , 'int8' , 'int8_t' , 'int8_t' , '' , '' ),
2124
- ('int8' , 'int32' , 'int8_t' , 'int32_t' , '' , '' ),
2125
- ('int8' , 'int64' , 'int8_t' , 'int64_t' , '' , '' ),
2126
- ('int8' , 'float64' , 'int8_t' , 'float64_t' , '' , '' ),
2127
- ('int16' , 'int16' , 'int16_t' , 'int16_t' , '' , '' ),
2128
- ('int16' , 'int32' , 'int16_t' , 'int32_t' , '' , '' ),
2129
- ('int16' , 'int64' , 'int16_t' , 'int64_t' , '' , '' ),
2130
- ('int16' , 'float64' , 'int16_t' , 'float64_t' , '' , '' ),
2131
- ('int32' , 'int32' , 'int32_t' , 'int32_t' , '' , '' ),
2132
- ('int32' , 'int64' , 'int32_t' , 'int64_t' , '' , '' ),
2133
- ('int32' , 'float64' , 'int32_t' , 'float64_t' , '' , '' ),
2134
- ('int64' , 'int64' , 'int64_t' , 'int64_t' , '' , '' ),
2135
- ('int64' , 'float64' , 'int64_t' , 'float64_t' , '' , '' ),
2136
- ('float32' , 'float32' , 'float32_t' , 'float32_t' , '' , '' ),
2137
- ('float32' , 'float64' , 'float32_t' , 'float64_t' , '' , '' ),
2138
- ('float64' , 'float64' , 'float64_t' , 'float64_t' , '' , '' ),
2139
- ('object' , 'object' , 'object' , 'object' , '' , '' )
2158
+ 'True if ' , ' > 0 else False' , False ),
2159
+ ('int8' , 'int8' , 'int8_t' , 'int8_t' , '' , '' , True ),
2160
+ ('int8' , 'int32' , 'int8_t' , 'int32_t' , '' , '' , False ),
2161
+ ('int8' , 'int64' , 'int8_t' , 'int64_t' , '' , '' , False ),
2162
+ ('int8' , 'float64' , 'int8_t' , 'float64_t' , '' , '' , False ),
2163
+ ('int16' , 'int16' , 'int16_t' , 'int16_t' , '' , '' , True ),
2164
+ ('int16' , 'int32' , 'int16_t' , 'int32_t' , '' , '' , False ),
2165
+ ('int16' , 'int64' , 'int16_t' , 'int64_t' , '' , '' , False ),
2166
+ ('int16' , 'float64' , 'int16_t' , 'float64_t' , '' , '' , False ),
2167
+ ('int32' , 'int32' , 'int32_t' , 'int32_t' , '' , '' , True ),
2168
+ ('int32' , 'int64' , 'int32_t' , 'int64_t' , '' , '' , False ),
2169
+ ('int32' , 'float64' , 'int32_t' , 'float64_t' , '' , '' , False ),
2170
+ ('int64' , 'int64' , 'int64_t' , 'int64_t' , '' , '' , True ),
2171
+ ('int64' , 'float64' , 'int64_t' , 'float64_t' , '' , '' , False ),
2172
+ ('float32' , 'float32' , 'float32_t' , 'float32_t' , '' , '' , True ),
2173
+ ('float32' , 'float64' , 'float32_t' , 'float64_t' , '' , '' , False ),
2174
+ ('float64' , 'float64' , 'float64_t' , 'float64_t' , '' , '' , True ),
2175
+ ('object' , 'object' , 'object' , 'object' , '' , '' , False )
2140
2176
]
2141
2177
2142
2178
output = StringIO ()
2143
2179
for (name , dest , c_type_in , c_type_out ,
2144
- preval , postval ) in function_list :
2180
+ preval , postval , can_copy ) in function_list :
2145
2181
if exclude is not None and name in exclude :
2146
2182
continue
2147
2183
2148
2184
func = template % {'name' : name , 'dest' : dest ,
2149
2185
'c_type_in' : c_type_in , 'c_type_out' : c_type_out ,
2150
- 'preval' : preval , 'postval' : postval }
2186
+ 'preval' : preval , 'postval' : postval ,
2187
+ 'can_copy' : 'True' if can_copy else 'False' }
2151
2188
output .write (func )
2152
2189
return output .getvalue ()
2153
2190
0 commit comments