Skip to content

Commit e852b36

Browse files
committed
PERF: Limit memmove to >= 256 bytes, relax contiguity requirements (only the stride in the dimension of the copy matters)
1 parent 4d3e34e commit e852b36

File tree

2 files changed

+643
-463
lines changed

2 files changed

+643
-463
lines changed

pandas/src/generate_code.py

+35-26
Original file line numberDiff line numberDiff line change
@@ -56,10 +56,10 @@
5656
take_1d_template = """@cython.wraparound(False)
5757
def take_1d_%(name)s_%(dest)s(ndarray[%(c_type_in)s] values,
5858
ndarray[int64_t] indexer,
59-
out, fill_value=np.nan):
59+
ndarray[%(c_type_out)s] out,
60+
fill_value=np.nan):
6061
cdef:
6162
Py_ssize_t i, n, idx
62-
ndarray[%(c_type_out)s] outbuf = out
6363
%(c_type_out)s fv
6464
6565
n = len(indexer)
@@ -68,20 +68,20 @@ def take_1d_%(name)s_%(dest)s(ndarray[%(c_type_in)s] values,
6868
for i from 0 <= i < n:
6969
idx = indexer[i]
7070
if idx == -1:
71-
outbuf[i] = fv
71+
out[i] = fv
7272
else:
73-
outbuf[i] = %(preval)svalues[idx]%(postval)s
73+
out[i] = %(preval)svalues[idx]%(postval)s
7474
7575
"""
7676

7777
take_2d_axis0_template = """@cython.wraparound(False)
7878
@cython.boundscheck(False)
7979
def take_2d_axis0_%(name)s_%(dest)s(ndarray[%(c_type_in)s, ndim=2] values,
8080
ndarray[int64_t] indexer,
81-
out, fill_value=np.nan):
81+
ndarray[%(c_type_out)s, ndim=2] out,
82+
fill_value=np.nan):
8283
cdef:
8384
Py_ssize_t i, j, k, n, idx
84-
ndarray[%(c_type_out)s, ndim=2] outbuf = out
8585
%(c_type_out)s fv
8686
8787
n = len(indexer)
@@ -93,81 +93,90 @@ def take_2d_axis0_%(name)s_%(dest)s(ndarray[%(c_type_in)s, ndim=2] values,
9393
cdef:
9494
%(c_type_out)s *v, *o
9595
96-
if values.flags.c_contiguous and out.flags.c_contiguous:
96+
if (values.strides[1] == out.strides[1] and
97+
values.strides[1] == sizeof(%(c_type_out)s) and
98+
sizeof(%(c_type_out)s) * n >= 256):
99+
97100
for i from 0 <= i < n:
98101
idx = indexer[i]
99102
if idx == -1:
100103
for j from 0 <= j < k:
101-
outbuf[i, j] = fv
104+
out[i, j] = fv
102105
else:
103106
v = &values[idx, 0]
104-
o = &outbuf[i, 0]
107+
o = &out[i, 0]
105108
memmove(o, v, <size_t>(sizeof(%(c_type_out)s) * k))
106109
return
107110
108111
for i from 0 <= i < n:
109112
idx = indexer[i]
110113
if idx == -1:
111114
for j from 0 <= j < k:
112-
outbuf[i, j] = fv
115+
out[i, j] = fv
113116
else:
114117
for j from 0 <= j < k:
115-
outbuf[i, j] = %(preval)svalues[idx, j]%(postval)s
118+
out[i, j] = %(preval)svalues[idx, j]%(postval)s
116119
117120
"""
118121

119122
take_2d_axis1_template = """@cython.wraparound(False)
120123
@cython.boundscheck(False)
121124
def take_2d_axis1_%(name)s_%(dest)s(ndarray[%(c_type_in)s, ndim=2] values,
122125
ndarray[int64_t] indexer,
123-
out, fill_value=np.nan):
126+
ndarray[%(c_type_out)s, ndim=2] out,
127+
fill_value=np.nan):
124128
cdef:
125129
Py_ssize_t i, j, k, n, idx
126-
ndarray[%(c_type_out)s, ndim=2] outbuf = out
127130
%(c_type_out)s fv
128131
129132
n = len(values)
130133
k = len(indexer)
131-
134+
135+
if n == 0 or k == 0:
136+
return
137+
132138
fv = fill_value
133139
134140
IF %(can_copy)s:
135141
cdef:
136142
%(c_type_out)s *v, *o
137143
138-
if values.flags.f_contiguous and out.flags.f_contiguous:
144+
if (values.strides[0] == out.strides[0] and
145+
values.strides[0] == sizeof(%(c_type_out)s) and
146+
sizeof(%(c_type_out)s) * n >= 256):
147+
139148
for j from 0 <= j < k:
140149
idx = indexer[j]
141150
if idx == -1:
142151
for i from 0 <= i < n:
143-
outbuf[i, j] = fv
152+
out[i, j] = fv
144153
else:
145154
v = &values[0, idx]
146-
o = &outbuf[0, j]
155+
o = &out[0, j]
147156
memmove(o, v, <size_t>(sizeof(%(c_type_out)s) * n))
148157
return
149158
150159
for j from 0 <= j < k:
151160
idx = indexer[j]
152161
if idx == -1:
153162
for i from 0 <= i < n:
154-
outbuf[i, j] = fv
163+
out[i, j] = fv
155164
else:
156165
for i from 0 <= i < n:
157-
outbuf[i, j] = %(preval)svalues[i, idx]%(postval)s
166+
out[i, j] = %(preval)svalues[i, idx]%(postval)s
158167
159168
"""
160169

161170
take_2d_multi_template = """@cython.wraparound(False)
162171
@cython.boundscheck(False)
163172
def take_2d_multi_%(name)s_%(dest)s(ndarray[%(c_type_in)s, ndim=2] values,
164173
indexer,
165-
out, fill_value=np.nan):
174+
ndarray[%(c_type_out)s, ndim=2] out,
175+
fill_value=np.nan):
166176
cdef:
167177
Py_ssize_t i, j, k, n, idx
168178
ndarray[int64_t] idx0 = indexer[0]
169179
ndarray[int64_t] idx1 = indexer[1]
170-
ndarray[%(c_type_out)s, ndim=2] outbuf = out
171180
%(c_type_out)s fv
172181
173182
n = len(idx0)
@@ -178,13 +187,13 @@ def take_2d_multi_%(name)s_%(dest)s(ndarray[%(c_type_in)s, ndim=2] values,
178187
idx = idx0[i]
179188
if idx == -1:
180189
for j from 0 <= j < k:
181-
outbuf[i, j] = fv
190+
out[i, j] = fv
182191
else:
183192
for j from 0 <= j < k:
184193
if idx1[j] == -1:
185-
outbuf[i, j] = fv
194+
out[i, j] = fv
186195
else:
187-
outbuf[i, j] = %(preval)svalues[idx, idx1[j]]%(postval)s
196+
out[i, j] = %(preval)svalues[idx, idx1[j]]%(postval)s
188197
189198
"""
190199

@@ -2169,7 +2178,7 @@ def generate_put_template(template, use_ints = True, use_floats = True):
21692178

21702179
output = StringIO()
21712180
for name, c_type, dest_type, dest_dtype in function_list:
2172-
func = template % {'name' : name,
2181+
func = template % {'name' : name,
21732182
'c_type' : c_type,
21742183
'dest_type' : dest_type.replace('_t', ''),
21752184
'dest_type2' : dest_type,
@@ -2203,7 +2212,7 @@ def generate_take_template(template, exclude=None):
22032212
]
22042213

22052214
output = StringIO()
2206-
for (name, dest, c_type_in, c_type_out,
2215+
for (name, dest, c_type_in, c_type_out,
22072216
preval, postval, can_copy) in function_list:
22082217
if exclude is not None and name in exclude:
22092218
continue

0 commit comments

Comments
 (0)