17
17
NDArray ,
18
18
SubokLike ,
19
19
UnpackedSeqArrayLike ,
20
+ NDArrayOrSequence ,
20
21
normalizer ,
21
22
)
22
23
@@ -69,30 +70,21 @@ def copy(a: ArrayLike, order="K", subok: SubokLike = False) -> NDArray:
69
70
70
71
71
72
@normalizer
72
- def atleast_1d (* arys : UnpackedSeqArrayLike ):
73
- res = torch .atleast_1d (* arys )
74
- if len (res ) == 1 :
75
- return _helpers .array_from (res [0 ])
76
- else :
77
- return list (_helpers .tuple_arrays_from (res ))
73
+ def atleast_1d (* arys : UnpackedSeqArrayLike ) -> NDArrayOrSequence :
74
+ result = _impl .atleast_1d (* arys )
75
+ return result
78
76
79
77
80
78
@normalizer
81
- def atleast_2d (* arys : UnpackedSeqArrayLike ):
82
- res = torch .atleast_2d (* arys )
83
- if len (res ) == 1 :
84
- return _helpers .array_from (res [0 ])
85
- else :
86
- return list (_helpers .tuple_arrays_from (res ))
79
+ def atleast_2d (* arys : UnpackedSeqArrayLike ) -> NDArrayOrSequence :
80
+ result = _impl .atleast_2d (* arys )
81
+ return result
87
82
88
83
89
84
@normalizer
90
- def atleast_3d (* arys : UnpackedSeqArrayLike ):
91
- res = torch .atleast_3d (* arys )
92
- if len (res ) == 1 :
93
- return _helpers .array_from (res [0 ])
94
- else :
95
- return list (_helpers .tuple_arrays_from (res ))
85
+ def atleast_3d (* arys : UnpackedSeqArrayLike ) -> NDArrayOrSequence :
86
+ result = _impl .atleast_3d (* arys )
87
+ return result
96
88
97
89
98
90
def _concat_check (tup , dtype , out ):
@@ -175,33 +167,33 @@ def stack(
175
167
176
168
177
169
@normalizer
178
- def array_split (ary : ArrayLike , indices_or_sections , axis = 0 ):
170
+ def array_split (ary : ArrayLike , indices_or_sections , axis = 0 ) -> tuple [ NDArray ] :
179
171
result = _impl .split_helper (ary , indices_or_sections , axis )
180
- return _helpers . tuple_arrays_from ( result )
172
+ return result
181
173
182
174
183
175
@normalizer
184
- def split (ary : ArrayLike , indices_or_sections , axis = 0 ):
176
+ def split (ary : ArrayLike , indices_or_sections , axis = 0 ) -> tuple [ NDArray ] :
185
177
result = _impl .split_helper (ary , indices_or_sections , axis , strict = True )
186
- return _helpers . tuple_arrays_from ( result )
178
+ return result
187
179
188
180
189
181
@normalizer
190
- def hsplit (ary : ArrayLike , indices_or_sections ):
182
+ def hsplit (ary : ArrayLike , indices_or_sections ) -> tuple [ NDArray ] :
191
183
result = _impl .hsplit (ary , indices_or_sections )
192
- return _helpers . tuple_arrays_from ( result )
184
+ return result
193
185
194
186
195
187
@normalizer
196
- def vsplit (ary : ArrayLike , indices_or_sections ):
188
+ def vsplit (ary : ArrayLike , indices_or_sections ) -> tuple [ NDArray ] :
197
189
result = _impl .vsplit (ary , indices_or_sections )
198
- return _helpers . tuple_arrays_from ( result )
190
+ return result
199
191
200
192
201
193
@normalizer
202
- def dsplit (ary : ArrayLike , indices_or_sections ):
194
+ def dsplit (ary : ArrayLike , indices_or_sections ) -> tuple [ NDArray ] :
203
195
result = _impl .dsplit (ary , indices_or_sections )
204
- return _helpers . tuple_arrays_from ( result )
196
+ return result
205
197
206
198
207
199
@normalizer
@@ -421,13 +413,9 @@ def where(
421
413
x : Optional [ArrayLike ] = None ,
422
414
y : Optional [ArrayLike ] = None ,
423
415
/ ,
424
- ):
416
+ ) -> NDArrayOrSequence :
425
417
result = _impl .where (condition , x , y )
426
- if isinstance (result , tuple ):
427
- # single-argument where(condition)
428
- return _helpers .tuple_arrays_from (result )
429
- else :
430
- return _helpers .array_from (result )
418
+ return result
431
419
432
420
433
421
###### module-level queries of object properties
@@ -496,10 +484,12 @@ def broadcast_to(array: ArrayLike, shape, subok: SubokLike = False) -> NDArray:
496
484
497
485
# YYY: pattern: tuple of arrays as input, tuple of arrays as output; cf nonzero
498
486
@normalizer
499
- def broadcast_arrays (* args : UnpackedSeqArrayLike , subok : SubokLike = False ):
487
+ def broadcast_arrays (
488
+ * args : UnpackedSeqArrayLike , subok : SubokLike = False
489
+ ) -> tuple [NDArray ]:
500
490
args = args [0 ] # undo the *args wrapping in normalizer
501
491
res = torch .broadcast_tensors (* args )
502
- return _helpers . tuple_arrays_from ( res )
492
+ return res
503
493
504
494
505
495
def unravel_index (indices , shape , order = "C" ):
@@ -524,11 +514,12 @@ def ravel_multi_index(multi_index, dims, mode="raise", order="C"):
524
514
525
515
526
516
@normalizer
527
- def meshgrid (* xi : UnpackedSeqArrayLike , copy = True , sparse = False , indexing = "xy" ):
517
+ def meshgrid (
518
+ * xi : UnpackedSeqArrayLike , copy = True , sparse = False , indexing = "xy"
519
+ ) -> list [NDArray ]:
528
520
xi = xi [0 ] # undo the *xi wrapping in normalizer
529
521
output = _impl .meshgrid (* xi , copy = copy , sparse = sparse , indexing = indexing )
530
- outp = _helpers .tuple_arrays_from (output )
531
- return list (outp ) # match numpy, return a list
522
+ return output
532
523
533
524
534
525
@normalizer
@@ -559,26 +550,26 @@ def triu(m: ArrayLike, k=0) -> NDArray:
559
550
return result
560
551
561
552
562
- def tril_indices (n , k = 0 , m = None ):
553
+ def tril_indices (n , k = 0 , m = None ) -> tuple [ NDArray ] :
563
554
result = _impl .tril_indices (n , k , m )
564
- return _helpers . tuple_arrays_from ( result )
555
+ return result
565
556
566
557
567
- def triu_indices (n , k = 0 , m = None ):
558
+ def triu_indices (n , k = 0 , m = None ) -> tuple [ NDArray ] :
568
559
result = _impl .triu_indices (n , k , m )
569
- return _helpers . tuple_arrays_from ( result )
560
+ return result
570
561
571
562
572
563
@normalizer
573
- def tril_indices_from (arr : ArrayLike , k = 0 ):
564
+ def tril_indices_from (arr : ArrayLike , k = 0 ) -> tuple [ NDArray ] :
574
565
result = _impl .tril_indices_from (arr , k )
575
- return _helpers . tuple_arrays_from ( result )
566
+ return result
576
567
577
568
578
569
@normalizer
579
- def triu_indices_from (arr : ArrayLike , k = 0 ):
570
+ def triu_indices_from (arr : ArrayLike , k = 0 ) -> tuple [ NDArray ] :
580
571
result = _impl .triu_indices_from (arr , k )
581
- return _helpers . tuple_arrays_from ( result )
572
+ return result
582
573
583
574
584
575
@normalizer
@@ -859,7 +850,7 @@ def unique(
859
850
axis = None ,
860
851
* ,
861
852
equal_nan = True ,
862
- ):
853
+ ) -> NDArrayOrSequence :
863
854
result = _impl .unique (
864
855
ar ,
865
856
return_index = return_index ,
@@ -868,11 +859,7 @@ def unique(
868
859
axis = axis ,
869
860
equal_nan = equal_nan ,
870
861
)
871
-
872
- if isinstance (result , tuple ):
873
- return _helpers .tuple_arrays_from (result )
874
- else :
875
- return _helpers .array_from (result )
862
+ return result
876
863
877
864
878
865
###### mapping from numpy API objects to wrappers from this module ######
0 commit comments