Skip to content

Commit 963fb8c

Browse files
committed
MAINT: split tri* family
1 parent a429352 commit 963fb8c

File tree

2 files changed

+52
-24
lines changed

2 files changed

+52
-24
lines changed

torch_np/_detail/implementations.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ def tensor_angle(z, deg=False):
8282

8383
# ### sorting ###
8484

85+
8586
def tensor_argsort(tensor, axis=-1, kind=None, order=None):
8687
if order is not None:
8788
raise NotImplementedError
@@ -91,6 +92,45 @@ def tensor_argsort(tensor, axis=-1, kind=None, order=None):
9192
return torch.argsort(tensor, stable=stable, dim=axis, descending=False)
9293

9394

95+
# ### tri*-something ###
96+
97+
98+
def tri(N, M, k, dtype):
99+
if M is None:
100+
M = N
101+
tensor = torch.ones((N, M), dtype=dtype)
102+
tensor = torch.tril(tensor, diagonal=k)
103+
return tensor
104+
105+
106+
def triu_indices_from(tensor, k):
107+
if tensor.ndim != 2:
108+
raise ValueError("input array must be 2-d")
109+
result = torch.triu_indices(tensor.shape[0], arr.shape[1], offset=k)
110+
return result
111+
112+
113+
def tril_indices_from(tensor, k=0):
114+
if tensor.ndim != 2:
115+
raise ValueError("input array must be 2-d")
116+
result = torch.tril_indices(tensor.shape[0], tensor.shape[1], offset=k)
117+
return result
118+
119+
120+
def tril_indices(n, k=0, m=None):
121+
if m is None:
122+
m = n
123+
result = torch.tril_indices(n, m, offset=k)
124+
return result
125+
126+
127+
def triu_indices(n, k=0, m=None):
128+
if m is None:
129+
m = n
130+
result = torch.triu_indices(n, m, offset=k)
131+
return result
132+
133+
94134
# ### splits ###
95135

96136

torch_np/_wrapper.py

Lines changed: 12 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -706,44 +706,32 @@ def triu(m, k=0):
706706

707707

708708
def tril_indices(n, k=0, m=None):
709-
if m is None:
710-
m = n
711-
tensor_2 = torch.tril_indices(n, m, offset=k)
712-
return tuple(asarray(_) for _ in tensor_2)
709+
result = _impl.tril_indices(n, k, m)
710+
return tuple(asarray(t) for t in result)
713711

714712

715713
def triu_indices(n, k=0, m=None):
716-
if m is None:
717-
m = n
718-
tensor_2 = torch.tril_indices(n, m, offset=k)
719-
return tuple(asarray(_) for _ in tensor_2)
714+
result = _impl.triu_indices(n, k, m)
715+
return tuple(asarray(t) for t in result)
720716

721717

722-
# YYY: pattern: array in, sequence of arrays out
723718
def tril_indices_from(arr, k=0):
724-
arr = asarray(arr).get()
725-
if arr.ndim != 2:
726-
raise ValueError("input array must be 2-d")
727-
tensor_2 = torch.tril_indices(arr.shape[0], arr.shape[1], offset=k)
728-
return tuple(asarray(_) for _ in tensor_2)
719+
tensor = asarray(arr).get()
720+
result = _impl.tril_indices_from(tensor, k)
721+
return tuple(asarray(t) for t in result)
729722

730723

731724
def triu_indices_from(arr, k=0):
732-
arr = asarray(arr).get()
733-
if arr.ndim != 2:
734-
raise ValueError("input array must be 2-d")
735-
tensor_2 = torch.tril_indices(arr.shape[0], arr.shape[1], offset=k)
736-
return tuple(asarray(_) for _ in tensor_2)
725+
tensor = asarray(arr).get()
726+
result = _impl.triu_indices_from(tensor, k)
727+
return tuple(asarray(t) for t in result)
737728

738729

739730
@_decorators.dtype_to_torch
740731
def tri(N, M=None, k=0, dtype=float, *, like=None):
741732
_util.subok_not_ok(like)
742-
if M is None:
743-
M = N
744-
tensor = torch.ones((N, M), dtype=dtype)
745-
tensor = torch.tril(tensor, diagonal=k)
746-
return asarray(tensor)
733+
result = _impl.tri(N, M, k, dtype)
734+
return asarray(result)
747735

748736

749737
###### reductions

0 commit comments

Comments
 (0)