3
3
import torch
4
4
5
5
from . import _binary_ufuncs_impl , _helpers , _unary_ufuncs_impl
6
+ from ._detail import _dtypes_impl , _util
6
7
from ._normalizations import ArrayLike , DTypeLike , NDArray , SubokLike , normalizer
7
8
9
+
10
+ def _ufunc_preprocess (tensors , where , casting , order , dtype , subok , signature , extobj ):
11
+ if order != "K" or not where or signature or extobj :
12
+ raise NotImplementedError
13
+
14
+ if dtype is None :
15
+ dtype = _dtypes_impl .result_type_impl ([t .dtype for t in tensors ])
16
+
17
+ tensors = _util .typecast_tensors (tensors , dtype , casting )
18
+
19
+ return tensors
20
+
21
+
22
+ def _ufunc_postprocess (result , out , casting ):
23
+ if out is not None :
24
+ (result ,) = _util .typecast_tensors ((result ,), out .dtype .torch_dtype , casting )
25
+ result = torch .broadcast_to (result , out .shape )
26
+ return result
27
+
28
+
8
29
# ############# Binary ufuncs ######################
9
30
10
31
_binary = [
@@ -35,16 +56,12 @@ def wrapped(
35
56
signature = None ,
36
57
extobj = None ,
37
58
):
38
- tensors = _helpers . ufunc_preprocess (
39
- (x1 , x2 ), out , where , casting , order , dtype , subok , signature , extobj
59
+ tensors = _ufunc_preprocess (
60
+ (x1 , x2 ), where , casting , order , dtype , subok , signature , extobj
40
61
)
41
- # now broadcast input tensors against the out=... array
42
- if out is not None :
43
- # XXX: need to filter out noop broadcasts if t.shape == out.shape?
44
- shape = out .shape
45
- tensors = tuple (torch .broadcast_to (t , shape ) for t in tensors )
46
-
47
62
result = torch_func (* tensors )
63
+
64
+ result = _ufunc_postprocess (result , out , casting )
48
65
return result
49
66
50
67
wrapped .__qualname__ = torch_func .__name__
@@ -54,8 +71,9 @@ def wrapped(
54
71
55
72
56
73
#
57
- # matmul is special in that its `out=...` array does not broadcast x1 and x2.
58
- # E.g. consider x1.shape = (5, 2) and x2.shape = (2, 3). Then `out.shape` is (5, 3).
74
+ # matmul's signature is _slightly_ different from other ufuncs:
75
+ # - no where=...
76
+ # - additional axis=..., axes=...
59
77
#
60
78
@normalizer
61
79
def matmul (
@@ -73,17 +91,21 @@ def matmul(
73
91
axes = None ,
74
92
axis = None ,
75
93
):
76
- tensors = _helpers . ufunc_preprocess (
77
- (x1 , x2 ), out , True , casting , order , dtype , subok , signature , extobj
94
+ tensors = _ufunc_preprocess (
95
+ (x1 , x2 ), True , casting , order , dtype , subok , signature , extobj
78
96
)
79
97
if axis is not None or axes is not None :
80
98
raise NotImplementedError
81
99
82
- # NB: do not broadcast input tensors against the out=... array
83
100
result = _binary_ufuncs_impl .matmul (* tensors )
101
+
102
+ result = _ufunc_postprocess (result , out , casting )
84
103
return result
85
104
86
105
106
+ #
107
+ # nin=2, nout=2
108
+ #
87
109
def divmod (
88
110
x1 : ArrayLike ,
89
111
x2 : ArrayLike ,
@@ -110,12 +132,14 @@ def divmod(
110
132
if out1 .shape != out2 .shape or out1 .dtype != out2 .dtype :
111
133
raise ValueError ("out1, out2 must be compatible" )
112
134
113
- tensors = _helpers . ufunc_preprocess (
114
- (x1 , x2 ), out , True , casting , order , dtype , subok , signature , extobj
135
+ tensors = _ufunc_preprocess (
136
+ (x1 , x2 ), True , casting , order , dtype , subok , signature , extobj
115
137
)
116
138
117
- result = _binary_ufuncs_impl .divmod (* tensors )
139
+ quot , rem = _binary_ufuncs_impl .divmod (* tensors )
118
140
141
+ quot = _ufunc_postprocess (quot , out1 , casting )
142
+ rem = _ufunc_postprocess (rem , out2 , casting )
119
143
return quot , rem
120
144
121
145
@@ -167,15 +191,11 @@ def wrapped(
167
191
signature = None ,
168
192
extobj = None ,
169
193
):
170
- tensors = _helpers . ufunc_preprocess (
171
- (x ,), out , where , casting , order , dtype , subok , signature , extobj
194
+ tensors = _ufunc_preprocess (
195
+ (x ,), where , casting , order , dtype , subok , signature , extobj
172
196
)
173
- # now broadcast the input tensor against the out=... array
174
- if out is not None :
175
- # XXX: need to filter out noop broadcasts if t.shape == out.shape?
176
- shape = out .shape
177
- tensors = tuple (torch .broadcast_to (t , shape ) for t in tensors )
178
197
result = torch_func (* tensors )
198
+ result = _ufunc_postprocess (result , out , casting )
179
199
return result
180
200
181
201
wrapped .__qualname__ = torch_func .__name__
0 commit comments