@@ -601,7 +601,8 @@ def broadcast_to(array, shape, subok=False):
601
601
# YYY: pattern: tuple of arrays as input, tuple of arrays as output; cf nonzero
602
602
def broadcast_arrays (* args , subok = False ):
603
603
_util .subok_not_ok (subok = subok )
604
- res = torch .broadcast_tensors (* [asarray (a ).get () for a in args ])
604
+ tensors = _helpers .to_tensors (* args )
605
+ res = torch .broadcast_tensors (* tensors )
605
606
return tuple (asarray (_ ) for _ in res )
606
607
607
608
@@ -706,44 +707,32 @@ def triu(m, k=0):
706
707
707
708
708
709
def tril_indices (n , k = 0 , m = None ):
709
- if m is None :
710
- m = n
711
- tensor_2 = torch .tril_indices (n , m , offset = k )
712
- return tuple (asarray (_ ) for _ in tensor_2 )
710
+ result = _impl .tril_indices (n , k , m )
711
+ return tuple (asarray (t ) for t in result )
713
712
714
713
715
714
def triu_indices (n , k = 0 , m = None ):
716
- if m is None :
717
- m = n
718
- tensor_2 = torch .tril_indices (n , m , offset = k )
719
- return tuple (asarray (_ ) for _ in tensor_2 )
715
+ result = _impl .triu_indices (n , k , m )
716
+ return tuple (asarray (t ) for t in result )
720
717
721
718
722
- # YYY: pattern: array in, sequence of arrays out
723
719
def tril_indices_from (arr , k = 0 ):
724
- arr = asarray (arr ).get ()
725
- if arr .ndim != 2 :
726
- raise ValueError ("input array must be 2-d" )
727
- tensor_2 = torch .tril_indices (arr .shape [0 ], arr .shape [1 ], offset = k )
728
- return tuple (asarray (_ ) for _ in tensor_2 )
720
+ tensor = asarray (arr ).get ()
721
+ result = _impl .tril_indices_from (tensor , k )
722
+ return tuple (asarray (t ) for t in result )
729
723
730
724
731
725
def triu_indices_from (arr , k = 0 ):
732
- arr = asarray (arr ).get ()
733
- if arr .ndim != 2 :
734
- raise ValueError ("input array must be 2-d" )
735
- tensor_2 = torch .tril_indices (arr .shape [0 ], arr .shape [1 ], offset = k )
736
- return tuple (asarray (_ ) for _ in tensor_2 )
726
+ tensor = asarray (arr ).get ()
727
+ result = _impl .triu_indices_from (tensor , k )
728
+ return tuple (asarray (t ) for t in result )
737
729
738
730
739
731
@_decorators .dtype_to_torch
740
732
def tri (N , M = None , k = 0 , dtype = float , * , like = None ):
741
733
_util .subok_not_ok (like )
742
- if M is None :
743
- M = N
744
- tensor = torch .ones ((N , M ), dtype = dtype )
745
- tensor = torch .tril (tensor , diagonal = k )
746
- return asarray (tensor )
734
+ result = _impl .tri (N , M , k , dtype )
735
+ return asarray (result )
747
736
748
737
749
738
###### reductions
0 commit comments