Skip to content

Commit ca09cfe

Browse files
committed
Remove unused inplace option in DimShuffle
1 parent 94c84f8 commit ca09cfe

File tree

5 files changed

+13
-47
lines changed

5 files changed

+13
-47
lines changed

pytensor/link/jax/dispatch/elemwise.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -79,12 +79,7 @@ def dimshuffle(x):
7979
for augm in op.augment:
8080
shape.insert(augm, 1)
8181

82-
res = jnp.reshape(res, shape)
83-
84-
if not op.inplace:
85-
res = jnp.copy(res)
86-
87-
return res
82+
return jnp.reshape(res, shape)
8883

8984
return dimshuffle
9085

pytensor/link/numba/dispatch/elemwise.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -414,7 +414,6 @@ def numba_funcify_DimShuffle(op, node, **kwargs):
414414
shuffle = tuple(op.shuffle)
415415
transposition = tuple(op.transposition)
416416
augment = tuple(op.augment)
417-
inplace = op.inplace
418417

419418
ndim_new_shape = len(shuffle) + len(augment)
420419

@@ -474,12 +473,7 @@ def dimshuffle_inner(x, shuffle):
474473
new_shape = find_shape(shuffle_shape)
475474

476475
# FIXME: Numba's `array.reshape` only accepts C arrays.
477-
res_reshape = np.reshape(np.ascontiguousarray(x), new_shape)
478-
479-
if not inplace:
480-
return res_reshape.copy()
481-
else:
482-
return res_reshape
476+
return np.reshape(np.ascontiguousarray(x), new_shape)
483477

484478
else:
485479

pytensor/link/pytorch/dispatch/elemwise.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -61,12 +61,7 @@ def dimshuffle(x):
6161
for augm in op.augment:
6262
shape.insert(augm, 1)
6363

64-
res = torch.reshape(res, shape)
65-
66-
if not op.inplace:
67-
res = res.clone()
68-
69-
return res
64+
return torch.reshape(res, shape)
7065

7166
return dimshuffle
7267

pytensor/tensor/c_code/dimshuffle.c

Lines changed: 8 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,6 @@ int APPLY_SPECIFIC(cpu_dimshuffle)(PyArrayObject *input, PyArrayObject **res, PA
77
npy_intp* dimensions;
88
npy_intp* strides;
99

10-
// This points to either the original input or a copy we create below.
11-
// Either way, this is what we should be working on/with.
12-
PyArrayObject *_input;
1310

1411
if (!PyArray_IS_C_CONTIGUOUS(params->_new_order)) {
1512
PyErr_SetString(PyExc_RuntimeError, "DimShuffle: param _new_order must be C-contiguous.");
@@ -27,14 +24,7 @@ int APPLY_SPECIFIC(cpu_dimshuffle)(PyArrayObject *input, PyArrayObject **res, PA
2724
if (*res)
2825
Py_XDECREF(*res);
2926

30-
if (params->inplace) {
31-
_input = input;
32-
Py_INCREF((PyObject*)_input);
33-
} else {
34-
_input = (PyArrayObject *)PyArray_FromAny(
35-
(PyObject *)input, NULL, 0, 0, NPY_ARRAY_ALIGNED | NPY_ARRAY_ENSURECOPY,
36-
NULL);
37-
}
27+
Py_INCREF((PyObject*)input);
3828

3929
// Compute new dimensions and strides
4030
dimensions = (npy_intp*) malloc(nd_out * sizeof(npy_intp));
@@ -46,12 +36,12 @@ int APPLY_SPECIFIC(cpu_dimshuffle)(PyArrayObject *input, PyArrayObject **res, PA
4636
return 1;
4737
};
4838

49-
npy_intp original_size = PyArray_SIZE(_input);
39+
npy_intp original_size = PyArray_SIZE(input);
5040
npy_intp new_size = 1;
5141
for (npy_intp i = 0; i < nd_out; ++i) {
5242
if (new_order[i] != -1) {
53-
dimensions[i] = PyArray_DIMS(_input)[new_order[i]];
54-
strides[i] = PyArray_DIMS(_input)[new_order[i]] == 1 ? 0 : PyArray_STRIDES(_input)[new_order[i]];
43+
dimensions[i] = PyArray_DIMS(input)[new_order[i]];
44+
strides[i] = PyArray_DIMS(input)[new_order[i]] == 1 ? 0 : PyArray_STRIDES(input)[new_order[i]];
5545
} else {
5646
dimensions[i] = 1;
5747
strides[i] = 0;
@@ -68,11 +58,11 @@ int APPLY_SPECIFIC(cpu_dimshuffle)(PyArrayObject *input, PyArrayObject **res, PA
6858

6959
// Create the new array.
7060
*res = (PyArrayObject*)PyArray_New(&PyArray_Type, nd_out, dimensions,
71-
PyArray_TYPE(_input), strides,
72-
PyArray_DATA(_input), PyArray_ITEMSIZE(_input),
61+
PyArray_TYPE(input), strides,
62+
PyArray_DATA(input), PyArray_ITEMSIZE(input),
7363
// borrow only the writable flag from the base
7464
// the NPY_OWNDATA flag will default to 0.
75-
(NPY_ARRAY_WRITEABLE * PyArray_ISWRITEABLE(_input)),
65+
(NPY_ARRAY_WRITEABLE * PyArray_ISWRITEABLE(input)),
7666
NULL);
7767

7868
if (*res == NULL) {
@@ -85,7 +75,7 @@ int APPLY_SPECIFIC(cpu_dimshuffle)(PyArrayObject *input, PyArrayObject **res, PA
8575
PyArray_UpdateFlags(*res, NPY_ARRAY_UPDATE_ALL);
8676

8777
// we are making a view in both inplace and non-inplace cases
88-
PyArray_SetBaseObject(*res, (PyObject*)_input);
78+
PyArray_SetBaseObject(*res, (PyObject*)input);
8979

9080
free(strides);
9181
free(dimensions);

pytensor/tensor/elemwise.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
from pytensor.npy_2_compat import normalize_axis_tuple
2020
from pytensor.printing import Printer, pprint
2121
from pytensor.scalar import get_scalar_type
22-
from pytensor.scalar.basic import bool as scalar_bool
2322
from pytensor.scalar.basic import identity as scalar_identity
2423
from pytensor.scalar.basic import int64, transfer_type, upcast
2524
from pytensor.tensor import elemwise_cgen as cgen
@@ -114,15 +113,15 @@ class DimShuffle(ExternalCOp):
114113

115114
_f16_ok = True
116115
check_input = False
117-
__props__ = ("input_ndim", "new_order", "inplace")
116+
__props__ = ("input_ndim", "new_order")
118117
c_func_file = "c_code/dimshuffle.c"
119118
c_func_name = "APPLY_SPECIFIC(cpu_dimshuffle)"
119+
view_map = {0: [0]}
120120

121121
@property
122122
def params_type(self):
123123
return ParamsType(
124124
_new_order=lvector,
125-
inplace=scalar_bool,
126125
input_ndim=int64,
127126
)
128127

@@ -141,7 +140,6 @@ def __init__(self, *, input_ndim: int, new_order: Sequence[int | Literal["x"]]):
141140

142141
self.input_ndim = input_ndim
143142
self.new_order = tuple(new_order)
144-
self.inplace = True
145143

146144
for i, j in enumerate(new_order):
147145
if j != "x":
@@ -184,9 +182,6 @@ def __init__(self, *, input_ndim: int, new_order: Sequence[int | Literal["x"]]):
184182
:input_ndim
185183
] == list(range(input_ndim))
186184

187-
if self.inplace:
188-
self.view_map = {0: [0]}
189-
190185
def __setstate__(self, state):
191186
self.__dict__.update(state)
192187
if not hasattr(self, "func_files"):
@@ -251,9 +246,6 @@ def perform(self, node, inp, out):
251246
new_shape.insert(augm, 1)
252247
res = res.reshape(new_shape)
253248

254-
if not self.inplace:
255-
res = np.copy(res)
256-
257249
storage[0] = np.asarray(res)
258250

259251
def infer_shape(self, fgraph, node, shapes):

0 commit comments

Comments
 (0)