Skip to content

Commit c75a8d5

Browse files
committed
ENH: normalize tuples of array_likes
1 parent bc9448e commit c75a8d5

File tree

2 files changed

+25
-25
lines changed

2 files changed

+25
-25
lines changed

torch_np/_funcs.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import typing
2-
from typing import Optional
2+
from typing import Optional, Sequence
33

44
import torch
55

@@ -30,6 +30,11 @@ def normalize_optional_array_like(x, name=None):
3030
return None if x is None else normalize_array_like(x, name)
3131

3232

33+
def normalize_seq_array_like(x, name=None):
34+
tensors = _helpers.to_tensors(*x)
35+
return tensors
36+
37+
3338
def normalize_dtype(dtype, name=None):
3439
# cf _decorators.dtype_to_torch
3540
torch_dtype = None
@@ -47,6 +52,7 @@ def normalize_subok_like(arg, name):
4752
normalizers = {
4853
ArrayLike: normalize_array_like,
4954
Optional[ArrayLike]: normalize_optional_array_like,
55+
Sequence[ArrayLike]: normalize_seq_array_like,
5056
DTypeLike: normalize_dtype,
5157
SubokLike: normalize_subok_like,
5258
}

torch_np/_wrapper.py

Lines changed: 18 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -110,60 +110,54 @@ def _concat_check(tup, dtype, out):
110110

111111
### XXX: order the imports DAG
112112
from . _funcs import normalizer, DTypeLike, ArrayLike
113-
from typing import Optional
113+
from typing import Optional, Sequence
114114

115115
@normalizer
116-
def concatenate(ar_tuple, axis=0, out=None, dtype: DTypeLike=None, casting="same_kind"):
117-
tensors = _helpers.to_tensors(*ar_tuple)
118-
_concat_check(tensors, dtype, out=out)
119-
result = _impl.concatenate(tensors, axis, out, dtype, casting)
116+
def concatenate(ar_tuple : Sequence[ArrayLike], axis=0, out=None, dtype: DTypeLike=None, casting="same_kind"):
117+
_concat_check(ar_tuple, dtype, out=out)
118+
result = _impl.concatenate(ar_tuple, axis, out, dtype, casting)
120119
return _helpers.result_or_out(result, out)
121120

122121

123122
@normalizer
124-
def vstack(tup, *, dtype : DTypeLike=None, casting="same_kind"):
125-
tensors = _helpers.to_tensors(*tup)
126-
_concat_check(tensors, dtype, out=None)
127-
result = _impl.vstack(tensors, dtype=dtype, casting=casting)
123+
def vstack(tup : Sequence[ArrayLike], *, dtype : DTypeLike=None, casting="same_kind"):
124+
_concat_check(tup, dtype, out=None)
125+
result = _impl.vstack(tup, dtype=dtype, casting=casting)
128126
return asarray(result)
129127

130128

131129
row_stack = vstack
132130

133131

134132
@normalizer
135-
def hstack(tup, *, dtype : DTypeLike=None, casting="same_kind"):
136-
tensors = _helpers.to_tensors(*tup)
137-
_concat_check(tensors, dtype, out=None)
138-
result = _impl.hstack(tensors, dtype=dtype, casting=casting)
133+
def hstack(tup: Sequence[ArrayLike], *, dtype : DTypeLike=None, casting="same_kind"):
134+
_concat_check(tup, dtype, out=None)
135+
result = _impl.hstack(tup, dtype=dtype, casting=casting)
139136
return asarray(result)
140137

141138

142139
@normalizer
143-
def dstack(tup, *, dtype : DTypeLike=None, casting="same_kind"):
140+
def dstack(tup: Sequence[ArrayLike], *, dtype : DTypeLike=None, casting="same_kind"):
144141
# XXX: in numpy 1.24 dstack does not have dtype and casting keywords
145142
# but {h,v}stack do. Hence add them here for consistency.
146-
tensors = _helpers.to_tensors(*tup)
147-
result = _impl.dstack(tensors, dtype=dtype, casting=casting)
143+
result = _impl.dstack(tup, dtype=dtype, casting=casting)
148144
return asarray(result)
149145

150146

151147
@normalizer
152-
def column_stack(tup, *, dtype : DTypeLike=None, casting="same_kind"):
148+
def column_stack(tup : Sequence[ArrayLike], *, dtype : DTypeLike=None, casting="same_kind"):
153149
# XXX: in numpy 1.24 column_stack does not have dtype and casting keywords
154150
# but row_stack does. (because row_stack is an alias for vstack, really).
155151
# Hence add these keywords here for consistency.
156-
tensors = _helpers.to_tensors(*tup)
157-
_concat_check(tensors, dtype, out=None)
158-
result = _impl.column_stack(tensors, dtype=dtype, casting=casting)
152+
_concat_check(tup, dtype, out=None)
153+
result = _impl.column_stack(tup, dtype=dtype, casting=casting)
159154
return asarray(result)
160155

161156

162157
@normalizer
163-
def stack(arrays, axis=0, out=None, *, dtype : DTypeLike=None, casting="same_kind"):
164-
tensors = _helpers.to_tensors(*arrays)
165-
_concat_check(tensors, dtype, out=out)
166-
result = _impl.stack(tensors, axis=axis, out=out, dtype=dtype, casting=casting)
158+
def stack(arrays : Sequence[ArrayLike], axis=0, out=None, *, dtype : DTypeLike=None, casting="same_kind"):
159+
_concat_check(arrays, dtype, out=out)
160+
result = _impl.stack(arrays, axis=axis, out=out, dtype=dtype, casting=casting)
167161
return _helpers.result_or_out(result, out)
168162

169163

0 commit comments

Comments
 (0)