Skip to content

Commit 6744787

Browse files
committed
normalize Optional[ArrayLike] via annotations
1 parent 532c1d7 commit 6744787

File tree

2 files changed

+11
-10
lines changed

2 files changed

+11
-10
lines changed

torch_np/_detail/implementations.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -261,12 +261,6 @@ def dsplit(tensor, indices_or_sections):
261261

262262

263263
def clip(tensor, t_min, t_max):
264-
if t_min is not None:
265-
t_min = torch.broadcast_to(t_min, tensor.shape)
266-
267-
if t_max is not None:
268-
t_max = torch.broadcast_to(t_max, tensor.shape)
269-
270264
if t_min is None and t_max is None:
271265
raise ValueError("One of max or min must be given")
272266

torch_np/_funcs.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import typing
2+
from typing import Optional
23

34
import torch
45

@@ -23,6 +24,12 @@ def normalize_array_like(x, name=None):
2324
return tensor
2425

2526

27+
def normalize_optional_array_like(x, name=None):
28+
# This explicit normalizer is needed because otherwise normalize_array_like
29+
# does not run for a parameter annotated as Optional[ArrayLike]
30+
return None if x is None else normalize_array_like(x, name)
31+
32+
2633
def normalize_dtype(dtype, name=None):
2734
# cf _decorators.dtype_to_torch
2835
torch_dtype = None
@@ -39,6 +46,7 @@ def normalize_subok_like(arg, name):
3946

4047
normalizers = {
4148
ArrayLike: normalize_array_like,
49+
Optional[ArrayLike]: normalize_optional_array_like,
4250
DTypeLike: normalize_dtype,
4351
SubokLike: normalize_subok_like,
4452
}
@@ -121,12 +129,11 @@ def argwhere(a):
121129
return _helpers.array_from(result)
122130

123131

124-
def clip(a, min=None, max=None, out=None):
132+
@normalizer
133+
def clip(a : ArrayLike, min : Optional[ArrayLike]=None, max : Optional[ArrayLike]=None, out=None):
125134
# np.clip requires both a_min and a_max not None, while ndarray.clip allows
126135
# one of them to be None. Follow the more lax version.
127-
# Also min/max as arg names: follow numpy naming.
128-
tensor, t_min, t_max = _helpers.to_tensors_or_none(a, min, max)
129-
result = _impl.clip(tensor, t_min, t_max)
136+
result = _impl.clip(a, min, max)
130137
return _helpers.result_or_out(result, out)
131138

132139

0 commit comments

Comments
 (0)