Skip to content

Commit 5024b04

Browse files
committed
Renamed vars in implemeantion of Eye Op in PyTorch
1 parent f256146 commit 5024b04

File tree

2 files changed

+8
-10
lines changed

2 files changed

+8
-10
lines changed

pytensor/link/pytorch/dispatch/basic.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -104,15 +104,15 @@ def join(axis, *tensors):
104104

105105
@pytorch_funcify.register(Eye)
106106
def pytorch_funcify_eye(op, **kwargs):
107-
dtype = getattr(torch, op.dtype)
107+
torch_dtype = getattr(torch, op.dtype)
108108

109109
def eye(N, M, k):
110-
mjr, mnr = (M, N) if k > 0 else (N, M)
110+
major, minor = (M, N) if k > 0 else (N, M)
111111
k_abs = torch.abs(k)
112-
zeros = torch.zeros(N, M, dtype=dtype)
113-
if k_abs < mjr:
114-
l_ones = torch.min(mjr - k_abs, mnr)
115-
return zeros.diagonal_scatter(torch.ones(l_ones, dtype=dtype), k)
112+
zeros = torch.zeros(N, M, dtype=torch_dtype)
113+
if k_abs < major:
114+
l_ones = torch.min(major - k_abs, minor)
115+
return zeros.diagonal_scatter(torch.ones(l_ones, dtype=torch_dtype), k)
116116
return zeros
117117

118118
return eye

tests/link/pytorch/test_basic.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -284,11 +284,9 @@ def test_eye():
284284

285285
out = eye(N, M, k, dtype="float32")
286286

287-
trange = range(1, 6)
288-
289287
fn = function([N, M, k], out, mode=pytorch_mode)
290288

291-
for _N in trange:
292-
for _M in trange:
289+
for _N in range(1, 6):
290+
for _M in range(1, 6):
293291
for _k in list(range(_M + 2)) + [-x for x in range(1, _N + 2)]:
294292
np.testing.assert_array_equal(fn(_N, _M, _k), np.eye(_N, _M, _k))

0 commit comments

Comments
 (0)