Skip to content

Commit d3bd1f1

Browse files
committed
Add more specialized static output shape to Eye
Importantly, it now provides broadcastability information which is needed elsewhere
1 parent 28d9d4d commit d3bd1f1

File tree

2 files changed

+44
-32
lines changed

2 files changed

+44
-32
lines changed

pytensor/tensor/basic.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -1273,6 +1273,7 @@ def triu_indices_from(
12731273

12741274

12751275
class Eye(Op):
1276+
_output_type_depends_on_input_value = True
12761277
__props__ = ("dtype",)
12771278

12781279
def __init__(self, dtype=None):
@@ -1287,10 +1288,13 @@ def make_node(self, n, m, k):
12871288
assert n.ndim == 0
12881289
assert m.ndim == 0
12891290
assert k.ndim == 0
1291+
1292+
_, static_shape = infer_static_shape((n, m))
1293+
12901294
return Apply(
12911295
self,
12921296
[n, m, k],
1293-
[TensorType(dtype=self.dtype, shape=(None, None))()],
1297+
[TensorType(dtype=self.dtype, shape=static_shape)()],
12941298
)
12951299

12961300
def perform(self, node, inp, out_):

tests/tensor/test_basic.py

+39-31
Original file line numberDiff line numberDiff line change
@@ -937,38 +937,46 @@ def test_infer_static_shape():
937937
assert static_shape == (1,)
938938

939939

940-
# This is slow for the ('int8', 3) version.
941-
def test_eye():
942-
def check(dtype, N, M_=None, k=0):
943-
# PyTensor does not accept None as a tensor.
944-
# So we must use a real value.
945-
M = M_
946-
# Currently DebugMode does not support None as inputs even if this is
947-
# allowed.
948-
if M is None and config.mode in ["DebugMode", "DEBUG_MODE"]:
949-
M = N
950-
N_symb = iscalar()
951-
M_symb = iscalar()
952-
k_symb = iscalar()
953-
f = function([N_symb, M_symb, k_symb], eye(N_symb, M_symb, k_symb, dtype=dtype))
954-
result = f(N, M, k)
955-
assert np.allclose(result, np.eye(N, M_, k, dtype=dtype))
956-
assert result.dtype == np.dtype(dtype)
940+
class TestEye:
941+
# This is slow for the ('int8', 3) version.
942+
def test_basic(self):
943+
def check(dtype, N, M_=None, k=0):
944+
# PyTensor does not accept None as a tensor.
945+
# So we must use a real value.
946+
M = M_
947+
# Currently DebugMode does not support None as inputs even if this is
948+
# allowed.
949+
if M is None and config.mode in ["DebugMode", "DEBUG_MODE"]:
950+
M = N
951+
N_symb = iscalar()
952+
M_symb = iscalar()
953+
k_symb = iscalar()
954+
f = function(
955+
[N_symb, M_symb, k_symb], eye(N_symb, M_symb, k_symb, dtype=dtype)
956+
)
957+
result = f(N, M, k)
958+
assert np.allclose(result, np.eye(N, M_, k, dtype=dtype))
959+
assert result.dtype == np.dtype(dtype)
957960

958-
for dtype in ALL_DTYPES:
959-
check(dtype, 3)
960-
# M != N, k = 0
961-
check(dtype, 3, 5)
962-
check(dtype, 5, 3)
963-
# N == M, k != 0
964-
check(dtype, 3, 3, 1)
965-
check(dtype, 3, 3, -1)
966-
# N < M, k != 0
967-
check(dtype, 3, 5, 1)
968-
check(dtype, 3, 5, -1)
969-
# N > M, k != 0
970-
check(dtype, 5, 3, 1)
971-
check(dtype, 5, 3, -1)
961+
for dtype in ALL_DTYPES:
962+
check(dtype, 3)
963+
# M != N, k = 0
964+
check(dtype, 3, 5)
965+
check(dtype, 5, 3)
966+
# N == M, k != 0
967+
check(dtype, 3, 3, 1)
968+
check(dtype, 3, 3, -1)
969+
# N < M, k != 0
970+
check(dtype, 3, 5, 1)
971+
check(dtype, 3, 5, -1)
972+
# N > M, k != 0
973+
check(dtype, 5, 3, 1)
974+
check(dtype, 5, 3, -1)
975+
976+
def test_static_output_type(self):
977+
l = lscalar("l")
978+
assert eye(5, 3, l).type.shape == (5, 3)
979+
assert eye(1, l, 3).type.shape == (1, None)
972980

973981

974982
class TestTriangle:

0 commit comments

Comments
 (0)