From bf5d29ae845bd718bd70babba19c06721d59c152 Mon Sep 17 00:00:00 2001 From: Diego Sandoval Date: Tue, 2 Jul 2024 19:46:09 +0000 Subject: [PATCH 1/6] Implemented Eye Op in PyTorch - Added support for diagonal offset (param `k`) --- pytensor/link/pytorch/dispatch/basic.py | 18 +++++++++++++++++- tests/link/pytorch/test_basic.py | 19 ++++++++++++++++++- 2 files changed, 35 insertions(+), 2 deletions(-) diff --git a/pytensor/link/pytorch/dispatch/basic.py b/pytensor/link/pytorch/dispatch/basic.py index 0f5c1b2fe0..1524d966d2 100644 --- a/pytensor/link/pytorch/dispatch/basic.py +++ b/pytensor/link/pytorch/dispatch/basic.py @@ -6,7 +6,7 @@ from pytensor.graph.fg import FunctionGraph from pytensor.link.utils import fgraph_to_python from pytensor.raise_op import CheckAndRaise -from pytensor.tensor.basic import Alloc, AllocEmpty, ARange, Join +from pytensor.tensor.basic import Alloc, AllocEmpty, ARange, Eye, Join @singledispatch @@ -100,3 +100,19 @@ def join(axis, *tensors): return torch.cat(tensors, dim=axis) return join + + +@pytorch_funcify.register(Eye) +def pytorch_funcify_eye(op, **kwargs): + dtype = getattr(torch, op.dtype) + + def eye(N, M, k): + mjr, mnr = (M, N) if k > 0 else (N, M) + k_abs = abs(k) + zeros = torch.zeros(N, M, dtype=dtype) + if k_abs < mjr: + l_ones = min(mjr - k_abs, mnr) + return zeros.diagonal_scatter(torch.ones(l_ones, dtype=dtype), k) + return zeros + + return eye diff --git a/tests/link/pytorch/test_basic.py b/tests/link/pytorch/test_basic.py index c6750361a7..bd5c8615a3 100644 --- a/tests/link/pytorch/test_basic.py +++ b/tests/link/pytorch/test_basic.py @@ -13,7 +13,7 @@ from pytensor.graph.fg import FunctionGraph from pytensor.graph.op import Op from pytensor.raise_op import CheckAndRaise -from pytensor.tensor import alloc, arange, as_tensor, empty +from pytensor.tensor import alloc, arange, as_tensor, empty, eye from pytensor.tensor.type import matrix, scalar, vector @@ -275,3 +275,20 @@ def test_pytorch_Join(): np.c_[[5.0, 6.0]].astype(config.floatX), ], ) + + +def test_eye(): + N = scalar("N", dtype="int64") + M = scalar("M", dtype="int64") + k = scalar("k", dtype="int64") + + out = eye(N, M, k, dtype="int16") + + trange = range(1, 6) + for _N in trange: + for _M in trange: + for _k in list(range(_M + 2)) + [-x for x in range(1, _N + 2)]: + compare_pytorch_and_py( + FunctionGraph([N, M, k], [out]), + [np.array(_N + 1), np.array(_M + 1), np.array(_k)], + ) From 04d4ac718fd9fc76fcf7ead8e51d608eaccbac56 Mon Sep 17 00:00:00 2001 From: Diego Sandoval Date: Tue, 2 Jul 2024 20:06:44 +0000 Subject: [PATCH 2/6] Updated PyTorch Eye Op tests --- tests/link/pytorch/test_basic.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/link/pytorch/test_basic.py b/tests/link/pytorch/test_basic.py index bd5c8615a3..bf553c7816 100644 --- a/tests/link/pytorch/test_basic.py +++ b/tests/link/pytorch/test_basic.py @@ -290,5 +290,5 @@ def test_eye(): for _k in list(range(_M + 2)) + [-x for x in range(1, _N + 2)]: compare_pytorch_and_py( FunctionGraph([N, M, k], [out]), - [np.array(_N + 1), np.array(_M + 1), np.array(_k)], + [np.array(_N), np.array(_M), np.array(_k)], ) From 8b166f3113aeb5faceea53f786ff9f4b66a8758e Mon Sep 17 00:00:00 2001 From: Diego Sandoval Date: Wed, 3 Jul 2024 11:28:54 +0000 Subject: [PATCH 3/6] Replaced Abs and Min functions in Eye Op in PyTorch --- pytensor/link/pytorch/dispatch/basic.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytensor/link/pytorch/dispatch/basic.py b/pytensor/link/pytorch/dispatch/basic.py index 1524d966d2..250e8fe53c 100644 --- a/pytensor/link/pytorch/dispatch/basic.py +++ b/pytensor/link/pytorch/dispatch/basic.py @@ -108,10 +108,10 @@ def pytorch_funcify_eye(op, **kwargs): def eye(N, M, k): mjr, mnr = (M, N) if k > 0 else (N, M) - k_abs = abs(k) + k_abs = torch.abs(k) zeros = torch.zeros(N, M, dtype=dtype) if k_abs < mjr: - l_ones = min(mjr - k_abs, mnr) + l_ones = torch.min(mjr - k_abs, mnr) return zeros.diagonal_scatter(torch.ones(l_ones, dtype=dtype), k) return zeros From 7628331dd0e3f2c5bff280977bcc2b34b7e4461a Mon Sep 17 00:00:00 2001 From: Diego Sandoval Date: Thu, 4 Jul 2024 10:54:52 +0000 Subject: [PATCH 4/6] Refactor tests for Eye Op in PyTorch --- tests/link/pytorch/test_basic.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/link/pytorch/test_basic.py b/tests/link/pytorch/test_basic.py index bf553c7816..7ae9ad0d11 100644 --- a/tests/link/pytorch/test_basic.py +++ b/tests/link/pytorch/test_basic.py @@ -282,13 +282,13 @@ def test_eye(): M = scalar("M", dtype="int64") k = scalar("k", dtype="int64") - out = eye(N, M, k, dtype="int16") + out = eye(N, M, k, dtype="float32") trange = range(1, 6) + + fn = function([N, M, k], out, mode=pytorch_mode) + for _N in trange: for _M in trange: for _k in list(range(_M + 2)) + [-x for x in range(1, _N + 2)]: - compare_pytorch_and_py( - FunctionGraph([N, M, k], [out]), - [np.array(_N), np.array(_M), np.array(_k)], - ) + np.testing.assert_array_equal(fn(_N, _M, _k), np.eye(_N, _M, _k)) From daa86c4cde380e4e31eda9d043c80bee1d32ce95 Mon Sep 17 00:00:00 2001 From: Diego Sandoval Date: Thu, 4 Jul 2024 19:01:44 +0000 Subject: [PATCH 5/6] Renamed vars in implemeantion of Eye Op in PyTorch --- pytensor/link/pytorch/dispatch/basic.py | 12 ++++++------ tests/link/pytorch/test_basic.py | 6 ++---- 2 files changed, 8 insertions(+), 10 deletions(-) diff --git a/pytensor/link/pytorch/dispatch/basic.py b/pytensor/link/pytorch/dispatch/basic.py index 250e8fe53c..37622a8294 100644 --- a/pytensor/link/pytorch/dispatch/basic.py +++ b/pytensor/link/pytorch/dispatch/basic.py @@ -104,15 +104,15 @@ def join(axis, *tensors): @pytorch_funcify.register(Eye) def pytorch_funcify_eye(op, **kwargs): - dtype = getattr(torch, op.dtype) + torch_dtype = getattr(torch, op.dtype) def eye(N, M, k): - mjr, mnr = (M, N) if k > 0 else (N, M) + major, minor = (M, N) if k > 0 else (N, M) k_abs = torch.abs(k) - zeros = torch.zeros(N, M, dtype=dtype) - if k_abs < mjr: - l_ones = torch.min(mjr - k_abs, mnr) - return zeros.diagonal_scatter(torch.ones(l_ones, dtype=dtype), k) + zeros = torch.zeros(N, M, dtype=torch_dtype) + if k_abs < major: + l_ones = torch.min(major - k_abs, minor) + return zeros.diagonal_scatter(torch.ones(l_ones, dtype=torch_dtype), k) return zeros return eye diff --git a/tests/link/pytorch/test_basic.py b/tests/link/pytorch/test_basic.py index 7ae9ad0d11..91dd3cc350 100644 --- a/tests/link/pytorch/test_basic.py +++ b/tests/link/pytorch/test_basic.py @@ -284,11 +284,9 @@ def test_eye(): out = eye(N, M, k, dtype="float32") - trange = range(1, 6) - fn = function([N, M, k], out, mode=pytorch_mode) - for _N in trange: - for _M in trange: + for _N in range(1, 6): + for _M in range(1, 6): for _k in list(range(_M + 2)) + [-x for x in range(1, _N + 2)]: np.testing.assert_array_equal(fn(_N, _M, _k), np.eye(_N, _M, _k)) From 08d23ab8e0b435b4e82d9404a04a4cc98449f0b2 Mon Sep 17 00:00:00 2001 From: Diego Sandoval Date: Sun, 7 Jul 2024 09:58:04 +0200 Subject: [PATCH 6/6] Parametrized dtype in tests for Eye Op in PyTorch --- tests/link/pytorch/test_basic.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/tests/link/pytorch/test_basic.py b/tests/link/pytorch/test_basic.py index 91dd3cc350..0ccb1c454f 100644 --- a/tests/link/pytorch/test_basic.py +++ b/tests/link/pytorch/test_basic.py @@ -277,12 +277,16 @@ def test_pytorch_Join(): ) -def test_eye(): +@pytest.mark.parametrize( + "dtype", + ["int64", config.floatX], +) +def test_eye(dtype): N = scalar("N", dtype="int64") M = scalar("M", dtype="int64") k = scalar("k", dtype="int64") - out = eye(N, M, k, dtype="float32") + out = eye(N, M, k, dtype=dtype) fn = function([N, M, k], out, mode=pytorch_mode)