Skip to content

Commit 906f005

Browse files
authored
Merge pull request #91 from Quansight-Labs/rebase_stack
add divmod, matmul; add dtype=... arg to ufuncs, add several missing alisases.
2 parents 37d6f4f + 9b3e5ec commit 906f005

14 files changed

+278
-158
lines changed

torch_np/_binary_ufuncs.py

Lines changed: 97 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,15 @@
11
from typing import Optional
22

3+
import torch
4+
35
from . import _helpers
46
from ._detail import _binary_ufuncs
57
from ._normalizations import ArrayLike, DTypeLike, NDArray, SubokLike, normalizer
68

79
__all__ = [
8-
name for name in dir(_binary_ufuncs) if not name.startswith("_") and name != "torch"
10+
name
11+
for name in dir(_binary_ufuncs)
12+
if not name.startswith("_") and name not in ["torch", "matmul"]
913
]
1014

1115

@@ -33,12 +37,49 @@ def wrapped(
3337
tensors = _helpers.ufunc_preprocess(
3438
(x1, x2), out, where, casting, order, dtype, subok, signature, extobj
3539
)
40+
# now broadcast input tensors against the out=... array
41+
if out is not None:
42+
# XXX: need to filter out noop broadcasts if t.shape == out.shape?
43+
shape = out.shape
44+
tensors = tuple(torch.broadcast_to(t, shape) for t in tensors)
45+
3646
result = torch_func(*tensors)
3747
return _helpers.result_or_out(result, out)
3848

3949
return wrapped
4050

4151

52+
#
53+
# matmul is special in that its `out=...` array does not broadcast x1 and x2.
54+
# E.g. consider x1.shape = (5, 2) and x2.shape = (2, 3). Then `out.shape` is (5, 3).
55+
#
56+
@normalizer
57+
def matmul(
58+
x1: ArrayLike,
59+
x2: ArrayLike,
60+
/,
61+
out: Optional[NDArray] = None,
62+
*,
63+
casting="same_kind",
64+
order="K",
65+
dtype: DTypeLike = None,
66+
subok: SubokLike = False,
67+
signature=None,
68+
extobj=None,
69+
axes=None,
70+
axis=None,
71+
):
72+
tensors = _helpers.ufunc_preprocess(
73+
(x1, x2), out, True, casting, order, dtype, subok, signature, extobj
74+
)
75+
if axis is not None or axes is not None:
76+
raise NotImplementedError
77+
78+
# NB: do not broadcast input tensors against the out=... array
79+
result = _binary_ufuncs.matmul(*tensors)
80+
return _helpers.result_or_out(result, out)
81+
82+
4283
#
4384
# For each torch ufunc implementation, decorate and attach the decorated name
4485
# to this module. Its contents is then exported to the public namespace in __init__.py
@@ -50,3 +91,58 @@ def wrapped(
5091
decorated.__qualname__ = name # XXX: is this really correct?
5192
decorated.__name__ = name
5293
vars()[name] = decorated
94+
95+
96+
# a stub implementation of divmod, should be improved after
97+
# https://github.com/pytorch/pytorch/issues/90820 is fixed in pytorch
98+
#
99+
# Implementation details: we just call two ufuncs which have been created
100+
# just above, for x1 // x2 and x1 % x2.
101+
# This means we are normalizing x1, x2 in each of the ufuncs --- note that there
102+
# is no @normalizer on divmod.
103+
104+
105+
def divmod(
106+
x1,
107+
x2,
108+
/,
109+
out=None,
110+
*,
111+
where=True,
112+
casting="same_kind",
113+
order="K",
114+
dtype=None,
115+
subok: SubokLike = False,
116+
signature=None,
117+
extobj=None,
118+
):
119+
out1, out2 = None, None
120+
if out is not None:
121+
out1, out2 = out
122+
123+
kwds = dict(
124+
where=where,
125+
casting=casting,
126+
order=order,
127+
dtype=dtype,
128+
subok=subok,
129+
signature=signature,
130+
extobj=extobj,
131+
)
132+
133+
# NB: use local names for
134+
quot = floor_divide(x1, x2, out=out1, **kwds)
135+
rem = remainder(x1, x2, out=out2, **kwds)
136+
137+
quot = _helpers.result_or_out(quot.tensor, out1)
138+
rem = _helpers.result_or_out(rem.tensor, out2)
139+
140+
return quot, rem
141+
142+
143+
def modf(x, /, *args, **kwds):
144+
quot, rem = divmod(x, 1, *args, **kwds)
145+
return rem, quot
146+
147+
148+
__all__ = __all__ + ["divmod", "modf", "matmul"]

torch_np/_detail/_binary_ufuncs.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,16 +38,35 @@
3838
not_equal,
3939
)
4040
from torch import pow as power
41-
from torch import remainder, subtract
41+
from torch import remainder
42+
from torch import remainder as mod
43+
from torch import subtract, true_divide
4244

4345
from . import _dtypes_impl, _util
4446

4547

4648
# work around torch limitations w.r.t. numpy
4749
def matmul(x, y):
48-
# work around RuntimeError: expected scalar type Int but found Double
50+
# work around:
51+
# - RuntimeError: expected scalar type Int but found Double
52+
# - RuntimeError: "addmm_impl_cpu_" not implemented for 'Bool'
53+
# - RuntimeError: "addmm_impl_cpu_" not implemented for 'Half'
4954
dtype = _dtypes_impl.result_type_impl((x.dtype, y.dtype))
50-
x = _util.cast_if_needed(x, dtype)
51-
y = _util.cast_if_needed(y, dtype)
55+
is_bool = dtype == torch.bool
56+
is_half = dtype == torch.float16
57+
58+
work_dtype = dtype
59+
if is_bool:
60+
work_dtype = torch.uint8
61+
if is_half:
62+
work_dtype = torch.float32
63+
64+
x = _util.cast_if_needed(x, work_dtype)
65+
y = _util.cast_if_needed(y, work_dtype)
66+
5267
result = torch.matmul(x, y)
68+
69+
if work_dtype != dtype:
70+
result = result.to(dtype)
71+
5372
return result

torch_np/_detail/_unary_ufuncs.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
# renames
88
from torch import absolute as fabs
99
from torch import arccos, arccosh, arcsin, arcsinh, arctan, arctanh
10+
from torch import bitwise_not
1011
from torch import bitwise_not as invert
1112
from torch import ceil
1213
from torch import conj_physical as conjugate
@@ -31,6 +32,7 @@
3132
from torch import rad2deg
3233
from torch import rad2deg as degrees
3334
from torch import reciprocal
35+
from torch import round as fix
3436
from torch import round as rint
3537
from torch import sign, signbit, sin, sinh, sqrt, square, tan, tanh, trunc
3638

torch_np/_detail/_util.py

Lines changed: 1 addition & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ def axis_none_ravel(*tensors, axis=None):
137137
return tensors, axis
138138

139139

140-
def cast_dont_broadcast(tensors, target_dtype, casting):
140+
def typecast_tensors(tensors, target_dtype, casting):
141141
"""Dtype-cast tensors to target_dtype.
142142
143143
Parameters
@@ -170,52 +170,6 @@ def cast_dont_broadcast(tensors, target_dtype, casting):
170170
return tuple(cast_tensors)
171171

172172

173-
def cast_and_broadcast(tensors, out_param, casting):
174-
"""
175-
Parameters
176-
----------
177-
tensors : iterable
178-
tuple or list of torch.Tensors to broadcast/typecast
179-
target_dtype : a torch.dtype object
180-
The torch dtype to cast all tensors to
181-
target_shape : tuple
182-
The tensor shape to broadcast all `tensors` to
183-
casting : str
184-
The casting mode, see `np.can_cast`
185-
186-
Returns
187-
-------
188-
a tuple of torch.Tensors with dtype being the PyTorch counterpart
189-
of the `target_dtype` and `target_shape`
190-
"""
191-
if out_param is None:
192-
return tensors
193-
194-
target_dtype, target_shape = out_param
195-
196-
can_cast = _dtypes_impl.can_cast_impl
197-
198-
processed_tensors = []
199-
for tensor in tensors:
200-
# check dtypes of x and out
201-
if not can_cast(tensor.dtype, target_dtype, casting=casting):
202-
raise TypeError(
203-
f"Cannot cast array data from {tensor.dtype} to"
204-
f" {target_dtype} according to the rule '{casting}'"
205-
)
206-
207-
# cast arr if needed
208-
tensor = cast_if_needed(tensor, target_dtype)
209-
210-
# `out` broadcasts `tensor`
211-
if tensor.shape != target_shape:
212-
tensor = torch.broadcast_to(tensor, target_shape)
213-
214-
processed_tensors.append(tensor)
215-
216-
return tuple(processed_tensors)
217-
218-
219173
def axis_expand_func(func, tensor, axis, *args, **kwds):
220174
"""Generically handle axis arguments in reductions."""
221175
if axis is not None:

torch_np/_detail/implementations.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -308,7 +308,7 @@ def _concat_cast_helper(tensors, out=None, dtype=None, casting="same_kind"):
308308
out_dtype = _dtypes_impl.result_type_impl([t.dtype for t in tensors])
309309

310310
# cast input arrays if necessary; do not broadcast them agains `out`
311-
tensors = _util.cast_dont_broadcast(tensors, out_dtype, casting)
311+
tensors = _util.typecast_tensors(tensors, out_dtype, casting)
312312

313313
return tensors
314314

@@ -467,7 +467,7 @@ def bincount(x, /, weights=None, minlength=0):
467467
x = x.new_empty(0, dtype=int)
468468

469469
int_dtype = _dtypes_impl.default_int_dtype
470-
(x,) = _util.cast_dont_broadcast((x,), int_dtype, casting="safe")
470+
(x,) = _util.typecast_tensors((x,), int_dtype, casting="safe")
471471

472472
result = torch.bincount(x, weights, minlength)
473473
return result

torch_np/_dtypes.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,10 @@ class bool_(generic):
174174
"double": float64,
175175
"float_": float64,
176176
"csingle": complex64,
177+
"singlecomplex": complex64,
177178
"cdouble": complex128,
179+
"cfloat": complex128,
180+
"complex_": complex128,
178181
}
179182
for name, obj in _name_aliases.items():
180183
globals()[name] = obj

torch_np/_helpers.py

Lines changed: 17 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,58 +1,32 @@
11
import torch
22

3-
from . import _dtypes
4-
from ._detail import _util
5-
6-
7-
def cast_and_broadcast(tensors, out, casting):
8-
"""Cast dtypes of arrays to out.dtype and broadcast if needed.
9-
10-
Parameters
11-
----------
12-
arrays : sequence of arrays
13-
Each element is broadcast against `out` and typecast to out.dtype
14-
out : the "output" array
15-
Not modified.
16-
casting : str
17-
One of numpy casting modes
18-
19-
Returns
20-
-------
21-
tensors : tuple of Tensors
22-
Each tensor is dtype-cast and broadcast agains `out`, as needed
23-
24-
Notes
25-
-----
26-
The `out` arrays broadcasts and dtype-casts `arrays`, but not vice versa.
27-
28-
"""
29-
if out is None:
30-
return tensors
31-
else:
32-
tensors = _util.cast_and_broadcast(
33-
tensors, out.dtype.type.torch_dtype, out.shape, casting
34-
)
35-
36-
return tuple(tensors)
3+
from ._detail import _dtypes_impl, _util
374

385

396
def ufunc_preprocess(
407
tensors, out, where, casting, order, dtype, subok, signature, extobj
418
):
9+
"""
10+
Notes
11+
-----
12+
The `out` array broadcasts `tensors`, but not vice versa.
13+
"""
4214
# internal preprocessing or args in ufuncs (cf _unary_ufuncs, _binary_ufuncs)
4315
if order != "K" or not where or signature or extobj:
4416
raise NotImplementedError
4517

46-
# XXX: dtype=... parameter
47-
if dtype is not None:
48-
raise NotImplementedError
49-
50-
out_shape_dtype = None
51-
if out is not None:
52-
out_shape_dtype = (out.tensor.dtype, out.tensor.shape)
53-
54-
tensors = _util.cast_and_broadcast(tensors, out_shape_dtype, casting)
18+
# dtype of the result: depends on both dtype=... and out=... arguments
19+
if dtype is None:
20+
out_dtype = None if out is None else out.dtype.torch_dtype
21+
else:
22+
out_dtype = (
23+
dtype
24+
if out is None
25+
else _dtypes_impl.result_type_impl([dtype, out.dtype.torch_dtype])
26+
)
5527

28+
if out_dtype:
29+
tensors = _util.typecast_tensors(tensors, out_dtype, casting)
5630
return tensors
5731

5832

torch_np/_ndarray.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -266,6 +266,8 @@ def __rfloordiv__(self, other):
266266
def __ifloordiv__(self, other):
267267
return _binary_ufuncs.floor_divide(self, other, out=self)
268268

269+
__divmod__ = _binary_ufuncs.divmod
270+
269271
# power, self**exponent
270272
__pow__ = __rpow__ = _binary_ufuncs.float_power
271273

@@ -311,6 +313,14 @@ def __ilshift__(self, other):
311313
def __irshift__(self, other):
312314
return _binary_ufuncs.right_shift(self, other, out=self)
313315

316+
__matmul__ = _binary_ufuncs.matmul
317+
318+
def __rmatmul__(self, other):
319+
return _binary_ufuncs.matmul(other, self)
320+
321+
def __imatmul__(self, other):
322+
return _binary_ufuncs.matmul(self, other, out=self)
323+
314324
# unary ops
315325
__invert__ = _unary_ufuncs.invert
316326
__abs__ = _unary_ufuncs.absolute

torch_np/_normalizations.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def normalize_dtype(dtype, name=None):
4545
return torch_dtype
4646

4747

48-
def normalize_subok_like(arg, name):
48+
def normalize_subok_like(arg, name="subok"):
4949
if arg:
5050
raise ValueError(f"'{name}' parameter is not supported.")
5151

@@ -88,7 +88,7 @@ def maybe_normalize(arg, parm, return_on_failure=_sentinel):
8888
"""Normalize arg if a normalizer is registred."""
8989
normalizer = normalizers.get(parm.annotation, None)
9090
try:
91-
return normalizer(arg) if normalizer else arg
91+
return normalizer(arg, parm.name) if normalizer else arg
9292
except Exception as exc:
9393
if return_on_failure is not _sentinel:
9494
return return_on_failure

0 commit comments

Comments
 (0)