Skip to content

Commit 6970fd3

Browse files
committed
MAINT: reimplement {v,d,h, column_}stack through their pytorch equivalents
1 parent 9f11675 commit 6970fd3

File tree

1 file changed

+30
-37
lines changed

1 file changed

+30
-37
lines changed

torch_np/_detail/implementations.py

Lines changed: 30 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -229,73 +229,66 @@ def diff(a_tensor, n=1, axis=-1, prepend_tensor=None, append_tensor=None):
229229
# #### concatenate and relatives
230230

231231

232-
def concatenate(tensors, axis=0, out=None, dtype=None, casting="same_kind"):
233-
# np.concatenate ravels if axis=None
234-
tensors, axis = _util.axis_none_ravel(*tensors, axis=axis)
232+
def _concat_cast_helper(tensors, out=None, dtype=None, casting="same_kind"):
233+
"""Figure out dtypes, cast if necessary."""
235234

236235
if out is not None or dtype is not None:
237236
# figure out the type of the inputs and outputs
238237
out_dtype = out.dtype.torch_dtype if dtype is None else dtype
238+
else:
239+
out_dtype = _dtypes_impl.result_type_impl([t.dtype for t in tensors])
240+
241+
# cast input arrays if necessary; do not broadcast them agains `out`
242+
tensors = _util.cast_dont_broadcast(tensors, out_dtype, casting)
239243

240-
# cast input arrays if necessary; do not broadcast them agains `out`
241-
tensors = _util.cast_dont_broadcast(tensors, out_dtype, casting)
244+
return tensors
245+
246+
247+
def concatenate(tensors, axis=0, out=None, dtype=None, casting="same_kind"):
248+
# np.concatenate ravels if axis=None
249+
tensors, axis = _util.axis_none_ravel(*tensors, axis=axis)
250+
tensors = _concat_cast_helper(tensors, out, dtype, casting)
242251

243252
try:
244253
result = torch.cat(tensors, axis)
245-
except (IndexError, RuntimeError):
246-
raise _util.AxisError
254+
except (IndexError, RuntimeError) as e:
255+
raise _util.AxisError(*e.args)
247256

248257
return result
249258

250259

251260
def stack(tensors, axis=0, out=None, *, dtype=None, casting="same_kind"):
252-
shapes = {t.shape for t in tensors}
253-
if len(shapes) != 1:
254-
raise ValueError("all input arrays must have the same shape")
255-
261+
tensors = _concat_cast_helper(tensors, dtype=dtype, casting=casting)
256262
result_ndim = tensors[0].ndim + 1
257263
axis = _util.normalize_axis_index(axis, result_ndim)
258-
259-
sl = (slice(None),) * axis + (None,)
260-
expanded_tensors = [tensor[sl] for tensor in tensors]
261-
result = concatenate(
262-
expanded_tensors, axis=axis, out=out, dtype=dtype, casting=casting
263-
)
264-
264+
try:
265+
result = torch.stack(tensors, axis=axis)
266+
except RuntimeError as e:
267+
raise ValueError(*e.args)
265268
return result
266269

267270

268-
def column_stack(tensors_, *, dtype=None, casting="same_kind"):
269-
tensors = []
270-
for t in tensors_:
271-
if t.ndim < 2:
272-
t = _util._coerce_to_tensor(t, copy=False, ndmin=2).mT
273-
tensors.append(t)
274-
275-
result = concatenate(tensors, 1, dtype=dtype, casting=casting)
271+
def column_stack(tensors, *, dtype=None, casting="same_kind"):
272+
tensors = _concat_cast_helper(tensors, dtype=dtype, casting=casting)
273+
result = torch.column_stack(tensors)
276274
return result
277275

278276

279277
def dstack(tensors, *, dtype=None, casting="same_kind"):
280-
tensors = torch.atleast_3d(tensors)
281-
result = concatenate(tensors, 2, dtype=dtype, casting=casting)
278+
tensors = _concat_cast_helper(tensors, dtype=dtype, casting=casting)
279+
result = torch.dstack(tensors)
282280
return result
283281

284282

285283
def hstack(tensors, *, dtype=None, casting="same_kind"):
286-
tensors = torch.atleast_1d(tensors)
287-
288-
# As a special case, dimension 0 of 1-dimensional arrays is "horizontal"s
289-
if tensors and tensors[0].ndim == 1:
290-
result = concatenate(tensors, 0, dtype=dtype, casting=casting)
291-
else:
292-
result = concatenate(tensors, 1, dtype=dtype, casting=casting)
284+
tensors = _concat_cast_helper(tensors, dtype=dtype, casting=casting)
285+
result = torch.hstack(tensors)
293286
return result
294287

295288

296289
def vstack(tensors, *, dtype=None, casting="same_kind"):
297-
tensors = torch.atleast_2d(tensors)
298-
result = concatenate(tensors, 0, dtype=dtype, casting=casting)
290+
tensors = _concat_cast_helper(tensors, dtype=dtype, casting=casting)
291+
result = torch.vstack(tensors)
299292
return result
300293

301294

0 commit comments

Comments
 (0)