17
17
DTypeLike ,
18
18
NDArray ,
19
19
SubokLike ,
20
- UnpackedSeqArrayLike ,
21
20
normalizer ,
22
21
)
23
22
@@ -71,30 +70,30 @@ def copy(a: ArrayLike, order="K", subok: SubokLike = False):
71
70
72
71
73
72
@normalizer
74
- def atleast_1d (* arys : UnpackedSeqArrayLike ):
73
+ def atleast_1d (* arys : ArrayLike ):
75
74
res = torch .atleast_1d (* arys )
76
- if len (res ) == 1 :
77
- return _helpers .array_from (res [0 ])
78
- else :
75
+ if isinstance (res , tuple ):
79
76
return list (_helpers .tuple_arrays_from (res ))
77
+ else :
78
+ return _helpers .array_from (res )
80
79
81
80
82
81
@normalizer
83
- def atleast_2d (* arys : UnpackedSeqArrayLike ):
82
+ def atleast_2d (* arys : ArrayLike ):
84
83
res = torch .atleast_2d (* arys )
85
- if len (res ) == 1 :
86
- return _helpers .array_from (res [0 ])
87
- else :
84
+ if isinstance (res , tuple ):
88
85
return list (_helpers .tuple_arrays_from (res ))
86
+ else :
87
+ return _helpers .array_from (res )
89
88
90
89
91
90
@normalizer
92
- def atleast_3d (* arys : UnpackedSeqArrayLike ):
91
+ def atleast_3d (* arys : ArrayLike ):
93
92
res = torch .atleast_3d (* arys )
94
- if len (res ) == 1 :
95
- return _helpers .array_from (res [0 ])
96
- else :
93
+ if isinstance (res , tuple ):
97
94
return list (_helpers .tuple_arrays_from (res ))
95
+ else :
96
+ return _helpers .array_from (res )
98
97
99
98
100
99
def _concat_check (tup , dtype , out ):
@@ -537,8 +536,7 @@ def broadcast_to(array: ArrayLike, shape, subok: SubokLike = False):
537
536
538
537
# YYY: pattern: tuple of arrays as input, tuple of arrays as output; cf nonzero
539
538
@normalizer
540
- def broadcast_arrays (* args : UnpackedSeqArrayLike , subok : SubokLike = False ):
541
- args = args [0 ] # undo the *args wrapping in normalizer
539
+ def broadcast_arrays (* args : ArrayLike , subok : SubokLike = False ):
542
540
res = torch .broadcast_tensors (* args )
543
541
return _helpers .tuple_arrays_from (res )
544
542
@@ -565,8 +563,7 @@ def ravel_multi_index(multi_index, dims, mode="raise", order="C"):
565
563
566
564
567
565
@normalizer
568
- def meshgrid (* xi : UnpackedSeqArrayLike , copy = True , sparse = False , indexing = "xy" ):
569
- xi = xi [0 ] # undo the *xi wrapping in normalizer
566
+ def meshgrid (* xi : ArrayLike , copy = True , sparse = False , indexing = "xy" ):
570
567
output = _impl .meshgrid (* xi , copy = copy , sparse = sparse , indexing = indexing )
571
568
outp = _helpers .tuple_arrays_from (output )
572
569
return list (outp ) # match numpy, return a list
0 commit comments