56
56
take_1d_template = """@cython.wraparound(False)
57
57
def take_1d_%(name)s_%(dest)s(ndarray[%(c_type_in)s] values,
58
58
ndarray[int64_t] indexer,
59
- out, fill_value=np.nan):
59
+ ndarray[%(c_type_out)s] out,
60
+ fill_value=np.nan):
60
61
cdef:
61
62
Py_ssize_t i, n, idx
62
- ndarray[%(c_type_out)s] outbuf = out
63
63
%(c_type_out)s fv
64
64
65
65
n = len(indexer)
@@ -68,20 +68,20 @@ def take_1d_%(name)s_%(dest)s(ndarray[%(c_type_in)s] values,
68
68
for i from 0 <= i < n:
69
69
idx = indexer[i]
70
70
if idx == -1:
71
- outbuf [i] = fv
71
+ out [i] = fv
72
72
else:
73
- outbuf [i] = %(preval)svalues[idx]%(postval)s
73
+ out [i] = %(preval)svalues[idx]%(postval)s
74
74
75
75
"""
76
76
77
77
take_2d_axis0_template = """@cython.wraparound(False)
78
78
@cython.boundscheck(False)
79
79
def take_2d_axis0_%(name)s_%(dest)s(ndarray[%(c_type_in)s, ndim=2] values,
80
80
ndarray[int64_t] indexer,
81
- out, fill_value=np.nan):
81
+ ndarray[%(c_type_out)s, ndim=2] out,
82
+ fill_value=np.nan):
82
83
cdef:
83
84
Py_ssize_t i, j, k, n, idx
84
- ndarray[%(c_type_out)s, ndim=2] outbuf = out
85
85
%(c_type_out)s fv
86
86
87
87
n = len(indexer)
@@ -93,81 +93,90 @@ def take_2d_axis0_%(name)s_%(dest)s(ndarray[%(c_type_in)s, ndim=2] values,
93
93
cdef:
94
94
%(c_type_out)s *v, *o
95
95
96
- if values.flags.c_contiguous and out.flags.c_contiguous:
96
+ if (values.strides[1] == out.strides[1] and
97
+ values.strides[1] == sizeof(%(c_type_out)s) and
98
+ sizeof(%(c_type_out)s) * n >= 256):
99
+
97
100
for i from 0 <= i < n:
98
101
idx = indexer[i]
99
102
if idx == -1:
100
103
for j from 0 <= j < k:
101
- outbuf [i, j] = fv
104
+ out [i, j] = fv
102
105
else:
103
106
v = &values[idx, 0]
104
- o = &outbuf [i, 0]
107
+ o = &out [i, 0]
105
108
memmove(o, v, <size_t>(sizeof(%(c_type_out)s) * k))
106
109
return
107
110
108
111
for i from 0 <= i < n:
109
112
idx = indexer[i]
110
113
if idx == -1:
111
114
for j from 0 <= j < k:
112
- outbuf [i, j] = fv
115
+ out [i, j] = fv
113
116
else:
114
117
for j from 0 <= j < k:
115
- outbuf [i, j] = %(preval)svalues[idx, j]%(postval)s
118
+ out [i, j] = %(preval)svalues[idx, j]%(postval)s
116
119
117
120
"""
118
121
119
122
take_2d_axis1_template = """@cython.wraparound(False)
120
123
@cython.boundscheck(False)
121
124
def take_2d_axis1_%(name)s_%(dest)s(ndarray[%(c_type_in)s, ndim=2] values,
122
125
ndarray[int64_t] indexer,
123
- out, fill_value=np.nan):
126
+ ndarray[%(c_type_out)s, ndim=2] out,
127
+ fill_value=np.nan):
124
128
cdef:
125
129
Py_ssize_t i, j, k, n, idx
126
- ndarray[%(c_type_out)s, ndim=2] outbuf = out
127
130
%(c_type_out)s fv
128
131
129
132
n = len(values)
130
133
k = len(indexer)
131
-
134
+
135
+ if n == 0 or k == 0:
136
+ return
137
+
132
138
fv = fill_value
133
139
134
140
IF %(can_copy)s:
135
141
cdef:
136
142
%(c_type_out)s *v, *o
137
143
138
- if values.flags.f_contiguous and out.flags.f_contiguous:
144
+ if (values.strides[0] == out.strides[0] and
145
+ values.strides[0] == sizeof(%(c_type_out)s) and
146
+ sizeof(%(c_type_out)s) * n >= 256):
147
+
139
148
for j from 0 <= j < k:
140
149
idx = indexer[j]
141
150
if idx == -1:
142
151
for i from 0 <= i < n:
143
- outbuf [i, j] = fv
152
+ out [i, j] = fv
144
153
else:
145
154
v = &values[0, idx]
146
- o = &outbuf [0, j]
155
+ o = &out [0, j]
147
156
memmove(o, v, <size_t>(sizeof(%(c_type_out)s) * n))
148
157
return
149
158
150
159
for j from 0 <= j < k:
151
160
idx = indexer[j]
152
161
if idx == -1:
153
162
for i from 0 <= i < n:
154
- outbuf [i, j] = fv
163
+ out [i, j] = fv
155
164
else:
156
165
for i from 0 <= i < n:
157
- outbuf [i, j] = %(preval)svalues[i, idx]%(postval)s
166
+ out [i, j] = %(preval)svalues[i, idx]%(postval)s
158
167
159
168
"""
160
169
161
170
take_2d_multi_template = """@cython.wraparound(False)
162
171
@cython.boundscheck(False)
163
172
def take_2d_multi_%(name)s_%(dest)s(ndarray[%(c_type_in)s, ndim=2] values,
164
173
indexer,
165
- out, fill_value=np.nan):
174
+ ndarray[%(c_type_out)s, ndim=2] out,
175
+ fill_value=np.nan):
166
176
cdef:
167
177
Py_ssize_t i, j, k, n, idx
168
178
ndarray[int64_t] idx0 = indexer[0]
169
179
ndarray[int64_t] idx1 = indexer[1]
170
- ndarray[%(c_type_out)s, ndim=2] outbuf = out
171
180
%(c_type_out)s fv
172
181
173
182
n = len(idx0)
@@ -178,13 +187,13 @@ def take_2d_multi_%(name)s_%(dest)s(ndarray[%(c_type_in)s, ndim=2] values,
178
187
idx = idx0[i]
179
188
if idx == -1:
180
189
for j from 0 <= j < k:
181
- outbuf [i, j] = fv
190
+ out [i, j] = fv
182
191
else:
183
192
for j from 0 <= j < k:
184
193
if idx1[j] == -1:
185
- outbuf [i, j] = fv
194
+ out [i, j] = fv
186
195
else:
187
- outbuf [i, j] = %(preval)svalues[idx, idx1[j]]%(postval)s
196
+ out [i, j] = %(preval)svalues[idx, idx1[j]]%(postval)s
188
197
189
198
"""
190
199
@@ -2169,7 +2178,7 @@ def generate_put_template(template, use_ints = True, use_floats = True):
2169
2178
2170
2179
output = StringIO ()
2171
2180
for name , c_type , dest_type , dest_dtype in function_list :
2172
- func = template % {'name' : name ,
2181
+ func = template % {'name' : name ,
2173
2182
'c_type' : c_type ,
2174
2183
'dest_type' : dest_type .replace ('_t' , '' ),
2175
2184
'dest_type2' : dest_type ,
@@ -2203,7 +2212,7 @@ def generate_take_template(template, exclude=None):
2203
2212
]
2204
2213
2205
2214
output = StringIO ()
2206
- for (name , dest , c_type_in , c_type_out ,
2215
+ for (name , dest , c_type_in , c_type_out ,
2207
2216
preval , postval , can_copy ) in function_list :
2208
2217
if exclude is not None and name in exclude :
2209
2218
continue
0 commit comments