Skip to content

Commit 4874861

Browse files
authored
Merge pull request #29 from Quansight-Labs/function_base
finish up reductions and related animals
2 parents 6c0dd12 + 45ec94c commit 4874861

File tree

11 files changed

+401
-383
lines changed

11 files changed

+401
-383
lines changed

autogen/numpy_api_dump.py

Lines changed: 0 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -242,14 +242,6 @@ def cross(a, b, axisa=-1, axisb=-1, axisc=-1, axis=None):
242242
raise NotImplementedError
243243

244244

245-
def cumprod(a, axis=None, dtype=None, out=None):
246-
raise NotImplementedError
247-
248-
249-
def cumproduct(*args, **kwargs):
250-
raise NotImplementedError
251-
252-
253245
def cumsum(a, axis=None, dtype=None, out=None):
254246
raise NotImplementedError
255247

@@ -610,10 +602,6 @@ def may_share_memory(a, b, /, max_work=None):
610602
raise NotImplementedError
611603

612604

613-
def median(a, axis=None, out=None, overwrite_input=False, keepdims=False):
614-
raise NotImplementedError
615-
616-
617605
def meshgrid(*xi, copy=True, sparse=False, indexing="xy"):
618606
raise NotImplementedError
619607

@@ -734,20 +722,6 @@ def partition(a, kth, axis=-1, kind="introselect", order=None):
734722
raise NotImplementedError
735723

736724

737-
def percentile(
738-
a,
739-
q,
740-
axis=None,
741-
out=None,
742-
overwrite_input=False,
743-
method="linear",
744-
keepdims=False,
745-
*,
746-
interpolation=None,
747-
):
748-
raise NotImplementedError
749-
750-
751725
def piecewise(x, condlist, funclist, *args, **kw):
752726
raise NotImplementedError
753727

@@ -800,10 +774,6 @@ def product(*args, **kwargs):
800774
raise NotImplementedError
801775

802776

803-
def ptp(a, axis=None, out=None, keepdims=NoValue):
804-
raise NotImplementedError
805-
806-
807777
def put(a, ind, v, mode="raise"):
808778
raise NotImplementedError
809779

@@ -816,20 +786,6 @@ def putmask(a, mask, values):
816786
raise NotImplementedError
817787

818788

819-
def quantile(
820-
a,
821-
q,
822-
axis=None,
823-
out=None,
824-
overwrite_input=False,
825-
method="linear",
826-
keepdims=False,
827-
*,
828-
interpolation=None,
829-
):
830-
raise NotImplementedError
831-
832-
833789
def ravel(a, order="C"):
834790
raise NotImplementedError
835791

torch_np/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010

1111
# from . import testing
1212

13+
alltrue = all
14+
sometrue = any
1315

1416
inf = float("inf")
1517
nan = float("nan")

torch_np/_decorators.py

Lines changed: 31 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -93,10 +93,7 @@ def axis_keepdims_wrapper(func):
9393
see them). The `axis` argument we normalize and pass through to pytorch functions.
9494
9595
"""
96-
# XXX: move this out of _ndarray.py (circular imports)
97-
#
98-
# TODO: 1. get rid of _helpers.result_or_out
99-
# 2. sort out function signatures: how they flow through all decorators etc
96+
# TODO: sort out function signatures: how they flow through all decorators etc
10097
@functools.wraps(func)
10198
def wrapped(a, axis=None, keepdims=NoValue, *args, **kwds):
10299
from ._ndarray import asarray, ndarray
@@ -107,7 +104,36 @@ def wrapped(a, axis=None, keepdims=NoValue, *args, **kwds):
107104
if isinstance(axis, ndarray):
108105
axis = operator.index(axis)
109106

110-
result = _util.axis_keepdims(func, tensor, axis, keepdims, *args, **kwds)
107+
result = _util.axis_expand_func(func, tensor, axis, *args, **kwds)
108+
109+
if keepdims:
110+
result = _util.apply_keepdims(result, axis, tensor.ndim)
111+
112+
return result
113+
114+
return wrapped
115+
116+
117+
def axis_none_ravel_wrapper(func):
118+
"""`func` accepts an array-like as a 1st arg, returns a tensor.
119+
120+
This decorator implements the generic handling of axis=None acting on a
121+
raveled array. One use is cumprod / cumsum. concatenate also uses a
122+
similar logic.
123+
124+
"""
125+
126+
@functools.wraps(func)
127+
def wrapped(a, axis=None, *args, **kwds):
128+
from ._ndarray import asarray, ndarray
129+
130+
tensor = asarray(a).get()
131+
132+
# standardize the axis argument
133+
if isinstance(axis, ndarray):
134+
axis = operator.index(axis)
135+
136+
result = _util.axis_ravel_func(func, tensor, axis, *args, **kwds)
111137
return result
112138

113139
return wrapped

torch_np/_detail/_reductions.py

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,11 @@ def min(tensor, axis=None, initial=NoValue, where=NoValue):
8989
return result
9090

9191

92+
def ptp(tensor, axis=None):
93+
result = tensor.amax(axis) - tensor.amin(axis)
94+
return result
95+
96+
9297
def sum(tensor, axis=None, dtype=None, initial=NoValue, where=NoValue):
9398
if initial is not NoValue or where is not NoValue:
9499
raise NotImplementedError
@@ -161,3 +166,97 @@ def var(tensor, axis=None, dtype=None, ddof=0, *, where=NoValue):
161166
result = tensor.var(dim=axis, correction=ddof)
162167

163168
return result
169+
170+
171+
# cumsum / cumprod are almost reductions:
172+
# 1. no keepdims
173+
# 2. axis=None ravels (cf concatenate)
174+
175+
176+
def cumprod(tensor, axis, dtype=None):
177+
if dtype == torch.bool:
178+
dtype = _scalar_types.default_int_type.dtype
179+
if dtype is None:
180+
dtype = tensor.dtype
181+
182+
result = tensor.cumprod(axis=axis, dtype=dtype)
183+
184+
return result
185+
186+
187+
def cumsum(tensor, axis, dtype=None):
188+
if dtype == torch.bool:
189+
dtype = _scalar_types.default_int_type.dtype
190+
if dtype is None:
191+
dtype = tensor.dtype
192+
193+
result = tensor.cumsum(axis=axis, dtype=dtype)
194+
195+
return result
196+
197+
198+
def average(a_tensor, axis, w_tensor):
199+
200+
# dtype
201+
# FIXME: 1. use result_type
202+
# 2. actually implement multiply w/dtype
203+
if not a_tensor.dtype.is_floating_point:
204+
result_dtype = torch.float64
205+
a_tensor = a_tensor.to(result_dtype)
206+
207+
# axis
208+
if axis is None:
209+
(a_tensor, w_tensor), axis = _util.axis_none_ravel(
210+
a_tensor, w_tensor, axis=axis
211+
)
212+
213+
# axis & weights
214+
if a_tensor.shape != w_tensor.shape:
215+
if axis is None:
216+
raise TypeError(
217+
"Axis must be specified when shapes of a and weights " "differ."
218+
)
219+
if w_tensor.ndim != 1:
220+
raise TypeError("1D weights expected when shapes of a and weights differ.")
221+
if w_tensor.shape[0] != a_tensor.shape[axis]:
222+
raise ValueError("Length of weights not compatible with specified axis.")
223+
224+
# setup weight to broadcast along axis
225+
w_tensor = torch.broadcast_to(
226+
w_tensor, (a_tensor.ndim - 1) * (1,) + w_tensor.shape
227+
)
228+
w_tensor = w_tensor.swapaxes(-1, axis)
229+
230+
# do the work
231+
numerator = torch.mul(a_tensor, w_tensor).sum(axis)
232+
denominator = w_tensor.sum(axis)
233+
result = numerator / denominator
234+
235+
return result, denominator
236+
237+
238+
def quantile(a_tensor, q_tensor, axis, method):
239+
240+
if (0 > q_tensor).any() or (q_tensor > 1).any():
241+
raise ValueError("Quantiles must be in range [0, 1], got %s" % q_tensor)
242+
243+
if not a_tensor.dtype.is_floating_point:
244+
dtype = _scalar_types.default_float_type.torch_dtype
245+
a_tensor = a_tensor.to(dtype)
246+
247+
# edge case: torch.quantile only supports float32 and float64
248+
if a_tensor.dtype == torch.float16:
249+
a_tensor = a_tensor.to(torch.float32)
250+
251+
# axis
252+
if axis is not None:
253+
axis = _util.normalize_axis_tuple(axis, a_tensor.ndim)
254+
axis = _util.allow_only_single_axis(axis)
255+
256+
q_tensor = q_tensor.to(a_tensor.dtype)
257+
258+
(a_tensor, q_tensor), axis = _util.axis_none_ravel(a_tensor, q_tensor, axis=axis)
259+
260+
result = torch.quantile(a_tensor, q_tensor, axis=axis, interpolation=method)
261+
262+
return result

torch_np/_detail/_util.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,8 @@ def expand_shape(arr_shape, axis):
111111
def apply_keepdims(tensor, axis, ndim):
112112
if axis is None:
113113
# tensor was a scalar
114-
tensor = torch.full((1,) * ndim, fill_value=tensor, dtype=tensor.dtype)
114+
shape = (1,) * ndim
115+
tensor = tensor.expand(shape).contiguous() # avoid CUDA synchronization
115116
else:
116117
shape = expand_shape(tensor.shape, axis)
117118
tensor = tensor.reshape(shape)
@@ -211,8 +212,8 @@ def cast_and_broadcast(tensors, out_param, casting):
211212
return tuple(processed_tensors)
212213

213214

214-
def axis_keepdims(func, tensor, axis, keepdims, *args, **kwds):
215-
"""Generically handle axis and keepdims arguments in reductions."""
215+
def axis_expand_func(func, tensor, axis, *args, **kwds):
216+
"""Generically handle axis arguments in reductions."""
216217
if axis is not None:
217218
if not isinstance(axis, (list, tuple)):
218219
axis = (axis,)
@@ -225,8 +226,18 @@ def axis_keepdims(func, tensor, axis, keepdims, *args, **kwds):
225226

226227
result = func(tensor, axis=axis, *args, **kwds)
227228

228-
if keepdims:
229-
result = apply_keepdims(result, axis, tensor.ndim)
229+
return result
230+
231+
232+
def axis_ravel_func(func, tensor, axis, *args, **kwds):
233+
"""Generically handle axis arguments in cumsum/cumprod."""
234+
if axis is not None:
235+
axis = normalize_axis_index(axis, tensor.ndim)
236+
237+
tensors, axis = axis_none_ravel(tensor, axis=axis)
238+
tensor = tensors[0]
239+
240+
result = func(tensor, axis=axis, *args, **kwds)
230241

231242
return result
232243

torch_np/_helpers.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,13 +40,26 @@ def cast_and_broadcast(tensors, out, casting):
4040
return tuple(tensors)
4141

4242

43-
def result_or_out(result_tensor, out_array=None):
44-
"""A helper for returns with out= argument."""
43+
def result_or_out(result_tensor, out_array=None, promote_scalar=False):
44+
"""A helper for returns with out= argument.
45+
46+
If `promote_scalar is True`, then:
47+
if result_tensor.numel() == 1 and out is zero-dimensional,
48+
result_tensor is placed into the out array.
49+
This weirdness is used e.g. in `np.percentile`
50+
"""
4551
if out_array is not None:
4652
if not isinstance(out_array, ndarray):
4753
raise TypeError("Return arrays must be of ArrayType")
4854
if result_tensor.shape != out_array.shape:
49-
raise ValueError("Bad size of the out array.")
55+
can_fit = result_tensor.numel() == 1 and out_array.ndim == 0
56+
if promote_scalar and can_fit:
57+
result_tensor = result_tensor.squeeze()
58+
else:
59+
raise ValueError(
60+
f"Bad size of the out array: out.shape = {out_array.shape}"
61+
f" while result.shape = {result_tensor.shape}."
62+
)
5063
out_tensor = out_array.get()
5164
out_tensor.copy_(result_tensor)
5265
return out_array

torch_np/_ndarray.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,13 @@
33
import torch
44

55
from . import _binary_ufuncs, _dtypes, _helpers, _unary_ufuncs
6-
from ._decorators import NoValue, axis_keepdims_wrapper, dtype_to_torch, emulate_out_arg
6+
from ._decorators import (
7+
NoValue,
8+
axis_keepdims_wrapper,
9+
axis_none_ravel_wrapper,
10+
dtype_to_torch,
11+
emulate_out_arg,
12+
)
713
from ._detail import _reductions, _util
814

915
newaxis = None
@@ -274,13 +280,21 @@ def nonzero(self):
274280
all = emulate_out_arg(axis_keepdims_wrapper(_reductions.all))
275281
max = emulate_out_arg(axis_keepdims_wrapper(_reductions.max))
276282
min = emulate_out_arg(axis_keepdims_wrapper(_reductions.min))
283+
ptp = emulate_out_arg(axis_keepdims_wrapper(_reductions.ptp))
277284

278285
sum = emulate_out_arg(axis_keepdims_wrapper(dtype_to_torch(_reductions.sum)))
279286
prod = emulate_out_arg(axis_keepdims_wrapper(dtype_to_torch(_reductions.prod)))
280287
mean = emulate_out_arg(axis_keepdims_wrapper(dtype_to_torch(_reductions.mean)))
281288
var = emulate_out_arg(axis_keepdims_wrapper(dtype_to_torch(_reductions.var)))
282289
std = emulate_out_arg(axis_keepdims_wrapper(dtype_to_torch(_reductions.std)))
283290

291+
cumprod = emulate_out_arg(
292+
axis_none_ravel_wrapper(dtype_to_torch(_reductions.cumprod))
293+
)
294+
cumsum = emulate_out_arg(
295+
axis_none_ravel_wrapper(dtype_to_torch(_reductions.cumsum))
296+
)
297+
284298
### indexing ###
285299
def __getitem__(self, *args, **kwds):
286300
t_args = _helpers.ndarrays_to_tensors(*args)
@@ -380,6 +394,8 @@ def result_type(*arrays_and_dtypes):
380394
dtypes.append(_dtypes.dtype(entry))
381395
elif isinstance(entry, _dtypes.DType):
382396
dtypes.append(entry)
397+
elif isinstance(entry, str):
398+
dtypes.append(_dtypes.dtype(entry))
383399
else:
384400
dtypes.append(asarray(entry).dtype)
385401

0 commit comments

Comments
 (0)