|
4 | 4 |
|
5 | 5 | from . import _binary_ufuncs, _dtypes, _funcs, _helpers, _unary_ufuncs
|
6 | 6 | from ._detail import _dtypes_impl, _util
|
| 7 | +from ._normalizations import ArrayLike, normalizer |
7 | 8 |
|
8 | 9 | newaxis = None
|
9 | 10 |
|
@@ -117,6 +118,14 @@ def flags(self):
|
117 | 118 | }
|
118 | 119 | )
|
119 | 120 |
|
| 121 | + @property |
| 122 | + def data(self): |
| 123 | + return self.tensor.data_ptr() |
| 124 | + |
| 125 | + @property |
| 126 | + def nbytes(self): |
| 127 | + return self.tensor.storage().nbytes() |
| 128 | + |
120 | 129 | @property
|
121 | 130 | def T(self):
|
122 | 131 | return self.transpose()
|
@@ -157,7 +166,10 @@ def view(self, dtype):
|
157 | 166 | tview = self.tensor.view(torch_dtype)
|
158 | 167 | return ndarray(tview)
|
159 | 168 |
|
160 |
| - def fill(self, value): |
| 169 | + @normalizer |
| 170 | + def fill(self, value: ArrayLike): |
| 171 | + # Both Pytorch and NumPy accept 0D arrays/tensors and scalars, and |
| 172 | + # error out on D > 0 arrays |
161 | 173 | self.tensor.fill_(value)
|
162 | 174 |
|
163 | 175 | def tolist(self):
|
@@ -395,6 +407,18 @@ def sort(self, axis=-1, kind=None, order=None):
|
395 | 407 | cumprod = _funcs.cumprod
|
396 | 408 |
|
397 | 409 | ### indexing ###
|
| 410 | + def item(self, *args): |
| 411 | + # Mimic NumPy's implementation with three special cases (no arguments, |
| 412 | + # a flat index and a multi-index): |
| 413 | + # https://github.com/numpy/numpy/blob/main/numpy/core/src/multiarray/methods.c#L702 |
| 414 | + if args == (): |
| 415 | + return self.tensor.item() |
| 416 | + elif len(args) == 1: |
| 417 | + # int argument |
| 418 | + return self.ravel()[args[0]] |
| 419 | + else: |
| 420 | + return self.__getitem__(args) |
| 421 | + |
398 | 422 | @staticmethod
|
399 | 423 | def _upcast_int_indices(index):
|
400 | 424 | if isinstance(index, torch.Tensor):
|
|
0 commit comments