Skip to content

Commit 3ec9dda

Browse files
authored
Merge pull request #101 from Quansight-Labs/ndarray_attr
add several ndarray method/properties
2 parents a5d3d23 + 879c329 commit 3ec9dda

File tree

1 file changed

+25
-1
lines changed

1 file changed

+25
-1
lines changed

torch_np/_ndarray.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
from . import _binary_ufuncs, _dtypes, _funcs, _helpers, _unary_ufuncs
66
from ._detail import _dtypes_impl, _util
7+
from ._normalizations import ArrayLike, normalizer
78

89
newaxis = None
910

@@ -117,6 +118,14 @@ def flags(self):
117118
}
118119
)
119120

121+
@property
122+
def data(self):
123+
return self.tensor.data_ptr()
124+
125+
@property
126+
def nbytes(self):
127+
return self.tensor.storage().nbytes()
128+
120129
@property
121130
def T(self):
122131
return self.transpose()
@@ -157,7 +166,10 @@ def view(self, dtype):
157166
tview = self.tensor.view(torch_dtype)
158167
return ndarray(tview)
159168

160-
def fill(self, value):
169+
@normalizer
170+
def fill(self, value: ArrayLike):
171+
# Both Pytorch and NumPy accept 0D arrays/tensors and scalars, and
172+
# error out on D > 0 arrays
161173
self.tensor.fill_(value)
162174

163175
def tolist(self):
@@ -395,6 +407,18 @@ def sort(self, axis=-1, kind=None, order=None):
395407
cumprod = _funcs.cumprod
396408

397409
### indexing ###
410+
def item(self, *args):
411+
# Mimic NumPy's implementation with three special cases (no arguments,
412+
# a flat index and a multi-index):
413+
# https://github.com/numpy/numpy/blob/main/numpy/core/src/multiarray/methods.c#L702
414+
if args == ():
415+
return self.tensor.item()
416+
elif len(args) == 1:
417+
# int argument
418+
return self.ravel()[args[0]]
419+
else:
420+
return self.__getitem__(args)
421+
398422
@staticmethod
399423
def _upcast_int_indices(index):
400424
if isinstance(index, torch.Tensor):

0 commit comments

Comments
 (0)