@@ -74,89 +74,103 @@ def copy(a, order="K", subok=False):
74
74
75
75
76
76
def atleast_1d (* arys ):
77
- res = torch .atleast_1d ([asarray (a ).get () for a in arys ])
77
+ tensors = _helpers .to_tensors (* arys )
78
+ res = torch .atleast_1d (tensors )
78
79
if len (res ) == 1 :
79
80
return asarray (res [0 ])
80
81
else :
81
82
return list (asarray (_ ) for _ in res )
82
83
83
84
84
85
def atleast_2d (* arys ):
85
- res = torch .atleast_2d ([asarray (a ).get () for a in arys ])
86
+ tensors = _helpers .to_tensors (* arys )
87
+ res = torch .atleast_2d (tensors )
86
88
if len (res ) == 1 :
87
89
return asarray (res [0 ])
88
90
else :
89
91
return list (asarray (_ ) for _ in res )
90
92
91
93
92
94
def atleast_3d (* arys ):
93
- res = torch .atleast_3d ([asarray (a ).get () for a in arys ])
95
+ tensors = _helpers .to_tensors (* arys )
96
+ res = torch .atleast_3d (tensors )
94
97
if len (res ) == 1 :
95
98
return asarray (res [0 ])
96
99
else :
97
100
return list (asarray (_ ) for _ in res )
98
101
99
102
103
+ def _concat_check (tup , dtype , out ):
104
+ """Check inputs in concatenate et al."""
105
+ if tup == ():
106
+ # XXX: RuntimeError in torch, ValueError in numpy
107
+ raise ValueError ("need at least one array to concatenate" )
108
+
109
+ if out is not None :
110
+ if not isinstance (out , ndarray ):
111
+ raise ValueError ("'out' must be an array" )
112
+
113
+ if dtype is not None :
114
+ # mimic numpy
115
+ raise TypeError (
116
+ "concatenate() only takes `out` or `dtype` as an "
117
+ "argument, but both were provided."
118
+ )
119
+
120
+
121
+ @_decorators .dtype_to_torch
122
+ def concatenate (ar_tuple , axis = 0 , out = None , dtype = None , casting = "same_kind" ):
123
+ _concat_check (ar_tuple , dtype , out = out )
124
+ tensors = _helpers .to_tensors (* ar_tuple )
125
+ result = _impl .concatenate (tensors , axis , out , dtype , casting )
126
+ return _helpers .result_or_out (result , out )
127
+
128
+
129
+ @_decorators .dtype_to_torch
100
130
def vstack (tup , * , dtype = None , casting = "same_kind" ):
101
- arrs = atleast_2d (* tup )
102
- if not isinstance ( arrs , list ):
103
- arrs = [ arrs ]
104
- return concatenate ( arrs , 0 , dtype = dtype , casting = casting )
131
+ tensors = _helpers . to_tensors (* tup )
132
+ _concat_check ( tensors , dtype , out = None )
133
+ result = _impl . vstack ( tensors , dtype = dtype , casting = casting )
134
+ return asarray ( result )
105
135
106
136
107
137
row_stack = vstack
108
138
109
139
140
+ @_decorators .dtype_to_torch
110
141
def hstack (tup , * , dtype = None , casting = "same_kind" ):
111
- arrs = atleast_1d (* tup )
112
- if not isinstance (arrs , list ):
113
- arrs = [arrs ]
114
- # As a special case, dimension 0 of 1-dimensional arrays is "horizontal"
115
- if arrs and arrs [0 ].ndim == 1 :
116
- return concatenate (arrs , 0 , dtype = dtype , casting = casting )
117
- else :
118
- return concatenate (arrs , 1 , dtype = dtype , casting = casting )
142
+ tensors = _helpers .to_tensors (* tup )
143
+ _concat_check (tensors , dtype , out = None )
144
+ result = _impl .hstack (tensors , dtype = dtype , casting = casting )
145
+ return asarray (result )
119
146
120
147
148
+ @_decorators .dtype_to_torch
121
149
def dstack (tup , * , dtype = None , casting = "same_kind" ):
122
150
# XXX: in numpy 1.24 dstack does not have dtype and casting keywords
123
151
# but {h,v}stack do. Hence add them here for consistency.
124
- arrs = atleast_3d (* tup )
125
- if not isinstance (arrs , list ):
126
- arrs = [arrs ]
127
- return concatenate (arrs , 2 , dtype = dtype , casting = casting )
152
+ tensors = _helpers .to_tensors (* tup )
153
+ result = _impl .dstack (tensors , dtype = dtype , casting = casting )
154
+ return asarray (result )
128
155
129
156
157
+ @_decorators .dtype_to_torch
130
158
def column_stack (tup , * , dtype = None , casting = "same_kind" ):
131
159
# XXX: in numpy 1.24 column_stack does not have dtype and casting keywords
132
160
# but row_stack does. (because row_stack is an alias for vstack, really).
133
161
# Hence add these keywords here for consistency.
134
- arrays = []
135
- for v in tup :
136
- arr = asarray (v )
137
- if arr .ndim < 2 :
138
- arr = array (arr , copy = False , ndmin = 2 ).T
139
- arrays .append (arr )
140
- return concatenate (arrays , 1 , dtype = dtype , casting = casting )
162
+ tensors = _helpers .to_tensors (* tup )
163
+ _concat_check (tensors , dtype , out = None )
164
+ result = _impl .column_stack (tensors , dtype = dtype , casting = casting )
165
+ return asarray (result )
141
166
142
167
168
+ @_decorators .dtype_to_torch
143
169
def stack (arrays , axis = 0 , out = None , * , dtype = None , casting = "same_kind" ):
144
- arrays = [asarray (arr ) for arr in arrays ]
145
- if not arrays :
146
- raise ValueError ("need at least one array to stack" )
147
-
148
- shapes = {arr .shape for arr in arrays }
149
- if len (shapes ) != 1 :
150
- raise ValueError ("all input arrays must have the same shape" )
151
-
152
- result_ndim = arrays [0 ].ndim + 1
153
- axis = _util .normalize_axis_index (axis , result_ndim )
154
-
155
- sl = (slice (None ),) * axis + (newaxis ,)
156
- expanded_arrays = [arr [sl ] for arr in arrays ]
157
- return concatenate (
158
- expanded_arrays , axis = axis , out = out , dtype = dtype , casting = casting
159
- )
170
+ tensors = _helpers .to_tensors (* arrays )
171
+ _concat_check (tensors , dtype , out = out )
172
+ result = _impl .stack (tensors , axis = axis , out = out , dtype = dtype , casting = casting )
173
+ return _helpers .result_or_out (result , out )
160
174
161
175
162
176
def array_split (ary , indices_or_sections , axis = 0 ):
@@ -471,27 +485,6 @@ def cov(
471
485
return asarray (result )
472
486
473
487
474
- @_decorators .dtype_to_torch
475
- def concatenate (ar_tuple , axis = 0 , out = None , dtype = None , casting = "same_kind" ):
476
- if ar_tuple == ():
477
- # XXX: RuntimeError in torch, ValueError in numpy
478
- raise ValueError ("need at least one array to concatenate" )
479
-
480
- if out is not None :
481
- if not isinstance (out , ndarray ):
482
- raise ValueError ("'out' must be an array" )
483
-
484
- if dtype is not None :
485
- # mimic numpy
486
- raise TypeError (
487
- "concatenate() only takes `out` or `dtype` as an "
488
- "argument, but both were provided."
489
- )
490
- tensors = _helpers .to_tensors (* ar_tuple )
491
- result = _impl .concatenate (tensors , axis , out , dtype , casting )
492
- return _helpers .result_or_out (result , out )
493
-
494
-
495
488
def bincount (x , / , weights = None , minlength = 0 ):
496
489
if not isinstance (x , ndarray ) and x == []:
497
490
# edge case allowed by numpy
0 commit comments