12
12
from ._detail import _dtypes_impl , _flips , _reductions , _util
13
13
from ._detail import implementations as _impl
14
14
from ._ndarray import array , asarray , maybe_set_base , ndarray
15
- from ._normalizations import (
16
- ArrayLike ,
17
- DTypeLike ,
18
- NDArray ,
19
- SubokLike ,
20
- UnpackedSeqArrayLike ,
21
- normalizer ,
22
- )
15
+ from ._normalizations import ArrayLike , DTypeLike , NDArray , SubokLike , normalizer
23
16
24
17
# Things to decide on (punt for now)
25
18
#
@@ -71,30 +64,30 @@ def copy(a: ArrayLike, order="K", subok: SubokLike = False):
71
64
72
65
73
66
@normalizer
74
- def atleast_1d (* arys : UnpackedSeqArrayLike ):
67
+ def atleast_1d (* arys : ArrayLike ):
75
68
res = torch .atleast_1d (* arys )
76
- if len (res ) == 1 :
77
- return _helpers .array_from (res [0 ])
78
- else :
69
+ if isinstance (res , tuple ):
79
70
return list (_helpers .tuple_arrays_from (res ))
71
+ else :
72
+ return _helpers .array_from (res )
80
73
81
74
82
75
@normalizer
83
- def atleast_2d (* arys : UnpackedSeqArrayLike ):
76
+ def atleast_2d (* arys : ArrayLike ):
84
77
res = torch .atleast_2d (* arys )
85
- if len (res ) == 1 :
86
- return _helpers .array_from (res [0 ])
87
- else :
78
+ if isinstance (res , tuple ):
88
79
return list (_helpers .tuple_arrays_from (res ))
80
+ else :
81
+ return _helpers .array_from (res )
89
82
90
83
91
84
@normalizer
92
- def atleast_3d (* arys : UnpackedSeqArrayLike ):
85
+ def atleast_3d (* arys : ArrayLike ):
93
86
res = torch .atleast_3d (* arys )
94
- if len (res ) == 1 :
95
- return _helpers .array_from (res [0 ])
96
- else :
87
+ if isinstance (res , tuple ):
97
88
return list (_helpers .tuple_arrays_from (res ))
89
+ else :
90
+ return _helpers .array_from (res )
98
91
99
92
100
93
def _concat_check (tup , dtype , out ):
@@ -537,8 +530,7 @@ def broadcast_to(array: ArrayLike, shape, subok: SubokLike = False):
537
530
538
531
# YYY: pattern: tuple of arrays as input, tuple of arrays as output; cf nonzero
539
532
@normalizer
540
- def broadcast_arrays (* args : UnpackedSeqArrayLike , subok : SubokLike = False ):
541
- args = args [0 ] # undo the *args wrapping in normalizer
533
+ def broadcast_arrays (* args : ArrayLike , subok : SubokLike = False ):
542
534
res = torch .broadcast_tensors (* args )
543
535
return _helpers .tuple_arrays_from (res )
544
536
@@ -565,8 +557,7 @@ def ravel_multi_index(multi_index, dims, mode="raise", order="C"):
565
557
566
558
567
559
@normalizer
568
- def meshgrid (* xi : UnpackedSeqArrayLike , copy = True , sparse = False , indexing = "xy" ):
569
- xi = xi [0 ] # undo the *xi wrapping in normalizer
560
+ def meshgrid (* xi : ArrayLike , copy = True , sparse = False , indexing = "xy" ):
570
561
output = _impl .meshgrid (* xi , copy = copy , sparse = sparse , indexing = indexing )
571
562
outp = _helpers .tuple_arrays_from (output )
572
563
return list (outp ) # match numpy, return a list
0 commit comments