Skip to content

Commit 34e3ab6

Browse files
jbrockmendelAlexKirko
authored andcommitted
REF: algos_take_helper de-nest templating (pandas-dev#30413)
1 parent 2b453f5 commit 34e3ab6

File tree

1 file changed

+117
-170
lines changed

1 file changed

+117
-170
lines changed

pandas/_libs/algos_take_helper.pxi.in

+117-170
Original file line numberDiff line numberDiff line change
@@ -10,69 +10,119 @@ WARNING: DO NOT edit .pxi FILE directly, .pxi is generated from .pxi.in
1010

1111
{{py:
1212

13-
# c_type_in, c_type_out, preval, postval
13+
# c_type_in, c_type_out
1414
dtypes = [
15-
('uint8_t', 'uint8_t', '', ''),
16-
('uint8_t', 'object', 'True if ', ' > 0 else False'),
17-
('int8_t', 'int8_t', '', ''),
18-
('int8_t', 'int32_t', '', ''),
19-
('int8_t', 'int64_t', '', ''),
20-
('int8_t', 'float64_t', '', ''),
21-
('int16_t', 'int16_t', '', ''),
22-
('int16_t', 'int32_t', '', ''),
23-
('int16_t', 'int64_t', '', ''),
24-
('int16_t', 'float64_t', '', ''),
25-
('int32_t', 'int32_t', '', ''),
26-
('int32_t', 'int64_t', '', ''),
27-
('int32_t', 'float64_t', '', ''),
28-
('int64_t', 'int64_t', '', ''),
29-
('int64_t', 'float64_t', '', ''),
30-
('float32_t', 'float32_t', '', ''),
31-
('float32_t', 'float64_t', '', ''),
32-
('float64_t', 'float64_t', '', ''),
33-
('object', 'object', '', ''),
15+
('uint8_t', 'uint8_t'),
16+
('uint8_t', 'object'),
17+
('int8_t', 'int8_t'),
18+
('int8_t', 'int32_t'),
19+
('int8_t', 'int64_t'),
20+
('int8_t', 'float64_t'),
21+
('int16_t', 'int16_t'),
22+
('int16_t', 'int32_t'),
23+
('int16_t', 'int64_t'),
24+
('int16_t', 'float64_t'),
25+
('int32_t', 'int32_t'),
26+
('int32_t', 'int64_t'),
27+
('int32_t', 'float64_t'),
28+
('int64_t', 'int64_t'),
29+
('int64_t', 'float64_t'),
30+
('float32_t', 'float32_t'),
31+
('float32_t', 'float64_t'),
32+
('float64_t', 'float64_t'),
33+
('object', 'object'),
3434
]
3535

3636

3737
def get_dispatch(dtypes):
3838

39-
inner_take_1d_template = """
39+
for (c_type_in, c_type_out) in dtypes:
40+
41+
def get_name(dtype_name):
42+
if dtype_name == "object":
43+
return "object"
44+
if dtype_name == "uint8_t":
45+
return "bool"
46+
return dtype_name[:-2]
47+
48+
name = get_name(c_type_in)
49+
dest = get_name(c_type_out)
50+
51+
args = dict(name=name, dest=dest, c_type_in=c_type_in,
52+
c_type_out=c_type_out)
53+
54+
yield (name, dest, c_type_in, c_type_out)
55+
56+
}}
57+
58+
59+
{{for name, dest, c_type_in, c_type_out in get_dispatch(dtypes)}}
60+
61+
62+
@cython.wraparound(False)
63+
@cython.boundscheck(False)
64+
{{if c_type_in != "object"}}
65+
def take_1d_{{name}}_{{dest}}(const {{c_type_in}}[:] values,
66+
{{else}}
67+
def take_1d_{{name}}_{{dest}}(ndarray[{{c_type_in}}, ndim=1] values,
68+
{{endif}}
69+
const int64_t[:] indexer,
70+
{{c_type_out}}[:] out,
71+
fill_value=np.nan):
72+
4073
cdef:
4174
Py_ssize_t i, n, idx
42-
%(c_type_out)s fv
75+
{{c_type_out}} fv
4376

4477
n = indexer.shape[0]
4578

4679
fv = fill_value
4780

48-
%(nogil_str)s
49-
%(tab)sfor i in range(n):
50-
%(tab)s idx = indexer[i]
51-
%(tab)s if idx == -1:
52-
%(tab)s out[i] = fv
53-
%(tab)s else:
54-
%(tab)s out[i] = %(preval)svalues[idx]%(postval)s
55-
"""
81+
{{if c_type_out != "object"}}
82+
with nogil:
83+
{{else}}
84+
if True:
85+
{{endif}}
86+
for i in range(n):
87+
idx = indexer[i]
88+
if idx == -1:
89+
out[i] = fv
90+
else:
91+
{{if c_type_in == "uint8_t" and c_type_out == "object"}}
92+
out[i] = True if values[idx] > 0 else False
93+
{{else}}
94+
out[i] = values[idx]
95+
{{endif}}
96+
5697

57-
inner_take_2d_axis0_template = """\
98+
@cython.wraparound(False)
99+
@cython.boundscheck(False)
100+
{{if c_type_in != "object"}}
101+
def take_2d_axis0_{{name}}_{{dest}}(const {{c_type_in}}[:, :] values,
102+
{{else}}
103+
def take_2d_axis0_{{name}}_{{dest}}(ndarray[{{c_type_in}}, ndim=2] values,
104+
{{endif}}
105+
ndarray[int64_t] indexer,
106+
{{c_type_out}}[:, :] out,
107+
fill_value=np.nan):
58108
cdef:
59109
Py_ssize_t i, j, k, n, idx
60-
%(c_type_out)s fv
110+
{{c_type_out}} fv
61111

62112
n = len(indexer)
63113
k = values.shape[1]
64114

65115
fv = fill_value
66116

67-
IF %(can_copy)s:
117+
IF {{True if c_type_in == c_type_out != "object" else False}}:
68118
cdef:
69-
%(c_type_out)s *v
70-
%(c_type_out)s *o
119+
{{c_type_out}} *v
120+
{{c_type_out}} *o
71121

72-
#GH3130
122+
# GH#3130
73123
if (values.strides[1] == out.strides[1] and
74-
values.strides[1] == sizeof(%(c_type_out)s) and
75-
sizeof(%(c_type_out)s) * n >= 256):
124+
values.strides[1] == sizeof({{c_type_out}}) and
125+
sizeof({{c_type_out}}) * n >= 256):
76126

77127
for i in range(n):
78128
idx = indexer[i]
@@ -82,7 +132,7 @@ def get_dispatch(dtypes):
82132
else:
83133
v = &values[idx, 0]
84134
o = &out[i, 0]
85-
memmove(o, v, <size_t>(sizeof(%(c_type_out)s) * k))
135+
memmove(o, v, <size_t>(sizeof({{c_type_out}}) * k))
86136
return
87137

88138
for i in range(n):
@@ -92,13 +142,27 @@ def get_dispatch(dtypes):
92142
out[i, j] = fv
93143
else:
94144
for j in range(k):
95-
out[i, j] = %(preval)svalues[idx, j]%(postval)s
96-
"""
145+
{{if c_type_in == "uint8_t" and c_type_out == "object"}}
146+
out[i, j] = True if values[idx, j] > 0 else False
147+
{{else}}
148+
out[i, j] = values[idx, j]
149+
{{endif}}
150+
151+
152+
@cython.wraparound(False)
153+
@cython.boundscheck(False)
154+
{{if c_type_in != "object"}}
155+
def take_2d_axis1_{{name}}_{{dest}}(const {{c_type_in}}[:, :] values,
156+
{{else}}
157+
def take_2d_axis1_{{name}}_{{dest}}(ndarray[{{c_type_in}}, ndim=2] values,
158+
{{endif}}
159+
ndarray[int64_t] indexer,
160+
{{c_type_out}}[:, :] out,
161+
fill_value=np.nan):
97162

98-
inner_take_2d_axis1_template = """\
99163
cdef:
100164
Py_ssize_t i, j, k, n, idx
101-
%(c_type_out)s fv
165+
{{c_type_out}} fv
102166

103167
n = len(values)
104168
k = len(indexer)
@@ -114,132 +178,11 @@ def get_dispatch(dtypes):
114178
if idx == -1:
115179
out[i, j] = fv
116180
else:
117-
out[i, j] = %(preval)svalues[i, idx]%(postval)s
118-
"""
119-
120-
for (c_type_in, c_type_out, preval, postval) in dtypes:
121-
122-
can_copy = c_type_in == c_type_out != "object"
123-
nogil = c_type_out != "object"
124-
if nogil:
125-
nogil_str = "with nogil:"
126-
tab = ' '
127-
else:
128-
nogil_str = ''
129-
tab = ''
130-
131-
def get_name(dtype_name):
132-
if dtype_name == "object":
133-
return "object"
134-
if dtype_name == "uint8_t":
135-
return "bool"
136-
return dtype_name[:-2]
137-
138-
name = get_name(c_type_in)
139-
dest = get_name(c_type_out)
140-
141-
args = dict(name=name, dest=dest, c_type_in=c_type_in,
142-
c_type_out=c_type_out, preval=preval, postval=postval,
143-
can_copy=can_copy, nogil_str=nogil_str, tab=tab)
144-
145-
inner_take_1d = inner_take_1d_template % args
146-
inner_take_2d_axis0 = inner_take_2d_axis0_template % args
147-
inner_take_2d_axis1 = inner_take_2d_axis1_template % args
148-
149-
yield (name, dest, c_type_in, c_type_out, preval, postval,
150-
inner_take_1d, inner_take_2d_axis0, inner_take_2d_axis1)
151-
152-
}}
153-
154-
155-
{{for name, dest, c_type_in, c_type_out, preval, postval,
156-
inner_take_1d, inner_take_2d_axis0, inner_take_2d_axis1
157-
in get_dispatch(dtypes)}}
158-
159-
160-
@cython.wraparound(False)
161-
@cython.boundscheck(False)
162-
cdef inline take_1d_{{name}}_{{dest}}_memview({{c_type_in}}[:] values,
163-
const int64_t[:] indexer,
164-
{{c_type_out}}[:] out,
165-
fill_value=np.nan):
166-
167-
168-
{{inner_take_1d}}
169-
170-
171-
@cython.wraparound(False)
172-
@cython.boundscheck(False)
173-
def take_1d_{{name}}_{{dest}}(ndarray[{{c_type_in}}, ndim=1] values,
174-
const int64_t[:] indexer,
175-
{{c_type_out}}[:] out,
176-
fill_value=np.nan):
177-
178-
if values.flags.writeable:
179-
# We can call the memoryview version of the code
180-
take_1d_{{name}}_{{dest}}_memview(values, indexer, out,
181-
fill_value=fill_value)
182-
return
183-
184-
# We cannot use the memoryview version on readonly-buffers due to
185-
# a limitation of Cython's typed memoryviews. Instead we can use
186-
# the slightly slower Cython ndarray type directly.
187-
{{inner_take_1d}}
188-
189-
190-
@cython.wraparound(False)
191-
@cython.boundscheck(False)
192-
cdef inline take_2d_axis0_{{name}}_{{dest}}_memview({{c_type_in}}[:, :] values,
193-
const int64_t[:] indexer,
194-
{{c_type_out}}[:, :] out,
195-
fill_value=np.nan):
196-
{{inner_take_2d_axis0}}
197-
198-
199-
@cython.wraparound(False)
200-
@cython.boundscheck(False)
201-
def take_2d_axis0_{{name}}_{{dest}}(ndarray[{{c_type_in}}, ndim=2] values,
202-
ndarray[int64_t] indexer,
203-
{{c_type_out}}[:, :] out,
204-
fill_value=np.nan):
205-
if values.flags.writeable:
206-
# We can call the memoryview version of the code
207-
take_2d_axis0_{{name}}_{{dest}}_memview(values, indexer, out,
208-
fill_value=fill_value)
209-
return
210-
211-
# We cannot use the memoryview version on readonly-buffers due to
212-
# a limitation of Cython's typed memoryviews. Instead we can use
213-
# the slightly slower Cython ndarray type directly.
214-
{{inner_take_2d_axis0}}
215-
216-
217-
@cython.wraparound(False)
218-
@cython.boundscheck(False)
219-
cdef inline take_2d_axis1_{{name}}_{{dest}}_memview({{c_type_in}}[:, :] values,
220-
const int64_t[:] indexer,
221-
{{c_type_out}}[:, :] out,
222-
fill_value=np.nan):
223-
{{inner_take_2d_axis1}}
224-
225-
226-
@cython.wraparound(False)
227-
@cython.boundscheck(False)
228-
def take_2d_axis1_{{name}}_{{dest}}(ndarray[{{c_type_in}}, ndim=2] values,
229-
ndarray[int64_t] indexer,
230-
{{c_type_out}}[:, :] out,
231-
fill_value=np.nan):
232-
233-
if values.flags.writeable:
234-
# We can call the memoryview version of the code
235-
take_2d_axis1_{{name}}_{{dest}}_memview(values, indexer, out,
236-
fill_value=fill_value)
237-
return
238-
239-
# We cannot use the memoryview version on readonly-buffers due to
240-
# a limitation of Cython's typed memoryviews. Instead we can use
241-
# the slightly slower Cython ndarray type directly.
242-
{{inner_take_2d_axis1}}
181+
{{if c_type_in == "uint8_t" and c_type_out == "object"}}
182+
out[i, j] = True if values[i, idx] > 0 else False
183+
{{else}}
184+
out[i, j] = values[i, idx]
185+
{{endif}}
243186

244187

245188
@cython.wraparound(False)
@@ -268,7 +211,11 @@ def take_2d_multi_{{name}}_{{dest}}(ndarray[{{c_type_in}}, ndim=2] values,
268211
if idx1[j] == -1:
269212
out[i, j] = fv
270213
else:
271-
out[i, j] = {{preval}}values[idx, idx1[j]]{{postval}}
214+
{{if c_type_in == "uint8_t" and c_type_out == "object"}}
215+
out[i, j] = True if values[idx, idx1[j]] > 0 else False
216+
{{else}}
217+
out[i, j] = values[idx, idx1[j]]
218+
{{endif}}
272219

273220
{{endfor}}
274221

0 commit comments

Comments
 (0)