Skip to content

Commit 28d9d4d

Browse files
committed
Improve static output shape of Reshape
1 parent 734009a commit 28d9d4d

File tree

2 files changed

+49
-1
lines changed

2 files changed

+49
-1
lines changed

pytensor/tensor/shape.py

+24-1
Original file line numberDiff line numberDiff line change
@@ -669,7 +669,7 @@ def make_node(self, x, shp):
669669
assert shp.ndim == 1
670670

671671
if isinstance(shp, TensorConstant):
672-
out_shape = tuple(int(s) if s >= 0 else None for s in shp.data)
672+
out_shape = [int(s) if s >= 0 else None for s in shp.data]
673673
else:
674674
out_shape = [None] * self.ndim
675675
shp_list = shp_orig
@@ -685,6 +685,29 @@ def make_node(self, x, shp):
685685
except NotScalarConstantError:
686686
pass
687687

688+
# If we only don't know the size of one output dimension,
689+
# but we know all the input dimensions we can deduce it
690+
# This happens often when there is -1 as an input of Reshape
691+
if None not in x.type.shape and out_shape.count(None) == 1:
692+
full_size = np.prod(x.type.shape)
693+
known_size = np.prod([s for s in out_shape if s is not None])
694+
out_shape[out_shape.index(None)] = int(full_size // known_size)
695+
696+
out_shape = tuple(out_shape)
697+
698+
# Run some eager error checks
699+
if len(out_shape) != self.ndim:
700+
raise ValueError(
701+
"Shape argument to Reshape has incorrect length:"
702+
f" {len(out_shape)}, should be {self.ndim}"
703+
)
704+
705+
if None not in x.type.shape and None not in out_shape:
706+
if np.prod(x.type.shape) != np.prod(out_shape):
707+
raise ValueError(
708+
f"Reshape: Input shape {x.type.shape} is incompatible with new shape {out_shape}"
709+
)
710+
688711
return Apply(self, [x, shp], [tensor(dtype=x.type.dtype, shape=out_shape)])
689712

690713
def perform(self, node, inp, out_):

tests/tensor/test_shape.py

+25
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import re
2+
13
import numpy as np
24
import pytest
35

@@ -353,6 +355,29 @@ def test_rebuild(self):
353355
assert tuple(y_new.shape.eval({i: i_test})) == (4, 25)
354356
assert y_new.eval({i: i_test}).shape == (4, 25)
355357

358+
def test_static_shape(self):
359+
dim = lscalar("dim")
360+
x1 = tensor(shape=(2, 2, None))
361+
x2 = specify_shape(x1, (2, 2, 6))
362+
363+
assert reshape(x1, (6, 2)).type.shape == (6, 2)
364+
assert reshape(x1, (6, -1)).type.shape == (6, None)
365+
assert reshape(x1, (6, dim)).type.shape == (6, None)
366+
assert reshape(x1, (6, dim, 2)).type.shape == (6, None, 2)
367+
assert reshape(x1, (6, 3, 99)).type.shape == (6, 3, 99)
368+
369+
assert reshape(x2, (6, 4)).type.shape == (6, 4)
370+
assert reshape(x2, (6, -1)).type.shape == (6, 4)
371+
assert reshape(x2, (6, dim)).type.shape == (6, 4)
372+
assert reshape(x2, (6, dim, 2)).type.shape == (6, 2, 2)
373+
with pytest.raises(
374+
ValueError,
375+
match=re.escape(
376+
"Reshape: Input shape (2, 2, 6) is incompatible with new shape (6, 3, 99)"
377+
),
378+
):
379+
reshape(x2, (6, 3, 99))
380+
356381

357382
def test_shape_i_hash():
358383
assert isinstance(Shape_i(np.int64(1)).__hash__(), int)

0 commit comments

Comments
 (0)