@@ -229,73 +229,66 @@ def diff(a_tensor, n=1, axis=-1, prepend_tensor=None, append_tensor=None):
229
229
# #### concatenate and relatives
230
230
231
231
232
- def concatenate (tensors , axis = 0 , out = None , dtype = None , casting = "same_kind" ):
233
- # np.concatenate ravels if axis=None
234
- tensors , axis = _util .axis_none_ravel (* tensors , axis = axis )
232
+ def _concat_cast_helper (tensors , out = None , dtype = None , casting = "same_kind" ):
233
+ """Figure out dtypes, cast if necessary."""
235
234
236
235
if out is not None or dtype is not None :
237
236
# figure out the type of the inputs and outputs
238
237
out_dtype = out .dtype .torch_dtype if dtype is None else dtype
238
+ else :
239
+ out_dtype = _dtypes_impl .result_type_impl ([t .dtype for t in tensors ])
240
+
241
+ # cast input arrays if necessary; do not broadcast them agains `out`
242
+ tensors = _util .cast_dont_broadcast (tensors , out_dtype , casting )
239
243
240
- # cast input arrays if necessary; do not broadcast them agains `out`
241
- tensors = _util .cast_dont_broadcast (tensors , out_dtype , casting )
244
+ return tensors
245
+
246
+
247
+ def concatenate (tensors , axis = 0 , out = None , dtype = None , casting = "same_kind" ):
248
+ # np.concatenate ravels if axis=None
249
+ tensors , axis = _util .axis_none_ravel (* tensors , axis = axis )
250
+ tensors = _concat_cast_helper (tensors , out , dtype , casting )
242
251
243
252
try :
244
253
result = torch .cat (tensors , axis )
245
- except (IndexError , RuntimeError ):
246
- raise _util .AxisError
254
+ except (IndexError , RuntimeError ) as e :
255
+ raise _util .AxisError ( * e . args )
247
256
248
257
return result
249
258
250
259
251
260
def stack (tensors , axis = 0 , out = None , * , dtype = None , casting = "same_kind" ):
252
- shapes = {t .shape for t in tensors }
253
- if len (shapes ) != 1 :
254
- raise ValueError ("all input arrays must have the same shape" )
255
-
261
+ tensors = _concat_cast_helper (tensors , dtype = dtype , casting = casting )
256
262
result_ndim = tensors [0 ].ndim + 1
257
263
axis = _util .normalize_axis_index (axis , result_ndim )
258
-
259
- sl = (slice (None ),) * axis + (None ,)
260
- expanded_tensors = [tensor [sl ] for tensor in tensors ]
261
- result = concatenate (
262
- expanded_tensors , axis = axis , out = out , dtype = dtype , casting = casting
263
- )
264
-
264
+ try :
265
+ result = torch .stack (tensors , axis = axis )
266
+ except RuntimeError as e :
267
+ raise ValueError (* e .args )
265
268
return result
266
269
267
270
268
- def column_stack (tensors_ , * , dtype = None , casting = "same_kind" ):
269
- tensors = []
270
- for t in tensors_ :
271
- if t .ndim < 2 :
272
- t = _util ._coerce_to_tensor (t , copy = False , ndmin = 2 ).mT
273
- tensors .append (t )
274
-
275
- result = concatenate (tensors , 1 , dtype = dtype , casting = casting )
271
+ def column_stack (tensors , * , dtype = None , casting = "same_kind" ):
272
+ tensors = _concat_cast_helper (tensors , dtype = dtype , casting = casting )
273
+ result = torch .column_stack (tensors )
276
274
return result
277
275
278
276
279
277
def dstack (tensors , * , dtype = None , casting = "same_kind" ):
280
- tensors = torch . atleast_3d (tensors )
281
- result = concatenate (tensors , 2 , dtype = dtype , casting = casting )
278
+ tensors = _concat_cast_helper (tensors , dtype = dtype , casting = casting )
279
+ result = torch . dstack (tensors )
282
280
return result
283
281
284
282
285
283
def hstack (tensors , * , dtype = None , casting = "same_kind" ):
286
- tensors = torch .atleast_1d (tensors )
287
-
288
- # As a special case, dimension 0 of 1-dimensional arrays is "horizontal"s
289
- if tensors and tensors [0 ].ndim == 1 :
290
- result = concatenate (tensors , 0 , dtype = dtype , casting = casting )
291
- else :
292
- result = concatenate (tensors , 1 , dtype = dtype , casting = casting )
284
+ tensors = _concat_cast_helper (tensors , dtype = dtype , casting = casting )
285
+ result = torch .hstack (tensors )
293
286
return result
294
287
295
288
296
289
def vstack (tensors , * , dtype = None , casting = "same_kind" ):
297
- tensors = torch . atleast_2d (tensors )
298
- result = concatenate (tensors , 0 , dtype = dtype , casting = casting )
290
+ tensors = _concat_cast_helper (tensors , dtype = dtype , casting = casting )
291
+ result = torch . vstack (tensors )
299
292
return result
300
293
301
294
0 commit comments