Skip to content

Commit 9578bd3

Browse files
committed
Allow specializing shape of predefined tensors types
1 parent 3c66aa6 commit 9578bd3

File tree

2 files changed

+34
-0
lines changed

2 files changed

+34
-0
lines changed

pytensor/tensor/type.py

+12
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,18 @@ def parse_bcast_and_shape(s):
123123
self.name = name
124124
self.numpy_dtype = np.dtype(self.dtype)
125125

126+
def __call__(self, *args, shape=None, **kwargs):
127+
if shape is not None:
128+
# Check if shape is compatible with the original type
129+
new_type = self.clone(shape=shape)
130+
if self.is_super(new_type):
131+
return new_type(*args, **kwargs)
132+
else:
133+
raise ValueError(
134+
f"{shape=} is incompatible with original type shape {self.shape=}"
135+
)
136+
return super().__call__(*args, **kwargs)
137+
126138
def clone(
127139
self, dtype=None, shape=None, broadcastable=None, **kwargs
128140
) -> "TensorType":

tests/tensor/test_type.py

+22
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,10 @@
1010
from pytensor.tensor.type import (
1111
TensorType,
1212
col,
13+
dmatrix,
14+
drow,
15+
fmatrix,
16+
frow,
1317
matrix,
1418
row,
1519
scalar,
@@ -477,3 +481,21 @@ def test_row_matrix_creator_helpers(helper):
477481
match = "The second dimension of a `col` must have shape 1, got 5"
478482
with pytest.raises(ValueError, match=match):
479483
helper(shape=(2, 5))
484+
485+
486+
def test_shape_of_predefined_dtype_tensor():
487+
# Valid: None dimensions can be specialized
488+
assert fmatrix(shape=(1, None)).type == frow
489+
assert drow(shape=(1, 5)).type == dmatrix(shape=(1, 5)).type
490+
491+
# Invalid: Number of dimensions must match
492+
with pytest.raises(ValueError):
493+
fmatrix(shape=(None, None, None))
494+
495+
# Invalid: Fixed shapes must match
496+
with pytest.raises(ValueError):
497+
fmatrix(shape=(3, 5)).type(shape=(4, 5))
498+
499+
# Invalid: Known shapes can't be lost
500+
with pytest.raises(ValueError):
501+
drow(shape=(None, None))

0 commit comments

Comments
 (0)