1
1
import typing
2
+ from typing import Optional
2
3
3
4
import torch
4
5
@@ -23,6 +24,12 @@ def normalize_array_like(x, name=None):
23
24
return tensor
24
25
25
26
27
+ def normalize_optional_array_like (x , name = None ):
28
+ # This explicit normalizer is needed because otherwise normalize_array_like
29
+ # does not run for a parameter annotated as Optional[ArrayLike]
30
+ return None if x is None else normalize_array_like (x , name )
31
+
32
+
26
33
def normalize_dtype (dtype , name = None ):
27
34
# cf _decorators.dtype_to_torch
28
35
torch_dtype = None
@@ -39,6 +46,7 @@ def normalize_subok_like(arg, name):
39
46
40
47
normalizers = {
41
48
ArrayLike : normalize_array_like ,
49
+ Optional [ArrayLike ]: normalize_optional_array_like ,
42
50
DTypeLike : normalize_dtype ,
43
51
SubokLike : normalize_subok_like ,
44
52
}
@@ -121,12 +129,11 @@ def argwhere(a):
121
129
return _helpers .array_from (result )
122
130
123
131
124
- def clip (a , min = None , max = None , out = None ):
132
+ @normalizer
133
+ def clip (a : ArrayLike , min : Optional [ArrayLike ]= None , max : Optional [ArrayLike ]= None , out = None ):
125
134
# np.clip requires both a_min and a_max not None, while ndarray.clip allows
126
135
# one of them to be None. Follow the more lax version.
127
- # Also min/max as arg names: follow numpy naming.
128
- tensor , t_min , t_max = _helpers .to_tensors_or_none (a , min , max )
129
- result = _impl .clip (tensor , t_min , t_max )
136
+ result = _impl .clip (a , min , max )
130
137
return _helpers .result_or_out (result , out )
131
138
132
139
0 commit comments