Skip to content

Commit 7924195

Browse files
committed
add several ndarray method/properties
1 parent 4fc9380 commit 7924195

File tree

1 file changed

+20
-1
lines changed

1 file changed

+20
-1
lines changed

torch_np/_ndarray.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
from . import _binary_ufuncs, _dtypes, _funcs, _helpers, _unary_ufuncs
66
from ._detail import _dtypes_impl, _util
7+
from ._normalizations import ArrayLike, normalizer
78

89
newaxis = None
910

@@ -117,6 +118,14 @@ def flags(self):
117118
}
118119
)
119120

121+
@property
122+
def data(self):
123+
return self.tensor.data_ptr()
124+
125+
@property
126+
def nbytes(self):
127+
return self.tensor.storage().nbytes()
128+
120129
@property
121130
def T(self):
122131
return self.transpose()
@@ -157,7 +166,8 @@ def view(self, dtype):
157166
tview = self.tensor.view(torch_dtype)
158167
return ndarray(tview)
159168

160-
def fill(self, value):
169+
@normalizer
170+
def fill(self, value: ArrayLike):
161171
self.tensor.fill_(value)
162172

163173
def tolist(self):
@@ -395,6 +405,15 @@ def sort(self, axis=-1, kind=None, order=None):
395405
cumprod = _funcs.cumprod
396406

397407
### indexing ###
408+
def item(self, *args):
409+
if args == ():
410+
return self.tensor.item()
411+
elif len(args) == 1:
412+
# int argument
413+
return self.ravel()[args[0]]
414+
else:
415+
return self.__getitem__(args)
416+
398417
@staticmethod
399418
def _upcast_int_indices(index):
400419
if isinstance(index, torch.Tensor):

0 commit comments

Comments
 (0)