Skip to content

Commit e593b0a

Browse files
Use NumPy C API to perform DimShuffle steps in its C implementation
1 parent 223ee15 commit e593b0a

File tree

9 files changed

+126
-162
lines changed

9 files changed

+126
-162
lines changed

aesara/gpuarray/elemwise.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -468,7 +468,7 @@ def perform(self, node, inp, out, params):
468468

469469
res = input
470470

471-
res = res.transpose(self.shuffle + self.drop)
471+
res = res.transpose(self.transposition)
472472

473473
shape = list(res.shape[: len(self.shuffle)])
474474
for augm in self.augment:

aesara/link/jax/dispatch.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -710,7 +710,7 @@ def reshape(x, shape):
710710
def jax_funcify_DimShuffle(op, **kwargs):
711711
def dimshuffle(x):
712712

713-
res = jnp.transpose(x, op.shuffle + op.drop)
713+
res = jnp.transpose(x, op.transposition)
714714

715715
shape = list(res.shape[: len(op.shuffle)])
716716

aesara/link/numba/dispatch/elemwise.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -319,7 +319,7 @@ def numba_funcify_CAReduce(op, node, **kwargs):
319319
@numba_funcify.register(DimShuffle)
320320
def numba_funcify_DimShuffle(op, **kwargs):
321321
shuffle = tuple(op.shuffle)
322-
drop = tuple(op.drop)
322+
transposition = tuple(op.transposition)
323323
augment = tuple(op.augment)
324324
inplace = op.inplace
325325

@@ -352,7 +352,7 @@ def populate_new_shape(i, j, new_shape, shuffle_shape):
352352

353353
@numba.njit
354354
def dimshuffle_inner(x, shuffle):
355-
res = np.transpose(x, shuffle + drop)
355+
res = np.transpose(x, transposition)
356356
shuffle_shape = res.shape[: len(shuffle)]
357357

358358
new_shape = create_zeros_tuple()

aesara/tensor/c_code/dimshuffle.c

Lines changed: 68 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -1,104 +1,81 @@
11
#section support_code_apply
22

3-
int APPLY_SPECIFIC(cpu_dimshuffle)(PyArrayObject* input, PyArrayObject** res, PARAMS_TYPE* params) {
4-
npy_bool* input_broadcastable;
5-
npy_int64* new_order;
6-
npy_intp nd_in;
7-
npy_intp nd_out;
8-
PyArrayObject* basename;
9-
npy_intp* dimensions;
10-
npy_intp* strides;
11-
12-
if (!PyArray_IS_C_CONTIGUOUS(params->input_broadcastable)) {
13-
PyErr_SetString(PyExc_RuntimeError, "DimShuffle: param input_broadcastable must be C-contiguous.");
14-
return 1;
15-
}
16-
if (!PyArray_IS_C_CONTIGUOUS(params->_new_order)) {
17-
PyErr_SetString(PyExc_RuntimeError, "DimShuffle: param _new_order must be C-contiguous.");
18-
return 1;
19-
}
20-
input_broadcastable = (npy_bool*) PyArray_DATA(params->input_broadcastable);
21-
new_order = (npy_int64*) PyArray_DATA(params->_new_order);
22-
nd_in = PyArray_SIZE(params->input_broadcastable);
23-
nd_out = PyArray_SIZE(params->_new_order);
24-
25-
/* check_input_nd */
26-
if (PyArray_NDIM(input) != nd_in) {
27-
PyErr_SetString(PyExc_NotImplementedError, "input nd");
28-
return 1;
29-
}
3+
int APPLY_SPECIFIC(cpu_dimshuffle)(PyArrayObject *input, PyArrayObject **res,
4+
PARAMS_TYPE *params) {
305

31-
/* clear_output */
32-
if (*res)
33-
Py_XDECREF(*res);
6+
// This points to either the original input or a copy we create below.
7+
// Either way, this is what we should be working on/with.
8+
PyArrayObject *_input;
349

35-
/* get_base */
36-
if (params->inplace) {
37-
basename = input;
38-
Py_INCREF((PyObject*)basename);
39-
} else {
40-
basename =
41-
(PyArrayObject*)PyArray_FromAny((PyObject*)input,
42-
NULL, 0, 0, NPY_ARRAY_ALIGNED|NPY_ARRAY_ENSURECOPY, NULL);
43-
}
10+
if (*res)
11+
Py_XDECREF(*res);
4412

45-
/* shape_statements and strides_statements */
46-
dimensions = (npy_intp*) malloc(nd_out * sizeof(npy_intp));
47-
strides = (npy_intp*) malloc(nd_out * sizeof(npy_intp));
48-
if (dimensions == NULL || strides == NULL) {
49-
PyErr_NoMemory();
50-
free(dimensions);
51-
free(strides);
52-
return 1;
53-
};
54-
55-
for (npy_intp i = 0; i < nd_out; ++i) {
56-
if (new_order[i] != -1) {
57-
dimensions[i] = PyArray_DIMS(basename)[new_order[i]];
58-
strides[i] = PyArray_DIMS(basename)[new_order[i]] == 1 ?
59-
0 : PyArray_STRIDES(basename)[new_order[i]];
60-
} else {
61-
dimensions[i] = 1;
62-
strides[i] = 0;
63-
}
64-
}
13+
if (params->inplace) {
14+
_input = input;
15+
Py_INCREF((PyObject *)_input);
16+
} else {
17+
_input = (PyArrayObject *)PyArray_FromAny(
18+
(PyObject *)input, NULL, 0, 0, NPY_ARRAY_ALIGNED | NPY_ARRAY_ENSURECOPY,
19+
NULL);
20+
}
6521

66-
/* set the strides of the broadcasted dimensions.
67-
* This algorithm is from numpy: PyArray_Newshape() in
68-
* cvs/numpy/numpy/core/src/multiarraymodule.c */
69-
if (nd_out > 0) {
70-
if (strides[nd_out - 1] == 0)
71-
strides[nd_out - 1] = PyArray_DESCR(basename)->elsize;
72-
for (npy_intp i = nd_out - 2; i > -1; --i) {
73-
if (strides[i] == 0)
74-
strides[i] = strides[i + 1] * dimensions[i + 1];
75-
}
76-
}
22+
PyArray_Dims permute;
23+
24+
if (!PyArray_IntpConverter((PyObject *)params->transposition, &permute)) {
25+
return 1;
26+
}
7727

78-
/* close_bracket */
79-
// create a new array.
80-
*res = (PyArrayObject*)PyArray_New(&PyArray_Type, nd_out, dimensions,
81-
PyArray_TYPE(basename), strides,
82-
PyArray_DATA(basename), PyArray_ITEMSIZE(basename),
83-
// borrow only the writable flag from the base
84-
// the NPY_OWNDATA flag will default to 0.
85-
(NPY_ARRAY_WRITEABLE * PyArray_ISWRITEABLE(basename)),
86-
NULL);
87-
88-
if (*res == NULL) {
89-
free(dimensions);
90-
free(strides);
91-
return 1;
28+
/*
29+
res = res.transpose(self.transposition)
30+
*/
31+
PyArrayObject *transposed_input =
32+
(PyArrayObject *)PyArray_Transpose(_input, &permute);
33+
34+
PyDimMem_FREE(permute.ptr);
35+
36+
npy_intp *res_shape = PyArray_DIMS(transposed_input);
37+
npy_intp N_shuffle = PyArray_SIZE(params->shuffle);
38+
npy_intp N_augment = PyArray_SIZE(params->augment);
39+
npy_intp N = N_augment + N_shuffle;
40+
npy_intp *_reshape_shape = (npy_intp *)malloc(N * sizeof(npy_intp));
41+
42+
if (_reshape_shape == NULL) {
43+
PyErr_NoMemory();
44+
free(_reshape_shape);
45+
return 1;
46+
}
47+
48+
/*
49+
shape = list(res.shape[: len(self.shuffle)])
50+
for augm in self.augment:
51+
shape.insert(augm, 1)
52+
*/
53+
npy_intp aug_idx = 0;
54+
int res_idx = 0;
55+
for (npy_intp i = 0; i < N; i++) {
56+
if (aug_idx < N_augment &&
57+
i == *((npy_intp *)PyArray_GetPtr(params->augment, &aug_idx))) {
58+
_reshape_shape[i] = 1;
59+
aug_idx++;
60+
} else {
61+
_reshape_shape[i] = res_shape[res_idx];
62+
res_idx++;
9263
}
64+
}
65+
66+
PyArray_Dims reshape_shape = {.ptr = _reshape_shape, .len = (int)N};
67+
68+
/* res = res.reshape(shape) */
69+
*res = (PyArrayObject *)PyArray_Newshape(transposed_input, &reshape_shape,
70+
NPY_CORDER);
9371

94-
// recalculate flags: CONTIGUOUS, FORTRAN, ALIGNED
95-
PyArray_UpdateFlags(*res, NPY_ARRAY_UPDATE_ALL);
72+
/* Py_XDECREF(transposed_input); */
9673

97-
// we are making a view in both inplace and non-inplace cases
98-
PyArray_SetBaseObject(*res, (PyObject*)basename);
74+
PyDimMem_FREE(reshape_shape.ptr);
9975

100-
free(strides);
101-
free(dimensions);
76+
if (!*res) {
77+
return 1;
78+
}
10279

103-
return 0;
80+
return 0;
10481
}

aesara/tensor/elemwise.py

Lines changed: 27 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -119,47 +119,27 @@ class DimShuffle(ExternalCOp):
119119

120120
@property
121121
def params_type(self):
122-
# We can't directly create `params_type` as class attribute
123-
# because of importation issues related to TensorType.
124122
return ParamsType(
125-
input_broadcastable=TensorType(dtype="bool", broadcastable=(False,)),
126-
_new_order=lvector,
127-
transposition=TensorType(dtype="uint32", broadcastable=(False,)),
123+
shuffle=lvector,
124+
augment=lvector,
125+
transposition=lvector,
128126
inplace=scalar_bool,
129127
)
130128

131-
@property
132-
def _new_order(self):
133-
# Param for C code.
134-
# self.new_order may contain 'x', which is not a valid integer value.
135-
# We replace it with -1.
136-
return [(-1 if x == "x" else x) for x in self.new_order]
137-
138-
@property
139-
def transposition(self):
140-
return self.shuffle + self.drop
141-
142-
def __init__(self, input_broadcastable, new_order, inplace=True):
129+
def __init__(self, input_broadcastable, new_order):
143130
super().__init__([self.c_func_file], self.c_func_name)
131+
144132
self.input_broadcastable = tuple(input_broadcastable)
145133
self.new_order = tuple(new_order)
146-
if inplace is True:
147-
self.inplace = inplace
148-
else:
149-
raise ValueError(
150-
"DimShuffle is inplace by default and hence the inplace for DimShuffle must be true"
151-
)
134+
135+
self.inplace = True
152136

153137
for i, j in enumerate(new_order):
154138
if j != "x":
155-
# There is a bug in numpy that results in
156-
# isinstance(x, integer_types) returning False for
157-
# numpy integers. See
158-
# <http://projects.scipy.org/numpy/ticket/2235>.
159139
if not isinstance(j, (int, np.integer)):
160140
raise TypeError(
161-
"DimShuffle indices must be python ints. "
162-
f"Got: '{j}' of type '{type(j)}'."
141+
"DimShuffle indices must be Python ints; got "
142+
f"{j} of type {type(j)}."
163143
)
164144
if j >= len(input_broadcastable):
165145
raise ValueError(
@@ -169,31 +149,30 @@ def __init__(self, input_broadcastable, new_order, inplace=True):
169149
if j in new_order[(i + 1) :]:
170150
raise ValueError(
171151
"The same input dimension may not appear "
172-
"twice in the list of output dimensions",
173-
new_order,
152+
f"twice in the list of output dimensions: {new_order}"
174153
)
175154

176-
# list of dimensions of the input to drop
177-
self.drop = []
155+
# List of input dimensions to drop
156+
drop = []
178157
for i, b in enumerate(input_broadcastable):
179158
if i not in new_order:
180-
# we want to drop this dimension because it's not a value in
181-
# new_order
182-
if b == 1: # 1 aka True
183-
self.drop.append(i)
159+
# We want to drop this dimension because it's not a value in
160+
# `new_order`
161+
if b == 1:
162+
drop.append(i)
184163
else:
185-
# we cannot drop non-broadcastable dimensions
164+
# We cannot drop non-broadcastable dimensions
186165
raise ValueError(
187-
"You cannot drop a non-broadcastable dimension:",
188-
f" {input_broadcastable}, {new_order}",
166+
"Cannot drop a non-broadcastable dimension: "
167+
f"{input_broadcastable}, {new_order}"
189168
)
190169

191-
# this is the list of the original dimensions that we keep
170+
# This is the list of the original dimensions that we keep
192171
self.shuffle = [x for x in new_order if x != "x"]
193-
194-
# list of dimensions of the output that are broadcastable and were not
172+
self.transposition = self.shuffle + drop
173+
# List of dimensions of the output that are broadcastable and were not
195174
# in the original input
196-
self.augment = [i for i, x in enumerate(new_order) if x == "x"]
175+
self.augment = sorted([i for i, x in enumerate(new_order) if x == "x"])
197176

198177
if self.inplace:
199178
self.view_map = {0: [0]}
@@ -241,27 +220,23 @@ def __str__(self):
241220
return "DimShuffle{%s}" % ",".join(str(x) for x in self.new_order)
242221

243222
def perform(self, node, inp, out, params):
244-
(input,) = inp
223+
(res,) = inp
245224
(storage,) = out
246-
# drop
247-
res = input
225+
248226
if type(res) != np.ndarray and type(res) != np.memmap:
249227
raise TypeError(res)
250228

251-
# transpose
252-
res = res.transpose(self.shuffle + self.drop)
229+
res = res.transpose(self.transposition)
253230

254-
# augment
255231
shape = list(res.shape[: len(self.shuffle)])
256232
for augm in self.augment:
257233
shape.insert(augm, 1)
258234
res = res.reshape(shape)
259235

260-
# copy (if not inplace)
261236
if not self.inplace:
262237
res = np.copy(res)
263238

264-
storage[0] = np.asarray(res) # asarray puts scalars back into array
239+
storage[0] = np.asarray(res)
265240

266241
def infer_shape(self, fgraph, node, shapes):
267242
(ishp,) = shapes

aesara/tensor/inplace.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -399,4 +399,4 @@ def conj_inplace(a):
399399
def transpose_inplace(x, **kwargs):
400400
"Perform a transpose on a tensor without copying the underlying storage"
401401
dims = list(range(x.ndim - 1, -1, -1))
402-
return DimShuffle(x.broadcastable, dims, inplace=True)(x)
402+
return DimShuffle(x.broadcastable, dims)(x)

tests/link/test_jax.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -856,7 +856,7 @@ def test_jax_Dimshuffle():
856856
compare_jax_and_py(x_fg, [np.c_[[1.0, 2.0, 3.0, 4.0]].astype(config.floatX)])
857857

858858
a_aet = tensor(dtype=config.floatX, broadcastable=[False, True])
859-
x = aet_elemwise.DimShuffle([False, True], (0,), inplace=True)(a_aet)
859+
x = aet_elemwise.DimShuffle([False, True], (0,))(a_aet)
860860
x_fg = FunctionGraph([a_aet], [x])
861861
compare_jax_and_py(x_fg, [np.c_[[1.0, 2.0, 3.0, 4.0]].astype(config.floatX)])
862862

0 commit comments

Comments
 (0)