Skip to content

Commit 51823c9

Browse files
twaclawDiego Sandoval
authored and
Diego Sandoval
committed
Implemented Eye Op in PyTorch
- Added support for diagonal offset (param `k`)
1 parent afc1a6c commit 51823c9

File tree

2 files changed

+35
-2
lines changed

2 files changed

+35
-2
lines changed

pytensor/link/pytorch/dispatch/basic.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from pytensor.graph.fg import FunctionGraph
77
from pytensor.link.utils import fgraph_to_python
88
from pytensor.raise_op import CheckAndRaise
9-
from pytensor.tensor.basic import Alloc, AllocEmpty, ARange, Join
9+
from pytensor.tensor.basic import Alloc, AllocEmpty, ARange, Eye, Join
1010

1111

1212
@singledispatch
@@ -100,3 +100,19 @@ def join(axis, *tensors):
100100
return torch.cat(tensors, dim=axis)
101101

102102
return join
103+
104+
105+
@pytorch_funcify.register(Eye)
106+
def pytorch_funcify_eye(op, **kwargs):
107+
dtype = getattr(torch, op.dtype)
108+
109+
def eye(N, M, k):
110+
mjr, mnr = (M, N) if k > 0 else (N, M)
111+
k_abs = abs(k)
112+
zeros = torch.zeros(N, M, dtype=dtype)
113+
if k_abs < mjr:
114+
l_ones = min(mjr - k_abs, mnr)
115+
return zeros.diagonal_scatter(torch.ones(l_ones, dtype=dtype), k)
116+
return zeros
117+
118+
return eye

tests/link/pytorch/test_basic.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from pytensor.graph.fg import FunctionGraph
1414
from pytensor.graph.op import Op
1515
from pytensor.raise_op import CheckAndRaise
16-
from pytensor.tensor import alloc, arange, as_tensor, empty
16+
from pytensor.tensor import alloc, arange, as_tensor, empty, eye
1717
from pytensor.tensor.type import matrix, scalar, vector
1818

1919

@@ -275,3 +275,20 @@ def test_pytorch_Join():
275275
np.c_[[5.0, 6.0]].astype(config.floatX),
276276
],
277277
)
278+
279+
280+
def test_eye():
281+
N = scalar("N", dtype="int64")
282+
M = scalar("M", dtype="int64")
283+
k = scalar("k", dtype="int64")
284+
285+
out = eye(N, M, k, dtype="int16")
286+
287+
trange = range(1, 6)
288+
for _N in trange:
289+
for _M in trange:
290+
for _k in list(range(_M + 2)) + [-x for x in range(1, _N + 2)]:
291+
compare_pytorch_and_py(
292+
FunctionGraph([N, M, k], [out]),
293+
[np.array(_N + 1), np.array(_M + 1), np.array(_k)],
294+
)

0 commit comments

Comments
 (0)