@@ -19,15 +19,38 @@ def wrapped(x1, x2, /, out=None, *, where=True,
19
19
# XXX: dtype=... parameter is silently ignored
20
20
21
21
arrays = (x1 , x2 )
22
- x1_tensor , x2_tensor = _helpers .cast_and_broadcast (arrays , out , casting )
22
+ tensors = _helpers .cast_and_broadcast (arrays , out , casting )
23
23
24
- result = torch_func (x1_tensor , x2_tensor )
24
+ result = torch_func (* tensors )
25
25
26
26
return _helpers .result_or_out (result , out )
27
27
return wrapped
28
28
29
29
30
- # the list is autogenerated, cf autogen/gen_ufunc_2.py
30
+ def deco_unary_ufunc (torch_func ):
31
+ # TODO: deduplicate with `deco_binary_ufunc` above. Need to figure out the
32
+ # effect of the `dtype` parameter, does it differ between unary and binary ufuncs.
33
+ def wrapped (x1 , / , out = None , * , where = True ,
34
+ casting = 'same_kind' , order = 'K' , dtype = None , subok = False , ** kwds ):
35
+ _util .subok_not_ok (subok = subok )
36
+ if order != 'K' or not where :
37
+ raise NotImplementedError
38
+
39
+ # XXX: dtype=... parameter is silently ignored
40
+
41
+ arrays = (x1 , )
42
+ tensors = _helpers .cast_and_broadcast (arrays , out , casting )
43
+
44
+ result = torch_func (* tensors )
45
+
46
+ return _helpers .result_or_out (result , out )
47
+ return wrapped
48
+
49
+
50
+
51
+
52
+ # binary ufuncs: the list is autogenerated, cf autogen/gen_ufunc_2.py
53
+ # And edited manually! np.equal <--> torch.eq, not torch.equal
31
54
add = deco_binary_ufunc (torch .add )
32
55
arctan2 = deco_binary_ufunc (torch .arctan2 )
33
56
bitwise_and = deco_binary_ufunc (torch .bitwise_and )
@@ -69,3 +92,59 @@ def wrapped(x1, x2, /, out=None, *, where=True,
69
92
subtract = deco_binary_ufunc (torch .subtract )
70
93
divide = deco_binary_ufunc (torch .divide )
71
94
95
+
96
+
97
+ # unary ufuncs: the list is autogenerated, cf autogen/gen_ufunc_2.py
98
+ absolute = deco_unary_ufunc (torch .absolute )
99
+ #absolute = deco_unary_ufunc(torch.absolute)
100
+ arccos = deco_unary_ufunc (torch .arccos )
101
+ arccosh = deco_unary_ufunc (torch .arccosh )
102
+ arcsin = deco_unary_ufunc (torch .arcsin )
103
+ arcsinh = deco_unary_ufunc (torch .arcsinh )
104
+ arctan = deco_unary_ufunc (torch .arctan )
105
+ arctanh = deco_unary_ufunc (torch .arctanh )
106
+ ceil = deco_unary_ufunc (torch .ceil )
107
+ conjugate = deco_unary_ufunc (torch .conj_physical )
108
+ #conjugate = deco_unary_ufunc(torch.conj_physical)
109
+ cos = deco_unary_ufunc (torch .cos )
110
+ cosh = deco_unary_ufunc (torch .cosh )
111
+ deg2rad = deco_unary_ufunc (torch .deg2rad )
112
+ degrees = deco_unary_ufunc (torch .rad2deg )
113
+ exp = deco_unary_ufunc (torch .exp )
114
+ exp2 = deco_unary_ufunc (torch .exp2 )
115
+ expm1 = deco_unary_ufunc (torch .expm1 )
116
+ fabs = deco_unary_ufunc (torch .absolute )
117
+ floor = deco_unary_ufunc (torch .floor )
118
+ isfinite = deco_unary_ufunc (torch .isfinite )
119
+ isinf = deco_unary_ufunc (torch .isinf )
120
+ isnan = deco_unary_ufunc (torch .isnan )
121
+ log = deco_unary_ufunc (torch .log )
122
+ log10 = deco_unary_ufunc (torch .log10 )
123
+ log1p = deco_unary_ufunc (torch .log1p )
124
+ log2 = deco_unary_ufunc (torch .log2 )
125
+ logical_not = deco_unary_ufunc (torch .logical_not )
126
+ negative = deco_unary_ufunc (torch .negative )
127
+ rad2deg = deco_unary_ufunc (torch .rad2deg )
128
+ radians = deco_unary_ufunc (torch .deg2rad )
129
+ reciprocal = deco_unary_ufunc (torch .reciprocal )
130
+ rint = deco_unary_ufunc (torch .round )
131
+ sign = deco_unary_ufunc (torch .sign )
132
+ signbit = deco_unary_ufunc (torch .signbit )
133
+ sin = deco_unary_ufunc (torch .sin )
134
+ sinh = deco_unary_ufunc (torch .sinh )
135
+ sqrt = deco_unary_ufunc (torch .sqrt )
136
+ square = deco_unary_ufunc (torch .square )
137
+ tan = deco_unary_ufunc (torch .tan )
138
+ tanh = deco_unary_ufunc (torch .tanh )
139
+ trunc = deco_unary_ufunc (torch .trunc )
140
+
141
+ # special cases: torch does not export these names
142
+ def _cbrt (x ):
143
+ return torch .pow (x , 1 / 3 )
144
+
145
+ def _positive (x ):
146
+ return + x
147
+
148
+ cbrt = deco_unary_ufunc (_cbrt )
149
+ positive = deco_unary_ufunc (_positive )
150
+
0 commit comments