Skip to content

Commit 02545ed

Browse files
committed
Specify reshape shape length if unknown
1 parent 141307f commit 02545ed

File tree

2 files changed

+4
-3
lines changed

2 files changed

+4
-3
lines changed

pytensor/tensor/shape.py

+2
Original file line numberDiff line numberDiff line change
@@ -644,6 +644,8 @@ def make_node(self, x, shp):
644644
x = ptb.as_tensor_variable(x)
645645
shp_orig = shp
646646
shp = ptb.as_tensor_variable(shp, ndim=1)
647+
if shp.type.shape == (None,):
648+
shp = specify_shape(shp, self.ndim)
647649
if not (
648650
shp.dtype in int_dtypes
649651
or (isinstance(shp, TensorConstant) and shp.data.size == 0)

tests/tensor/test_shape.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@ def setup_method(self):
9898
Shape_i,
9999
DimShuffle,
100100
Elemwise,
101+
SpecifyShape,
101102
)
102103
super().setup_method()
103104

@@ -253,9 +254,7 @@ def test_bad_shape(self):
253254
f(a_val, [7, 5])
254255
with pytest.raises(ValueError):
255256
f(a_val, [-1, -1])
256-
with pytest.raises(
257-
ValueError, match=".*Shape argument to Reshape has incorrect length.*"
258-
):
257+
with pytest.raises(AssertionError):
259258
f(a_val, [3, 4, 1])
260259

261260
def test_0(self):

0 commit comments

Comments
 (0)