Skip to content

Commit 1dd4424

Browse files
committed
MAINT: merge {_unary,_binary}_ufuncs modules
1 parent fb45110 commit 1dd4424

9 files changed

+123
-121
lines changed

torch_np/__init__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
11
from . import linalg, random
2-
from ._binary_ufuncs import *
32
from ._detail._util import AxisError, UFuncTypeError
43
from ._dtypes import *
54
from ._funcs import *
65
from ._getlimits import errstate, finfo, iinfo
76
from ._ndarray import array, asarray, can_cast, ndarray, newaxis, result_type
8-
from ._unary_ufuncs import *
7+
from ._ufuncs import *
98

109
# from . import testing
1110

torch_np/_detail/_binary_ufuncs.py renamed to torch_np/_binary_ufuncs_impl.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
"""Export torch work functions for binary ufuncs, rename/tweak to match numpy.
2-
This listing is further exported to public symbols in the `torch_np/_binary_ufuncs.py` module.
2+
This listing is further exported to public symbols in the `torch_np/_ufuncs.py` module.
33
"""
44

55
import torch
@@ -42,7 +42,7 @@
4242
from torch import remainder as mod
4343
from torch import subtract, true_divide
4444

45-
from . import _dtypes_impl, _util
45+
from ._detail import _dtypes_impl, _util
4646

4747

4848
# work around torch limitations w.r.t. numpy

torch_np/_ndarray.py

Lines changed: 44 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import torch
44

5-
from . import _binary_ufuncs, _dtypes, _funcs, _funcs_impl, _helpers, _unary_ufuncs
5+
from . import _dtypes, _funcs, _funcs_impl, _helpers, _ufuncs
66
from ._detail import _dtypes_impl, _util
77

88
newaxis = None
@@ -176,24 +176,24 @@ def __str__(self):
176176
### comparisons ###
177177
def __eq__(self, other):
178178
try:
179-
return _binary_ufuncs.equal(self, other)
179+
return _ufuncs.equal(self, other)
180180
except (RuntimeError, TypeError):
181181
# Failed to convert other to array: definitely not equal.
182182
falsy = torch.full(self.shape, fill_value=False, dtype=bool)
183183
return asarray(falsy)
184184

185185
def __ne__(self, other):
186186
try:
187-
return _binary_ufuncs.not_equal(self, other)
187+
return _ufuncs.not_equal(self, other)
188188
except (RuntimeError, TypeError):
189189
# Failed to convert other to array: definitely not equal.
190190
falsy = torch.full(self.shape, fill_value=True, dtype=bool)
191191
return asarray(falsy)
192192

193-
__gt__ = _binary_ufuncs.greater
194-
__lt__ = _binary_ufuncs.less
195-
__ge__ = _binary_ufuncs.greater_equal
196-
__le__ = _binary_ufuncs.less_equal
193+
__gt__ = _ufuncs.greater
194+
__lt__ = _ufuncs.less
195+
__ge__ = _ufuncs.greater_equal
196+
__le__ = _ufuncs.less_equal
197197

198198
def __bool__(self):
199199
try:
@@ -239,107 +239,107 @@ def __len__(self):
239239
### arithmetic ###
240240

241241
# add, self + other
242-
__add__ = __radd__ = _binary_ufuncs.add
242+
__add__ = __radd__ = _ufuncs.add
243243

244244
def __iadd__(self, other):
245-
return _binary_ufuncs.add(self, other, out=self)
245+
return _ufuncs.add(self, other, out=self)
246246

247247
# sub, self - other
248-
__sub__ = _binary_ufuncs.subtract
248+
__sub__ = _ufuncs.subtract
249249

250250
# XXX: generate a function just for this? AND other non-commutative ops.
251251
def __rsub__(self, other):
252-
return _binary_ufuncs.subtract(other, self)
252+
return _ufuncs.subtract(other, self)
253253

254254
def __isub__(self, other):
255-
return _binary_ufuncs.subtract(self, other, out=self)
255+
return _ufuncs.subtract(self, other, out=self)
256256

257257
# mul, self * other
258-
__mul__ = __rmul__ = _binary_ufuncs.multiply
258+
__mul__ = __rmul__ = _ufuncs.multiply
259259

260260
def __imul__(self, other):
261-
return _binary_ufuncs.multiply(self, other, out=self)
261+
return _ufuncs.multiply(self, other, out=self)
262262

263263
# div, self / other
264-
__truediv__ = _binary_ufuncs.divide
264+
__truediv__ = _ufuncs.divide
265265

266266
def __rtruediv__(self, other):
267-
return _binary_ufuncs.divide(other, self)
267+
return _ufuncs.divide(other, self)
268268

269269
def __itruediv__(self, other):
270-
return _binary_ufuncs.divide(self, other, out=self)
270+
return _ufuncs.divide(self, other, out=self)
271271

272272
# floordiv, self // other
273-
__floordiv__ = _binary_ufuncs.floor_divide
273+
__floordiv__ = _ufuncs.floor_divide
274274

275275
def __rfloordiv__(self, other):
276-
return _binary_ufuncs.floor_divide(other, self)
276+
return _ufuncs.floor_divide(other, self)
277277

278278
def __ifloordiv__(self, other):
279-
return _binary_ufuncs.floor_divide(self, other, out=self)
279+
return _ufuncs.floor_divide(self, other, out=self)
280280

281-
__divmod__ = _binary_ufuncs.divmod
281+
__divmod__ = _ufuncs.divmod
282282

283283
# power, self**exponent
284-
__pow__ = __rpow__ = _binary_ufuncs.float_power
284+
__pow__ = __rpow__ = _ufuncs.float_power
285285

286286
def __rpow__(self, exponent):
287-
return _binary_ufuncs.float_power(exponent, self)
287+
return _ufuncs.float_power(exponent, self)
288288

289289
def __ipow__(self, exponent):
290-
return _binary_ufuncs.float_power(self, exponent, out=self)
290+
return _ufuncs.float_power(self, exponent, out=self)
291291

292292
# remainder, self % other
293-
__mod__ = __rmod__ = _binary_ufuncs.remainder
293+
__mod__ = __rmod__ = _ufuncs.remainder
294294

295295
def __imod__(self, other):
296-
return _binary_ufuncs.remainder(self, other, out=self)
296+
return _ufuncs.remainder(self, other, out=self)
297297

298298
# bitwise ops
299299
# and, self & other
300-
__and__ = __rand__ = _binary_ufuncs.bitwise_and
300+
__and__ = __rand__ = _ufuncs.bitwise_and
301301

302302
def __iand__(self, other):
303-
return _binary_ufuncs.bitwise_and(self, other, out=self)
303+
return _ufuncs.bitwise_and(self, other, out=self)
304304

305305
# or, self | other
306-
__or__ = __ror__ = _binary_ufuncs.bitwise_or
306+
__or__ = __ror__ = _ufuncs.bitwise_or
307307

308308
def __ior__(self, other):
309-
return _binary_ufuncs.bitwise_or(self, other, out=self)
309+
return _ufuncs.bitwise_or(self, other, out=self)
310310

311311
# xor, self ^ other
312-
__xor__ = __rxor__ = _binary_ufuncs.bitwise_xor
312+
__xor__ = __rxor__ = _ufuncs.bitwise_xor
313313

314314
def __ixor__(self, other):
315-
return _binary_ufuncs.bitwise_xor(self, other, out=self)
315+
return _ufuncs.bitwise_xor(self, other, out=self)
316316

317317
# bit shifts
318-
__lshift__ = __rlshift__ = _binary_ufuncs.left_shift
318+
__lshift__ = __rlshift__ = _ufuncs.left_shift
319319

320320
def __ilshift__(self, other):
321-
return _binary_ufuncs.left_shift(self, other, out=self)
321+
return _ufuncs.left_shift(self, other, out=self)
322322

323-
__rshift__ = __rrshift__ = _binary_ufuncs.right_shift
323+
__rshift__ = __rrshift__ = _ufuncs.right_shift
324324

325325
def __irshift__(self, other):
326-
return _binary_ufuncs.right_shift(self, other, out=self)
326+
return _ufuncs.right_shift(self, other, out=self)
327327

328-
__matmul__ = _binary_ufuncs.matmul
328+
__matmul__ = _ufuncs.matmul
329329

330330
def __rmatmul__(self, other):
331-
return _binary_ufuncs.matmul(other, self)
331+
return _ufuncs.matmul(other, self)
332332

333333
def __imatmul__(self, other):
334-
return _binary_ufuncs.matmul(self, other, out=self)
334+
return _ufuncs.matmul(self, other, out=self)
335335

336336
# unary ops
337-
__invert__ = _unary_ufuncs.invert
338-
__abs__ = _unary_ufuncs.absolute
339-
__pos__ = _unary_ufuncs.positive
340-
__neg__ = _unary_ufuncs.negative
337+
__invert__ = _ufuncs.invert
338+
__abs__ = _ufuncs.absolute
339+
__pos__ = _ufuncs.positive
340+
__neg__ = _ufuncs.negative
341341

342-
conjugate = _unary_ufuncs.conjugate
342+
conjugate = _ufuncs.conjugate
343343
conj = conjugate
344344

345345
### methods to match namespace functions

torch_np/_binary_ufuncs.py renamed to torch_np/_ufuncs.py

Lines changed: 70 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,14 @@
22

33
import torch
44

5-
from . import _helpers
6-
from ._detail import _binary_ufuncs
5+
from . import _binary_ufuncs_impl, _helpers, _unary_ufuncs_impl
76
from ._normalizations import ArrayLike, DTypeLike, NDArray, SubokLike, normalizer
87

9-
__all__ = [
8+
# ############# Binary ufuncs ######################
9+
10+
__binary__ = [
1011
name
11-
for name in dir(_binary_ufuncs)
12+
for name in dir(_binary_ufuncs_impl)
1213
if not name.startswith("_") and name not in ["torch", "matmul"]
1314
]
1415

@@ -76,7 +77,7 @@ def matmul(
7677
raise NotImplementedError
7778

7879
# NB: do not broadcast input tensors against the out=... array
79-
result = _binary_ufuncs.matmul(*tensors)
80+
result = _binary_ufuncs_impl.matmul(*tensors)
8081
return result
8182

8283

@@ -110,7 +111,7 @@ def divmod(
110111
(x1, x2), out, True, casting, order, dtype, subok, signature, extobj
111112
)
112113

113-
result = _binary_ufuncs.divmod(*tensors)
114+
result = _binary_ufuncs_impl.divmod(*tensors)
114115

115116
return quot, rem
116117

@@ -119,8 +120,8 @@ def divmod(
119120
# For each torch ufunc implementation, decorate and attach the decorated name
120121
# to this module. Its contents is then exported to the public namespace in __init__.py
121122
#
122-
for name in __all__:
123-
ufunc = getattr(_binary_ufuncs, name)
123+
for name in __binary__:
124+
ufunc = getattr(_binary_ufuncs_impl, name)
124125
decorated = normalizer(deco_binary_ufunc(ufunc))
125126

126127
decorated.__qualname__ = name # XXX: is this really correct?
@@ -133,4 +134,64 @@ def modf(x, /, *args, **kwds):
133134
return rem, quot
134135

135136

136-
__all__ = __all__ + ["divmod", "modf", "matmul"]
137+
__binary__ = __binary__ + ["divmod", "modf", "matmul"]
138+
139+
140+
# ############# Unary ufuncs ######################
141+
142+
143+
__unary__ = [
144+
name
145+
for name in dir(_unary_ufuncs_impl)
146+
if not name.startswith("_") and name != "torch"
147+
]
148+
149+
150+
def deco_unary_ufunc(torch_func):
151+
"""Common infra for unary ufuncs.
152+
153+
Normalize arguments, sort out type casting, broadcasting and delegate to
154+
the pytorch functions for the actual work.
155+
"""
156+
157+
def wrapped(
158+
x: ArrayLike,
159+
/,
160+
out: Optional[NDArray] = None,
161+
*,
162+
where=True,
163+
casting="same_kind",
164+
order="K",
165+
dtype: DTypeLike = None,
166+
subok: SubokLike = False,
167+
signature=None,
168+
extobj=None,
169+
):
170+
tensors = _helpers.ufunc_preprocess(
171+
(x,), out, where, casting, order, dtype, subok, signature, extobj
172+
)
173+
# now broadcast the input tensor against the out=... array
174+
if out is not None:
175+
# XXX: need to filter out noop broadcasts if t.shape == out.shape?
176+
shape = out.shape
177+
tensors = tuple(torch.broadcast_to(t, shape) for t in tensors)
178+
result = torch_func(*tensors)
179+
return result
180+
181+
return wrapped
182+
183+
184+
#
185+
# For each torch ufunc implementation, decorate and attach the decorated name
186+
# to this module. Its contents is then exported to the public namespace in __init__.py
187+
#
188+
for name in __unary__:
189+
ufunc = getattr(_unary_ufuncs_impl, name)
190+
decorated = normalizer(deco_unary_ufunc(ufunc))
191+
192+
decorated.__qualname__ = name # XXX: is this really correct?
193+
decorated.__name__ = name
194+
vars()[name] = decorated
195+
196+
197+
__all__ = __binary__ + __unary__

torch_np/_unary_ufuncs.py

Lines changed: 0 additions & 58 deletions
This file was deleted.

torch_np/_detail/_unary_ufuncs.py renamed to torch_np/_unary_ufuncs_impl.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
"""Export torch work functions for unary ufuncs, rename/tweak to match numpy.
2-
This listing is further exported to public symbols in the `torch_np/_unary_ufuncs.py` module.
2+
This listing is further exported to public symbols in the `torch_np/_ufuncs.py` module.
33
"""
44

55
import torch

0 commit comments

Comments
 (0)