@@ -129,62 +129,6 @@ def triu_indices(n, k=0, m=None):
129
129
return result
130
130
131
131
132
- def diag_indices (n , ndim = 2 ):
133
- idx = torch .arange (n )
134
- return (idx ,) * ndim
135
-
136
-
137
- def diag_indices_from (tensor ):
138
- if not tensor .ndim >= 2 :
139
- raise ValueError ("input array must be at least 2-d" )
140
- # For more than d=2, the strided formula is only valid for arrays with
141
- # all dimensions equal, so we check first.
142
- s = tensor .shape
143
- if s [1 :] != s [:- 1 ]:
144
- raise ValueError ("All dimensions of input must be of equal length" )
145
- return diag_indices (s [0 ], tensor .ndim )
146
-
147
-
148
- def fill_diagonal (tensor , t_val , wrap ):
149
- # torch.Tensor.fill_diagonal_ only accepts scalars. Thus vendor the numpy source,
150
- # https://github.com/numpy/numpy/blob/v1.24.0/numpy/lib/index_tricks.py#L786-L917
151
-
152
- if tensor .ndim < 2 :
153
- raise ValueError ("array must be at least 2-d" )
154
- end = None
155
- if tensor .ndim == 2 :
156
- # Explicit, fast formula for the common case. For 2-d arrays, we
157
- # accept rectangular ones.
158
- step = tensor .shape [1 ] + 1
159
- # This is needed to don't have tall matrix have the diagonal wrap.
160
- if not wrap :
161
- end = tensor .shape [1 ] * tensor .shape [1 ]
162
- else :
163
- # For more than d=2, the strided formula is only valid for arrays with
164
- # all dimensions equal, so we check first.
165
- s = tensor .shape
166
- if s [1 :] != s [:- 1 ]:
167
- raise ValueError ("All dimensions of input must be of equal length" )
168
- sz = torch .as_tensor (tensor .shape [:- 1 ])
169
- step = 1 + (torch .cumprod (sz , 0 )).sum ()
170
-
171
- # Write the value out into the diagonal.
172
- tensor .ravel ()[:end :step ] = t_val
173
- return tensor
174
-
175
-
176
- def trace (tensor , offset = 0 , axis1 = 0 , axis2 = 1 , dtype = None , out = None ):
177
- result = torch .diagonal (tensor , offset , dim1 = axis1 , dim2 = axis2 ).sum (- 1 , dtype = dtype )
178
- return result
179
-
180
-
181
- def diagonal (tensor , offset = 0 , axis1 = 0 , axis2 = 1 ):
182
- axis1 = _util .normalize_axis_index (axis1 , tensor .ndim )
183
- axis2 = _util .normalize_axis_index (axis2 , tensor .ndim )
184
- result = torch .diagonal (tensor , offset , axis1 , axis2 )
185
- return result
186
-
187
-
188
132
# ### splits ###
189
133
190
134
@@ -263,14 +207,6 @@ def dsplit(tensor, indices_or_sections):
263
207
return split_helper (tensor , indices_or_sections , 2 , strict = True )
264
208
265
209
266
- def clip (tensor , t_min , t_max ):
267
- if t_min is None and t_max is None :
268
- raise ValueError ("One of max or min must be given" )
269
-
270
- result = tensor .clamp (t_min , t_max )
271
- return result
272
-
273
-
274
210
def diff (a_tensor , n = 1 , axis = - 1 , prepend_tensor = None , append_tensor = None ):
275
211
axis = _util .normalize_axis_index (axis , a_tensor .ndim )
276
212
@@ -364,14 +300,6 @@ def vstack(tensors, *, dtype=None, casting="same_kind"):
364
300
return result
365
301
366
302
367
- def tile (tensor , reps ):
368
- if isinstance (reps , int ):
369
- reps = (reps ,)
370
-
371
- result = torch .tile (tensor , reps )
372
- return result
373
-
374
-
375
303
# #### cov & corrcoef
376
304
377
305
@@ -549,14 +477,6 @@ def arange(start=None, stop=None, step=1, dtype=None):
549
477
# ### empty/full et al ###
550
478
551
479
552
- def eye (N , M = None , k = 0 , dtype = float ):
553
- if M is None :
554
- M = N
555
- z = torch .zeros (N , M , dtype = dtype )
556
- z .diagonal (k ).fill_ (1 )
557
- return z
558
-
559
-
560
480
def zeros (shape , dtype = None , order = "C" ):
561
481
if order != "C" :
562
482
raise NotImplementedError
@@ -637,102 +557,12 @@ def full(shape, fill_value, dtype=None, order="C"):
637
557
# ### shape manipulations ###
638
558
639
559
640
- def roll (tensor , shift , axis = None ):
641
- if axis is not None :
642
- axis = _util .normalize_axis_tuple (axis , tensor .ndim , allow_duplicate = True )
643
- if not isinstance (shift , tuple ):
644
- shift = (shift ,) * len (axis )
645
- result = tensor .roll (shift , axis )
646
- return result
647
-
648
-
649
- def squeeze (tensor , axis = None ):
650
- if axis == ():
651
- result = tensor
652
- elif axis is None :
653
- result = tensor .squeeze ()
654
- else :
655
- if isinstance (axis , tuple ):
656
- result = tensor
657
- for ax in axis :
658
- result = result .squeeze (ax )
659
- else :
660
- result = tensor .squeeze (axis )
661
- return result
662
-
663
-
664
- def reshape (tensor , shape , order = "C" ):
665
- if order != "C" :
666
- raise NotImplementedError
667
- # if sh = (1, 2, 3), numpy allows both .reshape(sh) and .reshape(*sh)
668
- newshape = shape [0 ] if len (shape ) == 1 else shape
669
- result = tensor .reshape (newshape )
670
- return result
671
-
672
-
673
- def transpose (tensor , axes = None ):
674
- # numpy allows both .tranpose(sh) and .transpose(*sh)
675
- if axes in [(), None , (None ,)]:
676
- axes = tuple (range (tensor .ndim ))[::- 1 ]
677
- try :
678
- result = tensor .permute (axes )
679
- except RuntimeError :
680
- raise ValueError ("axes don't match array" )
681
- return result
682
-
683
-
684
- def ravel (tensor , order = "C" ):
685
- if order != "C" :
686
- raise NotImplementedError
687
- result = tensor .ravel ()
688
- return result
689
-
690
-
691
- # leading underscore since arr.flatten exists but np.flatten does not
692
- def _flatten (tensor , order = "C" ):
693
- if order != "C" :
694
- raise NotImplementedError
695
- # return a copy
696
- result = tensor .flatten ()
697
- return result
698
-
699
-
700
560
# ### swap/move/roll axis ###
701
561
702
562
703
- def moveaxis (tensor , source , destination ):
704
- source = _util .normalize_axis_tuple (source , tensor .ndim , "source" )
705
- destination = _util .normalize_axis_tuple (destination , tensor .ndim , "destination" )
706
- result = torch .moveaxis (tensor , source , destination )
707
- return result
708
-
709
-
710
563
# ### Numeric ###
711
564
712
565
713
- def round (tensor , decimals = 0 ):
714
- if tensor .is_floating_point ():
715
- result = torch .round (tensor , decimals = decimals )
716
- elif tensor .is_complex ():
717
- # RuntimeError: "round_cpu" not implemented for 'ComplexFloat'
718
- result = (
719
- torch .round (tensor .real , decimals = decimals )
720
- + torch .round (tensor .imag , decimals = decimals ) * 1j
721
- )
722
- else :
723
- # RuntimeError: "round_cpu" not implemented for 'int'
724
- result = tensor
725
- return result
726
-
727
-
728
- def imag (tensor ):
729
- if tensor .is_complex ():
730
- result = tensor .imag
731
- else :
732
- result = torch .zeros_like (tensor )
733
- return result
734
-
735
-
736
566
# ### put/take along axis ###
737
567
738
568
@@ -753,36 +583,6 @@ def put_along_dim(tensor, t_indices, t_values, axis):
753
583
return result
754
584
755
585
756
- # ### sort and partition ###
757
-
758
-
759
- def _sort_helper (tensor , axis , kind , order ):
760
- if order is not None :
761
- # only relevant for structured dtypes; not supported
762
- raise NotImplementedError (
763
- "'order' keyword is only relevant for structured dtypes"
764
- )
765
-
766
- (tensor ,), axis = _util .axis_none_ravel (tensor , axis = axis )
767
- axis = _util .normalize_axis_index (axis , tensor .ndim )
768
-
769
- stable = kind == "stable"
770
-
771
- return tensor , axis , stable
772
-
773
-
774
- def sort (tensor , axis = - 1 , kind = None , order = None ):
775
- tensor , axis , stable = _sort_helper (tensor , axis , kind , order )
776
- result = torch .sort (tensor , dim = axis , stable = stable )
777
- return result .values
778
-
779
-
780
- def argsort (tensor , axis = - 1 , kind = None , order = None ):
781
- tensor , axis , stable = _sort_helper (tensor , axis , kind , order )
782
- result = torch .argsort (tensor , dim = axis , stable = stable )
783
- return result
784
-
785
-
786
586
# ### logic and selection ###
787
587
788
588
@@ -831,56 +631,6 @@ def inner(t_a, t_b):
831
631
return result
832
632
833
633
834
- def vdot (t_a , t_b , / ):
835
- # 1. torch only accepts 1D arrays, numpy ravels
836
- # 2. torch requires matching dtype, while numpy casts (?)
837
- t_a , t_b = torch .atleast_1d (t_a , t_b )
838
- if t_a .ndim > 1 :
839
- t_a = t_a .ravel ()
840
- if t_b .ndim > 1 :
841
- t_b = t_b .ravel ()
842
-
843
- dtype = _dtypes_impl .result_type_impl ((t_a .dtype , t_b .dtype ))
844
- is_half = dtype == torch .float16
845
- is_bool = dtype == torch .bool
846
-
847
- # work around torch's "dot" not implemented for 'Half', 'Bool'
848
- if is_half :
849
- dtype = torch .float32
850
- if is_bool :
851
- dtype = torch .uint8
852
-
853
- t_a = _util .cast_if_needed (t_a , dtype )
854
- t_b = _util .cast_if_needed (t_b , dtype )
855
-
856
- result = torch .vdot (t_a , t_b )
857
-
858
- if is_half :
859
- result = result .to (torch .float16 )
860
- if is_bool :
861
- result = result .to (torch .bool )
862
-
863
- return result
864
-
865
-
866
- def dot (t_a , t_b ):
867
- dtype = _dtypes_impl .result_type_impl ((t_a .dtype , t_b .dtype ))
868
- t_a = _util .cast_if_needed (t_a , dtype )
869
- t_b = _util .cast_if_needed (t_b , dtype )
870
-
871
- if t_a .ndim == 0 or t_b .ndim == 0 :
872
- result = t_a * t_b
873
- elif t_a .ndim == 1 and t_b .ndim == 1 :
874
- result = torch .dot (t_a , t_b )
875
- elif t_a .ndim == 1 :
876
- result = torch .mv (t_b .T , t_a ).T
877
- elif t_b .ndim == 1 :
878
- result = torch .mv (t_a , t_b )
879
- else :
880
- result = torch .matmul (t_a , t_b )
881
- return result
882
-
883
-
884
634
# ### unique et al ###
885
635
886
636
0 commit comments