Skip to content

Commit af97984

Browse files
committed
MAINT: split *stack/concat family
1 parent 5ab9035 commit af97984

File tree

2 files changed

+105
-63
lines changed

2 files changed

+105
-63
lines changed

torch_np/_detail/implementations.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,55 @@ def concatenate(tensors, axis=0, out=None, dtype=None, casting="same_kind"):
246246
return result
247247

248248

249+
def stack(tensors, axis=0, out=None, *, dtype=None, casting="same_kind"):
250+
shapes = {t.shape for t in tensors}
251+
if len(shapes) != 1:
252+
raise ValueError("all input arrays must have the same shape")
253+
254+
result_ndim = tensors[0].ndim + 1
255+
axis = _util.normalize_axis_index(axis, result_ndim)
256+
257+
sl = (slice(None),) * axis + (None,)
258+
expanded_tensors = [tensor[sl] for tensor in tensors]
259+
result = concatenate(expanded_tensors, axis=axis, out=out, dtype=dtype, casting=casting)
260+
261+
return result
262+
263+
264+
def column_stack(tensors_, *, dtype=None, casting="same_kind"):
265+
tensors = []
266+
for t in tensors_:
267+
if t.ndim < 2:
268+
t = _util._coerce_to_tensor(t, copy=False, ndmin=2).mT
269+
tensors.append(t)
270+
271+
result = concatenate(tensors, 1, dtype=dtype, casting=casting)
272+
return result
273+
274+
275+
def dstack(tensors, *, dtype=None, casting="same_kind"):
276+
tensors = torch.atleast_3d(tensors)
277+
result = concatenate(tensors, 2, dtype=dtype, casting=casting)
278+
return result
279+
280+
281+
def hstack(tensors, *, dtype=None, casting="same_kind"):
282+
tensors = torch.atleast_1d(tensors)
283+
284+
# As a special case, dimension 0 of 1-dimensional arrays is "horizontal"s
285+
if tensors and tensors[0].ndim == 1:
286+
result = concatenate(tensors, 0, dtype=dtype, casting=casting)
287+
else:
288+
result = concatenate(tensors, 1, dtype=dtype, casting=casting)
289+
return result
290+
291+
292+
def vstack(tensors, *, dtype=None, casting="same_kind"):
293+
tensors = torch.atleast_2d(tensors)
294+
result = concatenate(tensors, 0, dtype=dtype, casting=casting)
295+
return result
296+
297+
249298
# #### cov & corrcoef
250299

251300

torch_np/_wrapper.py

Lines changed: 56 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -74,89 +74,103 @@ def copy(a, order="K", subok=False):
7474

7575

7676
def atleast_1d(*arys):
77-
res = torch.atleast_1d([asarray(a).get() for a in arys])
77+
tensors = _helpers.to_tensors(*arys)
78+
res = torch.atleast_1d(tensors)
7879
if len(res) == 1:
7980
return asarray(res[0])
8081
else:
8182
return list(asarray(_) for _ in res)
8283

8384

8485
def atleast_2d(*arys):
85-
res = torch.atleast_2d([asarray(a).get() for a in arys])
86+
tensors = _helpers.to_tensors(*arys)
87+
res = torch.atleast_2d(tensors)
8688
if len(res) == 1:
8789
return asarray(res[0])
8890
else:
8991
return list(asarray(_) for _ in res)
9092

9193

9294
def atleast_3d(*arys):
93-
res = torch.atleast_3d([asarray(a).get() for a in arys])
95+
tensors = _helpers.to_tensors(*arys)
96+
res = torch.atleast_3d(tensors)
9497
if len(res) == 1:
9598
return asarray(res[0])
9699
else:
97100
return list(asarray(_) for _ in res)
98101

99102

103+
def _concat_check(tup, dtype, out):
104+
"""Check inputs in concatenate et al."""
105+
if tup == ():
106+
# XXX: RuntimeError in torch, ValueError in numpy
107+
raise ValueError("need at least one array to concatenate")
108+
109+
if out is not None:
110+
if not isinstance(out, ndarray):
111+
raise ValueError("'out' must be an array")
112+
113+
if dtype is not None:
114+
# mimic numpy
115+
raise TypeError(
116+
"concatenate() only takes `out` or `dtype` as an "
117+
"argument, but both were provided."
118+
)
119+
120+
121+
@_decorators.dtype_to_torch
122+
def concatenate(ar_tuple, axis=0, out=None, dtype=None, casting="same_kind"):
123+
_concat_check(ar_tuple, dtype, out=out)
124+
tensors = _helpers.to_tensors(*ar_tuple)
125+
result = _impl.concatenate(tensors, axis, out, dtype, casting)
126+
return _helpers.result_or_out(result, out)
127+
128+
129+
@_decorators.dtype_to_torch
100130
def vstack(tup, *, dtype=None, casting="same_kind"):
101-
arrs = atleast_2d(*tup)
102-
if not isinstance(arrs, list):
103-
arrs = [arrs]
104-
return concatenate(arrs, 0, dtype=dtype, casting=casting)
131+
tensors = _helpers.to_tensors(*tup)
132+
_concat_check(tensors, dtype, out=None)
133+
result = _impl.vstack(tensors, dtype=dtype, casting=casting)
134+
return asarray(result)
105135

106136

107137
row_stack = vstack
108138

109139

140+
@_decorators.dtype_to_torch
110141
def hstack(tup, *, dtype=None, casting="same_kind"):
111-
arrs = atleast_1d(*tup)
112-
if not isinstance(arrs, list):
113-
arrs = [arrs]
114-
# As a special case, dimension 0 of 1-dimensional arrays is "horizontal"
115-
if arrs and arrs[0].ndim == 1:
116-
return concatenate(arrs, 0, dtype=dtype, casting=casting)
117-
else:
118-
return concatenate(arrs, 1, dtype=dtype, casting=casting)
142+
tensors = _helpers.to_tensors(*tup)
143+
_concat_check(tensors, dtype, out=None)
144+
result = _impl.hstack(tensors, dtype=dtype, casting=casting)
145+
return asarray(result)
119146

120147

148+
@_decorators.dtype_to_torch
121149
def dstack(tup, *, dtype=None, casting="same_kind"):
122150
# XXX: in numpy 1.24 dstack does not have dtype and casting keywords
123151
# but {h,v}stack do. Hence add them here for consistency.
124-
arrs = atleast_3d(*tup)
125-
if not isinstance(arrs, list):
126-
arrs = [arrs]
127-
return concatenate(arrs, 2, dtype=dtype, casting=casting)
152+
tensors = _helpers.to_tensors(*tup)
153+
result = _impl.dstack(tensors, dtype=dtype, casting=casting)
154+
return asarray(result)
128155

129156

157+
@_decorators.dtype_to_torch
130158
def column_stack(tup, *, dtype=None, casting="same_kind"):
131159
# XXX: in numpy 1.24 column_stack does not have dtype and casting keywords
132160
# but row_stack does. (because row_stack is an alias for vstack, really).
133161
# Hence add these keywords here for consistency.
134-
arrays = []
135-
for v in tup:
136-
arr = asarray(v)
137-
if arr.ndim < 2:
138-
arr = array(arr, copy=False, ndmin=2).T
139-
arrays.append(arr)
140-
return concatenate(arrays, 1, dtype=dtype, casting=casting)
162+
tensors = _helpers.to_tensors(*tup)
163+
_concat_check(tensors, dtype, out=None)
164+
result = _impl.column_stack(tensors, dtype=dtype, casting=casting)
165+
return asarray(result)
141166

142167

168+
@_decorators.dtype_to_torch
143169
def stack(arrays, axis=0, out=None, *, dtype=None, casting="same_kind"):
144-
arrays = [asarray(arr) for arr in arrays]
145-
if not arrays:
146-
raise ValueError("need at least one array to stack")
147-
148-
shapes = {arr.shape for arr in arrays}
149-
if len(shapes) != 1:
150-
raise ValueError("all input arrays must have the same shape")
151-
152-
result_ndim = arrays[0].ndim + 1
153-
axis = _util.normalize_axis_index(axis, result_ndim)
154-
155-
sl = (slice(None),) * axis + (newaxis,)
156-
expanded_arrays = [arr[sl] for arr in arrays]
157-
return concatenate(
158-
expanded_arrays, axis=axis, out=out, dtype=dtype, casting=casting
159-
)
170+
tensors = _helpers.to_tensors(*arrays)
171+
_concat_check(tensors, dtype, out=out)
172+
result = _impl.stack(tensors, axis=axis, out=out, dtype=dtype, casting=casting)
173+
return _helpers.result_or_out(result, out)
160174

161175

162176
def array_split(ary, indices_or_sections, axis=0):
@@ -471,27 +485,6 @@ def cov(
471485
return asarray(result)
472486

473487

474-
@_decorators.dtype_to_torch
475-
def concatenate(ar_tuple, axis=0, out=None, dtype=None, casting="same_kind"):
476-
if ar_tuple == ():
477-
# XXX: RuntimeError in torch, ValueError in numpy
478-
raise ValueError("need at least one array to concatenate")
479-
480-
if out is not None:
481-
if not isinstance(out, ndarray):
482-
raise ValueError("'out' must be an array")
483-
484-
if dtype is not None:
485-
# mimic numpy
486-
raise TypeError(
487-
"concatenate() only takes `out` or `dtype` as an "
488-
"argument, but both were provided."
489-
)
490-
tensors = _helpers.to_tensors(*ar_tuple)
491-
result = _impl.concatenate(tensors, axis, out, dtype, casting)
492-
return _helpers.result_or_out(result, out)
493-
494-
495488
def bincount(x, /, weights=None, minlength=0):
496489
if not isinstance(x, ndarray) and x == []:
497490
# edge case allowed by numpy

0 commit comments

Comments
 (0)