Skip to content

Commit 223ee15

Browse files
Update Reshape C implementation
These changes remove the stride-based manual computation of the new shape, since those are potentially sensitive to broadcasted arrays with no strides.
1 parent 79961a6 commit 223ee15

File tree

2 files changed

+59
-60
lines changed

2 files changed

+59
-60
lines changed

aesara/tensor/shape.py

Lines changed: 34 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from aesara.tensor import basic as aet
1515
from aesara.tensor.exceptions import NotScalarConstantError
1616
from aesara.tensor.type import TensorType, int_dtypes, tensor
17-
from aesara.tensor.var import TensorConstant, TensorVariable
17+
from aesara.tensor.var import TensorConstant
1818

1919

2020
def register_shape_c_code(type, code, version=()):
@@ -570,15 +570,11 @@ def perform(self, node, inp, out_, params):
570570
if len(shp) != self.ndim:
571571
raise ValueError(
572572
(
573-
"shape argument to Reshape.perform has incorrect"
574-
f" length {len(shp)}"
575-
f", should be {self.ndim}"
573+
"Shape argument to Reshape has incorrect"
574+
f" length: {len(shp)}, should be {self.ndim}"
576575
)
577576
)
578-
try:
579-
out[0] = np.reshape(x, shp)
580-
except Exception:
581-
raise ValueError(f"Cannot reshape input of shape {x.shape} to shape {shp}")
577+
out[0] = np.reshape(x, shp)
582578

583579
def connection_pattern(self, node):
584580
return [[True], [False]]
@@ -669,44 +665,38 @@ def infer_shape(self, fgraph, node, ishapes):
669665
]
670666

671667
def c_code_cache_version(self):
672-
return (8,)
668+
return (9,)
673669

674670
def c_code(self, node, name, inputs, outputs, sub):
675-
if isinstance(node.inputs[0], TensorVariable):
676-
x, shp = inputs
677-
(z,) = outputs
678-
sdtype = node.inputs[1].type.dtype_specs()[1]
679-
fail = sub["fail"]
680-
params = sub["params"]
681-
return (
682-
"""
683-
assert (PyArray_NDIM(%(shp)s) == 1);
684-
npy_intp new_dims[%(params)s->ndim];
685-
PyArray_Dims newshape;
686-
newshape.ptr = new_dims;
687-
newshape.len = %(params)s->ndim;
688-
for (int ii = 0; ii < %(params)s->ndim; ++ii)
689-
{
690-
// -- We do not want an explicit cast here. the shp can be any
691-
// -- int* dtype. The compiler will explicitly upcast it, but
692-
// -- will err if this will downcast. This could happen if the
693-
// -- user pass an int64 dtype, but npy_intp endup being int32.
694-
new_dims[ii] = ((%(sdtype)s*)(
695-
PyArray_BYTES(%(shp)s) +
696-
ii * PyArray_STRIDES(%(shp)s)[0]))[0];
697-
}
698-
Py_XDECREF(%(z)s);
699-
%(z)s = (PyArrayObject *) PyArray_Newshape(%(x)s, &newshape, NPY_CORDER);
700-
if (!%(z)s)
701-
{
702-
//The error message should have been set by PyArray_Newshape
703-
%(fail)s;
704-
}
705-
"""
706-
% locals()
707-
)
708-
else:
709-
raise NotImplementedError()
671+
x, shp = inputs
672+
(z,) = outputs
673+
fail = sub["fail"]
674+
params = sub["params"]
675+
return f"""
676+
assert (PyArray_NDIM({shp}) == 1);
677+
678+
PyArray_Dims newshape;
679+
680+
if (!PyArray_IntpConverter((PyObject *){shp}, &newshape)) {{
681+
{fail};
682+
}}
683+
684+
if ({params}->ndim != newshape.len) {{
685+
PyErr_SetString(PyExc_ValueError, "Shape argument to Reshape has incorrect length");
686+
PyDimMem_FREE(newshape.ptr);
687+
{fail};
688+
}}
689+
690+
Py_XDECREF({z});
691+
{z} = (PyArrayObject *) PyArray_Newshape({x}, &newshape, NPY_CORDER);
692+
693+
PyDimMem_FREE(newshape.ptr);
694+
695+
if (!{z}) {{
696+
//The error message should have been set by PyArray_Newshape
697+
{fail};
698+
}}
699+
"""
710700

711701

712702
def reshape(x, newshape, ndim=None):

tests/tensor/test_shape.py

Lines changed: 25 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,10 @@ def test_bad_shape(self):
222222
f(a_val, [7, 5])
223223
with pytest.raises(ValueError):
224224
f(a_val, [-1, -1])
225+
with pytest.raises(
226+
ValueError, match=".*Shape argument to Reshape has incorrect length.*"
227+
):
228+
f(a_val, [3, 4, 1])
225229

226230
def test_0(self):
227231
x = fvector("x")
@@ -267,14 +271,14 @@ def test_more_shapes(self):
267271
[admat], [Reshape(ndim)(admat, [-1, 4])], [admat_val], Reshape
268272
)
269273

270-
# enable when infer_shape is generalized:
271-
# self._compile_and_check([admat, aivec],
272-
# [Reshape(ndim)(admat, aivec)],
273-
# [admat_val, [4, 3]], Reshape)
274-
#
275-
# self._compile_and_check([admat, aivec],
276-
# [Reshape(ndim)(admat, aivec)],
277-
# [admat_val, [4, -1]], Reshape)
274+
aivec = ivector()
275+
self._compile_and_check(
276+
[admat, aivec], [Reshape(ndim)(admat, aivec)], [admat_val, [4, 3]], Reshape
277+
)
278+
279+
self._compile_and_check(
280+
[admat, aivec], [Reshape(ndim)(admat, aivec)], [admat_val, [4, -1]], Reshape
281+
)
278282

279283
adtens4 = dtensor4()
280284
ndim = 4
@@ -287,14 +291,19 @@ def test_more_shapes(self):
287291
[adtens4], [Reshape(ndim)(adtens4, [1, 3, 10, 4])], [adtens4_val], Reshape
288292
)
289293

290-
# enable when infer_shape is generalized:
291-
# self._compile_and_check([adtens4, aivec],
292-
# [Reshape(ndim)(adtens4, aivec)],
293-
# [adtens4_val, [1, -1, 10, 4]], Reshape)
294-
#
295-
# self._compile_and_check([adtens4, aivec],
296-
# [Reshape(ndim)(adtens4, aivec)],
297-
# [adtens4_val, [1, 3, 10, 4]], Reshape)
294+
self._compile_and_check(
295+
[adtens4, aivec],
296+
[Reshape(ndim)(adtens4, aivec)],
297+
[adtens4_val, [1, -1, 10, 4]],
298+
Reshape,
299+
)
300+
301+
self._compile_and_check(
302+
[adtens4, aivec],
303+
[Reshape(ndim)(adtens4, aivec)],
304+
[adtens4_val, [1, 3, 10, 4]],
305+
Reshape,
306+
)
298307

299308

300309
def test_shape_i_hash():

0 commit comments

Comments
 (0)