Skip to content

Commit b9a8ab5

Browse files
committed
MAINT: split tri* family
1 parent 5031d40 commit b9a8ab5

File tree

2 files changed

+54
-25
lines changed

2 files changed

+54
-25
lines changed

torch_np/_detail/implementations.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ def tensor_angle(z, deg=False):
8282

8383
# ### sorting ###
8484

85+
8586
def tensor_argsort(tensor, axis=-1, kind=None, order=None):
8687
if order is not None:
8788
raise NotImplementedError
@@ -91,6 +92,45 @@ def tensor_argsort(tensor, axis=-1, kind=None, order=None):
9192
return torch.argsort(tensor, stable=stable, dim=axis, descending=False)
9293

9394

95+
# ### tri*-something ###
96+
97+
98+
def tri(N, M, k, dtype):
99+
if M is None:
100+
M = N
101+
tensor = torch.ones((N, M), dtype=dtype)
102+
tensor = torch.tril(tensor, diagonal=k)
103+
return tensor
104+
105+
106+
def triu_indices_from(tensor, k):
107+
if tensor.ndim != 2:
108+
raise ValueError("input array must be 2-d")
109+
result = torch.triu_indices(tensor.shape[0], tensor.shape[1], offset=k)
110+
return result
111+
112+
113+
def tril_indices_from(tensor, k=0):
114+
if tensor.ndim != 2:
115+
raise ValueError("input array must be 2-d")
116+
result = torch.tril_indices(tensor.shape[0], tensor.shape[1], offset=k)
117+
return result
118+
119+
120+
def tril_indices(n, k=0, m=None):
121+
if m is None:
122+
m = n
123+
result = torch.tril_indices(n, m, offset=k)
124+
return result
125+
126+
127+
def triu_indices(n, k=0, m=None):
128+
if m is None:
129+
m = n
130+
result = torch.triu_indices(n, m, offset=k)
131+
return result
132+
133+
94134
# ### splits ###
95135

96136

torch_np/_wrapper.py

Lines changed: 14 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -601,7 +601,8 @@ def broadcast_to(array, shape, subok=False):
601601
# YYY: pattern: tuple of arrays as input, tuple of arrays as output; cf nonzero
602602
def broadcast_arrays(*args, subok=False):
603603
_util.subok_not_ok(subok=subok)
604-
res = torch.broadcast_tensors(*[asarray(a).get() for a in args])
604+
tensors = _helpers.to_tensors(*args)
605+
res = torch.broadcast_tensors(*tensors)
605606
return tuple(asarray(_) for _ in res)
606607

607608

@@ -706,44 +707,32 @@ def triu(m, k=0):
706707

707708

708709
def tril_indices(n, k=0, m=None):
709-
if m is None:
710-
m = n
711-
tensor_2 = torch.tril_indices(n, m, offset=k)
712-
return tuple(asarray(_) for _ in tensor_2)
710+
result = _impl.tril_indices(n, k, m)
711+
return tuple(asarray(t) for t in result)
713712

714713

715714
def triu_indices(n, k=0, m=None):
716-
if m is None:
717-
m = n
718-
tensor_2 = torch.tril_indices(n, m, offset=k)
719-
return tuple(asarray(_) for _ in tensor_2)
715+
result = _impl.triu_indices(n, k, m)
716+
return tuple(asarray(t) for t in result)
720717

721718

722-
# YYY: pattern: array in, sequence of arrays out
723719
def tril_indices_from(arr, k=0):
724-
arr = asarray(arr).get()
725-
if arr.ndim != 2:
726-
raise ValueError("input array must be 2-d")
727-
tensor_2 = torch.tril_indices(arr.shape[0], arr.shape[1], offset=k)
728-
return tuple(asarray(_) for _ in tensor_2)
720+
tensor = asarray(arr).get()
721+
result = _impl.tril_indices_from(tensor, k)
722+
return tuple(asarray(t) for t in result)
729723

730724

731725
def triu_indices_from(arr, k=0):
732-
arr = asarray(arr).get()
733-
if arr.ndim != 2:
734-
raise ValueError("input array must be 2-d")
735-
tensor_2 = torch.tril_indices(arr.shape[0], arr.shape[1], offset=k)
736-
return tuple(asarray(_) for _ in tensor_2)
726+
tensor = asarray(arr).get()
727+
result = _impl.triu_indices_from(tensor, k)
728+
return tuple(asarray(t) for t in result)
737729

738730

739731
@_decorators.dtype_to_torch
740732
def tri(N, M=None, k=0, dtype=float, *, like=None):
741733
_util.subok_not_ok(like)
742-
if M is None:
743-
M = N
744-
tensor = torch.ones((N, M), dtype=dtype)
745-
tensor = torch.tril(tensor, diagonal=k)
746-
return asarray(tensor)
734+
result = _impl.tri(N, M, k, dtype)
735+
return asarray(result)
747736

748737

749738
###### reductions

0 commit comments

Comments
 (0)