Skip to content

API: remove ndarray.base #80

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Mar 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion torch_np/_decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def out_shape_dtype(func):
@functools.wraps(func)
def wrapped(*args, out=None, **kwds):
if out is not None:
kwds.update({"out_shape_dtype": (out.get().dtype, out.get().shape)})
kwds.update({"out_shape_dtype": (out.tensor.dtype, out.tensor.shape)})
result_tensor = func(*args, **kwds)
return _helpers.result_or_out(result_tensor, out)

Expand Down
6 changes: 3 additions & 3 deletions torch_np/_dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def __new__(self, value):
value = {"inf": torch.inf, "nan": torch.nan}[value]

if isinstance(value, _ndarray.ndarray):
tensor = value.get()
tensor = value.tensor
else:
try:
tensor = torch.as_tensor(value, dtype=self.torch_dtype)
Expand All @@ -49,7 +49,7 @@ def __new__(self, value):
# and here we follow the second approach and create a new object
# *for all inputs*.
#
return _ndarray.ndarray._from_tensor_and_base(tensor, None)
return _ndarray.ndarray(tensor)


##### these are abstract types
Expand Down Expand Up @@ -317,7 +317,7 @@ def __repr__(self):
@property
def itemsize(self):
elem = self.type(1)
return elem.get().element_size()
return elem.tensor.element_size()

def __getstate__(self):
return self._scalar_type
Expand Down
11 changes: 5 additions & 6 deletions torch_np/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def ufunc_preprocess(

out_shape_dtype = None
if out is not None:
out_shape_dtype = (out.get().dtype, out.get().shape)
out_shape_dtype = (out.tensor.dtype, out.tensor.shape)

tensors = _util.cast_and_broadcast(tensors, out_shape_dtype, casting)

Expand Down Expand Up @@ -77,7 +77,7 @@ def result_or_out(result_tensor, out_array=None, promote_scalar=False):
f"Bad size of the out array: out.shape = {out_array.shape}"
f" while result.shape = {result_tensor.shape}."
)
out_tensor = out_array.get()
out_tensor = out_array.tensor
out_tensor.copy_(result_tensor)
return out_array
else:
Expand All @@ -87,8 +87,7 @@ def result_or_out(result_tensor, out_array=None, promote_scalar=False):
def array_from(tensor, base=None):
from ._ndarray import ndarray

base = base if isinstance(base, ndarray) else None
return ndarray._from_tensor_and_base(tensor, base) # XXX: nuke .base
return ndarray(tensor)


def tuple_arrays_from(result):
Expand All @@ -109,7 +108,7 @@ def ndarrays_to_tensors(*inputs):
elif len(inputs) == 1:
input_ = inputs[0]
if isinstance(input_, ndarray):
return input_.get()
return input_.tensor
elif isinstance(input_, tuple):
result = []
for sub_input in input_:
Expand All @@ -127,4 +126,4 @@ def to_tensors(*inputs):
"""Convert all array_likes from `inputs` to tensors."""
from ._ndarray import asarray, ndarray

return tuple(asarray(value).get() for value in inputs)
return tuple(asarray(value).tensor for value in inputs)
93 changes: 37 additions & 56 deletions torch_np/_ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,48 +62,36 @@ def __getitem__(self, key):


class ndarray:
def __init__(self):
self._tensor = torch.Tensor()
self._base = None

@classmethod
def _from_tensor_and_base(cls, tensor, base):
self = cls()
self._tensor = tensor
self._base = base
return self

def get(self):
return self._tensor
def __init__(self, t=None):
if t is None:
self.tensor = torch.Tensor()
else:
self.tensor = torch.as_tensor(t)

@property
def shape(self):
return tuple(self._tensor.shape)
return tuple(self.tensor.shape)

@property
def size(self):
return self._tensor.numel()
return self.tensor.numel()

@property
def ndim(self):
return self._tensor.ndim
return self.tensor.ndim

@property
def dtype(self):
return _dtypes.dtype(self._tensor.dtype)
return _dtypes.dtype(self.tensor.dtype)

@property
def strides(self):
elsize = self._tensor.element_size()
return tuple(stride * elsize for stride in self._tensor.stride())
elsize = self.tensor.element_size()
return tuple(stride * elsize for stride in self.tensor.stride())

@property
def itemsize(self):
return self._tensor.element_size()

@property
def base(self):
return self._base
return self.tensor.element_size()

@property
def flags(self):
Expand All @@ -112,15 +100,15 @@ def flags(self):
# check if F contiguous
from itertools import accumulate

f_strides = tuple(accumulate(list(self._tensor.shape), func=lambda x, y: x * y))
f_strides = tuple(accumulate(list(self.tensor.shape), func=lambda x, y: x * y))
f_strides = (1,) + f_strides[:-1]
is_f_contiguous = f_strides == self._tensor.stride()
is_f_contiguous = f_strides == self.tensor.stride()

return Flags(
{
"C_CONTIGUOUS": self._tensor.is_contiguous(),
"C_CONTIGUOUS": self.tensor.is_contiguous(),
"F_CONTIGUOUS": is_f_contiguous,
"OWNDATA": self._tensor._base is None,
"OWNDATA": self.tensor._base is None,
"WRITEABLE": True, # pytorch does not have readonly tensors
}
)
Expand All @@ -135,38 +123,38 @@ def real(self):

@real.setter
def real(self, value):
self._tensor.real = asarray(value).get()
self.tensor.real = asarray(value).tensor

@property
def imag(self):
return _funcs.imag(self)

@imag.setter
def imag(self, value):
self._tensor.imag = asarray(value).get()
self.tensor.imag = asarray(value).tensor

round = _funcs.round

# ctors
def astype(self, dtype):
newt = ndarray()
torch_dtype = _dtypes.dtype(dtype).torch_dtype
newt._tensor = self._tensor.to(torch_dtype)
newt.tensor = self.tensor.to(torch_dtype)
return newt

def copy(self, order="C"):
if order != "C":
raise NotImplementedError
tensor = self._tensor.clone()
return ndarray._from_tensor_and_base(tensor, None)
tensor = self.tensor.clone()
return ndarray(tensor)

def tolist(self):
return self._tensor.tolist()
return self.tensor.tolist()

### niceties ###
def __str__(self):
return (
str(self._tensor)
str(self.tensor)
.replace("tensor", "array_w")
.replace("dtype=torch.", "dtype=")
)
Expand Down Expand Up @@ -197,7 +185,7 @@ def __ne__(self, other):

def __bool__(self):
try:
return bool(self._tensor)
return bool(self.tensor)
except RuntimeError:
raise ValueError(
"The truth value of an array with more than one "
Expand All @@ -206,35 +194,35 @@ def __bool__(self):

def __index__(self):
try:
return operator.index(self._tensor.item())
return operator.index(self.tensor.item())
except Exception:
mesg = "only integer scalar arrays can be converted to a scalar index"
raise TypeError(mesg)

def __float__(self):
return float(self._tensor)
return float(self.tensor)

def __complex__(self):
try:
return complex(self._tensor)
return complex(self.tensor)
except ValueError as e:
raise TypeError(*e.args)

def __int__(self):
return int(self._tensor)
return int(self.tensor)

# XXX : are single-element ndarrays scalars?
# in numpy, only array scalars have the `is_integer` method
def is_integer(self):
try:
result = int(self._tensor) == self._tensor
result = int(self.tensor) == self.tensor
except Exception:
result = False
return result

### sequence ###
def __len__(self):
return self._tensor.shape[0]
return self.tensor.shape[0]

### arithmetic ###

Expand Down Expand Up @@ -360,8 +348,8 @@ def reshape(self, *shape, order="C"):

def sort(self, axis=-1, kind=None, order=None):
# ndarray.sort works in-place
result = _impl.sort(self._tensor, axis, kind, order)
self._tensor = result
result = _impl.sort(self.tensor, axis, kind, order)
self.tensor = result

argsort = _funcs.argsort
searchsorted = _funcs.searchsorted
Expand Down Expand Up @@ -398,13 +386,13 @@ def _upcast_int_indices(index):
def __getitem__(self, index):
index = _helpers.ndarrays_to_tensors(index)
index = ndarray._upcast_int_indices(index)
return ndarray._from_tensor_and_base(self._tensor.__getitem__(index), self)
return ndarray(self.tensor.__getitem__(index))

def __setitem__(self, index, value):
index = _helpers.ndarrays_to_tensors(index)
index = ndarray._upcast_int_indices(index)
value = _helpers.ndarrays_to_tensors(value)
return self._tensor.__setitem__(index, value)
return self.tensor.__setitem__(index, value)


# This is the ideally the only place which talks to ndarray directly.
Expand All @@ -426,25 +414,22 @@ def array(obj, dtype=None, *, copy=True, order="K", subok=False, ndmin=0, like=N
a1 = []
for elem in obj:
if isinstance(elem, ndarray):
a1.append(elem.get().tolist())
a1.append(elem.tensor.tolist())
else:
a1.append(elem)
obj = a1

# is obj an ndarray already?
base = None
if isinstance(obj, ndarray):
obj = obj._tensor
base = obj
obj = obj.tensor

# is a specific dtype requrested?
torch_dtype = None
if dtype is not None:
torch_dtype = _dtypes.dtype(dtype).torch_dtype
base = None

tensor = _util._coerce_to_tensor(obj, torch_dtype, copy, ndmin)
return ndarray._from_tensor_and_base(tensor, base)
return ndarray(tensor)


def asarray(a, dtype=None, order=None, *, like=None):
Expand All @@ -453,10 +438,6 @@ def asarray(a, dtype=None, order=None, *, like=None):
return array(a, dtype=dtype, order=order, like=like, copy=False, ndmin=0)


def maybe_set_base(tensor, base):
return ndarray._from_tensor_and_base(tensor, base)


###### dtype routines


Expand Down
49 changes: 22 additions & 27 deletions torch_np/_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@

import torch

from . import _decorators, _dtypes, _funcs, _helpers
from . import _funcs, _helpers
from ._detail import _dtypes_impl, _flips, _reductions, _util
from ._detail import implementations as _impl
from ._ndarray import array, asarray, maybe_set_base, ndarray
from ._ndarray import asarray
from ._normalizations import ArrayLike, DTypeLike, NDArray, SubokLike, normalizer

# Things to decide on (punt for now)
Expand Down Expand Up @@ -169,39 +169,34 @@ def stack(
return _helpers.result_or_out(result, out)


def array_split(ary, indices_or_sections, axis=0):
tensor = asarray(ary).get()
base = ary if isinstance(ary, ndarray) else None
result = _impl.split_helper(tensor, indices_or_sections, axis)
return tuple(maybe_set_base(x, base) for x in result)
@normalizer
def array_split(ary: ArrayLike, indices_or_sections, axis=0):
result = _impl.split_helper(ary, indices_or_sections, axis)
return _helpers.tuple_arrays_from(result)


def split(ary, indices_or_sections, axis=0):
tensor = asarray(ary).get()
base = ary if isinstance(ary, ndarray) else None
result = _impl.split_helper(tensor, indices_or_sections, axis, strict=True)
return tuple(maybe_set_base(x, base) for x in result)
@normalizer
def split(ary: ArrayLike, indices_or_sections, axis=0):
result = _impl.split_helper(ary, indices_or_sections, axis, strict=True)
return _helpers.tuple_arrays_from(result)


def hsplit(ary, indices_or_sections):
tensor = asarray(ary).get()
base = ary if isinstance(ary, ndarray) else None
result = _impl.hsplit(tensor, indices_or_sections)
return tuple(maybe_set_base(x, base) for x in result)
@normalizer
def hsplit(ary: ArrayLike, indices_or_sections):
result = _impl.hsplit(ary, indices_or_sections)
return _helpers.tuple_arrays_from(result)


def vsplit(ary, indices_or_sections):
tensor = asarray(ary).get()
base = ary if isinstance(ary, ndarray) else None
result = _impl.vsplit(tensor, indices_or_sections)
return tuple(maybe_set_base(x, base) for x in result)
@normalizer
def vsplit(ary: ArrayLike, indices_or_sections):
result = _impl.vsplit(ary, indices_or_sections)
return _helpers.tuple_arrays_from(result)


def dsplit(ary, indices_or_sections):
tensor = asarray(ary).get()
base = ary if isinstance(ary, ndarray) else None
result = _impl.dsplit(tensor, indices_or_sections)
return tuple(maybe_set_base(x, base) for x in result)
@normalizer
def dsplit(ary: ArrayLike, indices_or_sections):
result = _impl.dsplit(ary, indices_or_sections)
return _helpers.tuple_arrays_from(result)


@normalizer
Expand Down
Loading