Skip to content

Commit 14efb91

Browse files
authored
Merge pull request #39 from Quansight-Labs/flips
implement flip-based array manipulations
2 parents 4874861 + ded65f2 commit 14efb91

File tree

8 files changed

+201
-499
lines changed

8 files changed

+201
-499
lines changed

autogen/numpy_api_dump.py

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -146,10 +146,6 @@ def asmatrix(data, dtype=None):
146146
raise NotImplementedError
147147

148148

149-
def average(a, axis=None, weights=None, returned=False, *, keepdims=NoValue):
150-
raise NotImplementedError
151-
152-
153149
def bartlett(M):
154150
raise NotImplementedError
155151

@@ -242,10 +238,6 @@ def cross(a, b, axisa=-1, axisb=-1, axisc=-1, axis=None):
242238
raise NotImplementedError
243239

244240

245-
def cumsum(a, axis=None, dtype=None, out=None):
246-
raise NotImplementedError
247-
248-
249241
def datetime_as_string(arr, unit=None, timezone="naive", casting="same_kind"):
250242
raise NotImplementedError
251243

@@ -330,10 +322,6 @@ def fix(x, out=None):
330322
raise NotImplementedError
331323

332324

333-
def flip(m, axis=None):
334-
raise NotImplementedError
335-
336-
337325
def fliplr(m):
338326
raise NotImplementedError
339327

@@ -770,10 +758,6 @@ def printoptions(*args, **kwargs):
770758
raise NotImplementedError
771759

772760

773-
def product(*args, **kwargs):
774-
raise NotImplementedError
775-
776-
777761
def put(a, ind, v, mode="raise"):
778762
raise NotImplementedError
779763

@@ -818,10 +802,6 @@ def roots(p):
818802
raise NotImplementedError
819803

820804

821-
def rot90(m, k=1, axes=(0, 1)):
822-
raise NotImplementedError
823-
824-
825805
def safe_eval(source):
826806
raise NotImplementedError
827807

torch_np/_detail/_flips.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
"""Implementations of flip-based routines and related animals.
2+
"""
3+
4+
import torch
5+
6+
from . import _scalar_types, _util
7+
8+
9+
def flip(m_tensor, axis=None):
10+
# XXX: semantic difference: np.flip returns a view, torch.flip copies
11+
if axis is None:
12+
axis = tuple(range(m_tensor.ndim))
13+
else:
14+
axis = _util.normalize_axis_tuple(axis, m_tensor.ndim)
15+
return torch.flip(m_tensor, axis)
16+
17+
18+
def flipud(m_tensor):
19+
return torch.flipud(m_tensor)
20+
21+
22+
def fliplr(m_tensor):
23+
return torch.fliplr(m_tensor)
24+
25+
26+
def rot90(m_tensor, k=1, axes=(0, 1)):
27+
axes = _util.normalize_axis_tuple(axes, m_tensor.ndim)
28+
return torch.rot90(m_tensor, k, axes)
29+
30+
31+
def swapaxes(tensor, axis1, axis2):
32+
return torch.swapaxes(tensor, axis1, axis2)
33+
34+
35+
# Straight vendor from:
36+
# https://github.com/numpy/numpy/blob/v1.24.0/numpy/core/numeric.py#L1259
37+
#
38+
# Also note this function in NumPy is mostly retained for backwards compat
39+
# (https://stackoverflow.com/questions/29891583/reason-why-numpy-rollaxis-is-so-confusing)
40+
# so let's not touch it unless hard pressed.
41+
def rollaxis(tensor, axis, start=0):
42+
n = tensor.ndim
43+
axis = _util.normalize_axis_index(axis, n)
44+
if start < 0:
45+
start += n
46+
msg = "'%s' arg requires %d <= %s < %d, but %d was passed in"
47+
if not (0 <= start < n + 1):
48+
raise _util.AxisError(msg % ("start", -n, "start", n + 1, start))
49+
if axis < start:
50+
# it's been removed
51+
start -= 1
52+
if axis == start:
53+
# numpy returns a view, here we try returning the tensor itself
54+
# return tensor[...]
55+
return tensor
56+
axes = list(range(0, n))
57+
axes.remove(axis)
58+
axes.insert(start, axis)
59+
return tensor.view(axes)

torch_np/_helpers.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,3 +77,9 @@ def ndarrays_to_tensors(*inputs):
7777
def to_tensors(*inputs):
7878
"""Convert all array_likes from `inputs` to tensors."""
7979
return tuple(asarray(value).get() for value in inputs)
80+
81+
82+
def _outer(x, y):
83+
x_tensor, y_tensor = to_tensors(x, y)
84+
result = torch.outer(x_tensor, y_tensor)
85+
return asarray(result)

torch_np/_ndarray.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
dtype_to_torch,
1111
emulate_out_arg,
1212
)
13-
from ._detail import _reductions, _util
13+
from ._detail import _flips, _reductions, _util
1414

1515
newaxis = None
1616

@@ -264,6 +264,9 @@ def transpose(self, *axes):
264264
raise ValueError("axes don't match array")
265265
return ndarray._from_tensor_and_base(tensor, self)
266266

267+
def swapaxes(self, axis1, axis2):
268+
return _flips.swapaxes(self._tensor, axis1, axis2)
269+
267270
def ravel(self, order="C"):
268271
if order != "C":
269272
raise NotImplementedError

torch_np/_wrapper.py

Lines changed: 32 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import torch
99

1010
from . import _dtypes, _helpers
11-
from ._detail import _reductions, _util
11+
from ._detail import _flips, _reductions, _util
1212
from ._ndarray import (
1313
array,
1414
asarray,
@@ -440,12 +440,22 @@ def expand_dims(a, axis):
440440

441441
@asarray_replacer()
442442
def flip(m, axis=None):
443-
# XXX: semantic difference: np.flip returns a view, torch.flip copies
444-
if axis is None:
445-
axis = tuple(range(m.ndim))
446-
else:
447-
axis = _util.normalize_axis_tuple(axis, m.ndim)
448-
return torch.flip(m, axis)
443+
return _flips.flip(m, axis)
444+
445+
446+
@asarray_replacer()
447+
def flipud(m):
448+
return _flips.flipud(m)
449+
450+
451+
@asarray_replacer()
452+
def fliplr(m):
453+
return _flips.fliplr(m)
454+
455+
456+
@asarray_replacer()
457+
def rot90(m, k=1, axes=(0, 1)):
458+
return _flips.rot90(m, k, axes)
449459

450460

451461
@asarray_replacer()
@@ -466,9 +476,21 @@ def broadcast_arrays(*args, subok=False):
466476

467477
@asarray_replacer()
468478
def moveaxis(a, source, destination):
479+
source = _util.normalize_axis_tuple(source, a.ndim, "source")
480+
destination = _util.normalize_axis_tuple(destination, a.ndim, "destination")
469481
return asarray(torch.moveaxis(a, source, destination))
470482

471483

484+
def swapaxis(a, axis1, axis2):
485+
arr = asarray(a)
486+
return arr.swapaxes(axis1, axis2)
487+
488+
489+
@asarray_replacer()
490+
def rollaxis(a, axis, start=0):
491+
return _flips.rollaxis(a, axis, start)
492+
493+
472494
def unravel_index(indices, shape, order="C"):
473495
# cf https://github.com/pytorch/pytorch/pull/66687
474496
# this version is from
@@ -645,6 +667,9 @@ def prod(
645667
)
646668

647669

670+
product = prod
671+
672+
648673
def cumprod(a, axis=None, dtype=None, out=None):
649674
arr = asarray(a)
650675
return arr.cumprod(axis=axis, dtype=dtype, out=out)

0 commit comments

Comments
 (0)