Skip to content

implement flip-based array manipulations #39

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Feb 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 0 additions & 20 deletions autogen/numpy_api_dump.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,10 +146,6 @@ def asmatrix(data, dtype=None):
raise NotImplementedError


def average(a, axis=None, weights=None, returned=False, *, keepdims=NoValue):
raise NotImplementedError


def bartlett(M):
raise NotImplementedError

Expand Down Expand Up @@ -242,10 +238,6 @@ def cross(a, b, axisa=-1, axisb=-1, axisc=-1, axis=None):
raise NotImplementedError


def cumsum(a, axis=None, dtype=None, out=None):
raise NotImplementedError


def datetime_as_string(arr, unit=None, timezone="naive", casting="same_kind"):
raise NotImplementedError

Expand Down Expand Up @@ -330,10 +322,6 @@ def fix(x, out=None):
raise NotImplementedError


def flip(m, axis=None):
raise NotImplementedError


def fliplr(m):
raise NotImplementedError

Expand Down Expand Up @@ -770,10 +758,6 @@ def printoptions(*args, **kwargs):
raise NotImplementedError


def product(*args, **kwargs):
raise NotImplementedError


def put(a, ind, v, mode="raise"):
raise NotImplementedError

Expand Down Expand Up @@ -818,10 +802,6 @@ def roots(p):
raise NotImplementedError


def rot90(m, k=1, axes=(0, 1)):
raise NotImplementedError


def safe_eval(source):
raise NotImplementedError

Expand Down
59 changes: 59 additions & 0 deletions torch_np/_detail/_flips.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
"""Implementations of flip-based routines and related animals.
"""

import torch

from . import _scalar_types, _util


def flip(m_tensor, axis=None):
# XXX: semantic difference: np.flip returns a view, torch.flip copies
if axis is None:
axis = tuple(range(m_tensor.ndim))
else:
axis = _util.normalize_axis_tuple(axis, m_tensor.ndim)
return torch.flip(m_tensor, axis)


def flipud(m_tensor):
return torch.flipud(m_tensor)


def fliplr(m_tensor):
return torch.fliplr(m_tensor)


def rot90(m_tensor, k=1, axes=(0, 1)):
axes = _util.normalize_axis_tuple(axes, m_tensor.ndim)
return torch.rot90(m_tensor, k, axes)


def swapaxes(tensor, axis1, axis2):
return torch.swapaxes(tensor, axis1, axis2)


# Straight vendor from:
# https://github.com/numpy/numpy/blob/v1.24.0/numpy/core/numeric.py#L1259
#
# Also note this function in NumPy is mostly retained for backwards compat
# (https://stackoverflow.com/questions/29891583/reason-why-numpy-rollaxis-is-so-confusing)
# so let's not touch it unless hard pressed.
def rollaxis(tensor, axis, start=0):
n = tensor.ndim
axis = _util.normalize_axis_index(axis, n)
if start < 0:
start += n
msg = "'%s' arg requires %d <= %s < %d, but %d was passed in"
if not (0 <= start < n + 1):
raise _util.AxisError(msg % ("start", -n, "start", n + 1, start))
if axis < start:
# it's been removed
start -= 1
if axis == start:
# numpy returns a view, here we try returning the tensor itself
# return tensor[...]
return tensor
axes = list(range(0, n))
axes.remove(axis)
axes.insert(start, axis)
return tensor.view(axes)
6 changes: 6 additions & 0 deletions torch_np/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,3 +74,9 @@ def ndarrays_to_tensors(*inputs):
def to_tensors(*inputs):
"""Convert all array_likes from `inputs` to tensors."""
return tuple(asarray(value).get() for value in inputs)


def _outer(x, y):
x_tensor, y_tensor = to_tensors(x, y)
result = torch.outer(x_tensor, y_tensor)
return asarray(result)
5 changes: 4 additions & 1 deletion torch_np/_ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
dtype_to_torch,
emulate_out_arg,
)
from ._detail import _reductions, _util
from ._detail import _flips, _reductions, _util

newaxis = None

Expand Down Expand Up @@ -264,6 +264,9 @@ def transpose(self, *axes):
raise ValueError("axes don't match array")
return ndarray._from_tensor_and_base(tensor, self)

def swapaxes(self, axis1, axis2):
return _flips.swapaxes(self._tensor, axis1, axis2)

def ravel(self, order="C"):
if order != "C":
raise NotImplementedError
Expand Down
39 changes: 32 additions & 7 deletions torch_np/_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import torch

from . import _dtypes, _helpers
from ._detail import _reductions, _util
from ._detail import _flips, _reductions, _util
from ._ndarray import (
array,
asarray,
Expand Down Expand Up @@ -440,12 +440,22 @@ def expand_dims(a, axis):

@asarray_replacer()
def flip(m, axis=None):
# XXX: semantic difference: np.flip returns a view, torch.flip copies
if axis is None:
axis = tuple(range(m.ndim))
else:
axis = _util.normalize_axis_tuple(axis, m.ndim)
return torch.flip(m, axis)
return _flips.flip(m, axis)


@asarray_replacer()
def flipud(m):
return _flips.flipud(m)


@asarray_replacer()
def fliplr(m):
return _flips.fliplr(m)


@asarray_replacer()
def rot90(m, k=1, axes=(0, 1)):
return _flips.rot90(m, k, axes)


@asarray_replacer()
Expand All @@ -466,9 +476,21 @@ def broadcast_arrays(*args, subok=False):

@asarray_replacer()
def moveaxis(a, source, destination):
source = _util.normalize_axis_tuple(source, a.ndim, "source")
destination = _util.normalize_axis_tuple(destination, a.ndim, "destination")
return asarray(torch.moveaxis(a, source, destination))


def swapaxis(a, axis1, axis2):
arr = asarray(a)
return arr.swapaxes(axis1, axis2)


@asarray_replacer()
def rollaxis(a, axis, start=0):
return _flips.rollaxis(a, axis, start)


def unravel_index(indices, shape, order="C"):
# cf https://github.com/pytorch/pytorch/pull/66687
# this version is from
Expand Down Expand Up @@ -645,6 +667,9 @@ def prod(
)


product = prod


def cumprod(a, axis=None, dtype=None, out=None):
arr = asarray(a)
return arr.cumprod(axis=axis, dtype=dtype, out=out)
Expand Down
Loading