Skip to content

Commit 4ea96b2

Browse files
authored
Implemented Eye Op in PyTorch (#877)
1 parent ca10298 commit 4ea96b2

File tree

2 files changed

+37
-2
lines changed

2 files changed

+37
-2
lines changed

pytensor/link/pytorch/dispatch/basic.py

+17-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from pytensor.graph.fg import FunctionGraph
77
from pytensor.link.utils import fgraph_to_python
88
from pytensor.raise_op import CheckAndRaise
9-
from pytensor.tensor.basic import Alloc, AllocEmpty, ARange, Join
9+
from pytensor.tensor.basic import Alloc, AllocEmpty, ARange, Eye, Join
1010

1111

1212
@singledispatch
@@ -100,3 +100,19 @@ def join(axis, *tensors):
100100
return torch.cat(tensors, dim=axis)
101101

102102
return join
103+
104+
105+
@pytorch_funcify.register(Eye)
106+
def pytorch_funcify_eye(op, **kwargs):
107+
torch_dtype = getattr(torch, op.dtype)
108+
109+
def eye(N, M, k):
110+
major, minor = (M, N) if k > 0 else (N, M)
111+
k_abs = torch.abs(k)
112+
zeros = torch.zeros(N, M, dtype=torch_dtype)
113+
if k_abs < major:
114+
l_ones = torch.min(major - k_abs, minor)
115+
return zeros.diagonal_scatter(torch.ones(l_ones, dtype=torch_dtype), k)
116+
return zeros
117+
118+
return eye

tests/link/pytorch/test_basic.py

+20-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from pytensor.graph.fg import FunctionGraph
1414
from pytensor.graph.op import Op
1515
from pytensor.raise_op import CheckAndRaise
16-
from pytensor.tensor import alloc, arange, as_tensor, empty
16+
from pytensor.tensor import alloc, arange, as_tensor, empty, eye
1717
from pytensor.tensor.type import matrix, scalar, vector
1818

1919

@@ -275,3 +275,22 @@ def test_pytorch_Join():
275275
np.c_[[5.0, 6.0]].astype(config.floatX),
276276
],
277277
)
278+
279+
280+
@pytest.mark.parametrize(
281+
"dtype",
282+
["int64", config.floatX],
283+
)
284+
def test_eye(dtype):
285+
N = scalar("N", dtype="int64")
286+
M = scalar("M", dtype="int64")
287+
k = scalar("k", dtype="int64")
288+
289+
out = eye(N, M, k, dtype=dtype)
290+
291+
fn = function([N, M, k], out, mode=pytorch_mode)
292+
293+
for _N in range(1, 6):
294+
for _M in range(1, 6):
295+
for _k in list(range(_M + 2)) + [-x for x in range(1, _N + 2)]:
296+
np.testing.assert_array_equal(fn(_N, _M, _k), np.eye(_N, _M, _k))

0 commit comments

Comments
 (0)