Skip to content

Commit 4450f17

Browse files
committed
MAINT: move shape/axes manipulations from ndarray
1 parent c730cb7 commit 4450f17

File tree

5 files changed

+101
-66
lines changed

5 files changed

+101
-66
lines changed

torch_np/_detail/implementations.py

Lines changed: 31 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -581,21 +581,18 @@ def squeeze(tensor, axis=None):
581581
return result
582582

583583

584-
def reshape(tensor, *shape, order="C"):
584+
def reshape(tensor, shape, order="C"):
585585
if order != "C":
586586
raise NotImplementedError
587-
newshape = shape[0] if len(shape) == 1 else shape
588-
# convert any tnp.ndarray inputs into tensors before passing to torch.Tensor.reshape
589-
t_newshape = _helpers.ndarrays_to_tensors(newshape)
590587
# if sh = (1, 2, 3), numpy allows both .reshape(sh) and .reshape(*sh)
591-
result = tensor.reshape(t_newshape)
588+
newshape = shape[0] if len(shape) == 1 else shape
589+
result = tensor.reshape(newshape)
592590
return result
593591

594592

595-
def transpose(tensor, *axes):
596-
# numpy allows both .reshape(sh) and .reshape(*sh)
597-
axes = axes[0] if len(axes) == 1 else axes
598-
if axes == () or axes is None:
593+
def transpose(tensor, axes=None):
594+
# numpy allows both .tranpose(sh) and .transpose(*sh)
595+
if axes in [(), None, (None,)]:
599596
axes = tuple(range(tensor.ndim))[::-1]
600597
try:
601598
result = tensor.permute(axes)
@@ -604,6 +601,31 @@ def transpose(tensor, *axes):
604601
return result
605602

606603

604+
def ravel(tensor, order="C"):
605+
if order != "C":
606+
raise NotImplementedError
607+
result = tensor.ravel()
608+
return result
609+
610+
611+
# leading underscore since arr.flatten exists but np.flatten does not
612+
def _flatten(tensor, order="C"):
613+
if order != "C":
614+
raise NotImplementedError
615+
# return a copy
616+
result = tensor.ravel().clone()
617+
return result
618+
619+
620+
# ### swap/move/roll axis ###
621+
622+
def moveaxis(tensor, source, destination):
623+
source = _util.normalize_axis_tuple(source, tensor.ndim, "source")
624+
destination = _util.normalize_axis_tuple(destination, tensor.ndim, "destination")
625+
result = torch.moveaxis(tensor, source, destination)
626+
return result
627+
628+
607629
# ### Numeric ###
608630

609631

torch_np/_funcs.py

Lines changed: 58 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import torch
22

33
from . import _decorators, _helpers
4-
from ._detail import _util
4+
from ._detail import _util, _flips
55
from ._detail import implementations as _impl
66

77

@@ -114,3 +114,60 @@ def searchsorted(a, v, side="left", sorter=None):
114114
a_t, v_t, sorter_t = _helpers.to_tensors_or_none(a, v, sorter)
115115
result = torch.searchsorted(a_t, v_t, side=side, sorter=sorter_t)
116116
return _helpers.array_from(result)
117+
118+
119+
# ### swap/move/roll axis ###
120+
121+
122+
def moveaxis(a, source, destination):
123+
tensor, = _helpers.to_tensors(a)
124+
result = _impl.moveaxis(tensor, source, destination)
125+
return _helpers.array_from(result)
126+
127+
128+
def swapaxes(a, axis1, axis2):
129+
tensor, = _helpers.to_tensors(a)
130+
result = _flips.swapaxes(tensor, axis1, axis2)
131+
return _helpers.array_from(result)
132+
133+
134+
def rollaxis(a, axis, start=0):
135+
tensor, = _helpers.to_tensors(a)
136+
result = _flips.rollaxis(a, axis, start)
137+
return _helpers.array_from(result)
138+
139+
140+
# ### shape manipulations ###
141+
142+
def squeeze(a, axis=None):
143+
tensor, = _helpers.to_tensors(a)
144+
result = _impl.squeeze(tensor, axis)
145+
return _helpers.array_from(result, a)
146+
147+
148+
def reshape(a, newshape, order="C"):
149+
tensor, = _helpers.to_tensors(a)
150+
result = _impl.reshape(tensor, newshape, order=order)
151+
return _helpers.array_from(result, a)
152+
153+
154+
def transpose(a, axes=None):
155+
tensor, = _helpers.to_tensors(a)
156+
result = _impl.transpose(tensor, axes)
157+
return _helpers.array_from(result, a)
158+
159+
160+
def ravel(a, order="C"):
161+
tensor, = _helpers.to_tensors(a)
162+
result = _impl.ravel(tensor)
163+
return _helpers.array_from(result, a)
164+
165+
166+
# leading underscore since arr.flatten exists but np.flatten does not
167+
def _flatten(a, order="C"):
168+
tensor, = _helpers.to_tensors(a)
169+
result = _impl._flatten(tensor)
170+
return _helpers.array_from(result, a)
171+
172+
173+

torch_np/_helpers.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -73,10 +73,10 @@ def result_or_out(result_tensor, out_array=None, promote_scalar=False):
7373
return asarray(result_tensor)
7474

7575

76-
def array_from(tensor):
76+
def array_from(tensor, base=None):
7777
from ._ndarray import ndarray
78-
79-
return ndarray._from_tensor_and_base(tensor, None) # XXX: nuke .base
78+
base = base if isinstance(base, ndarray) else None
79+
return ndarray._from_tensor_and_base(tensor, base) # XXX: nuke .base
8080

8181

8282
def tuple_arrays_from(result):

torch_np/_ndarray.py

Lines changed: 9 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -348,31 +348,19 @@ def __irshift__(self, other):
348348

349349
### methods to match namespace functions
350350

351-
def squeeze(self, axis=None):
352-
result = _impl.squeeze(self._tensor, axis)
353-
return ndarray._from_tensor_and_base(result, self)
354-
355-
def reshape(self, *shape, order="C"):
356-
result = _impl.reshape(self._tensor, *shape, order=order)
357-
return ndarray._from_tensor_and_base(result, self)
351+
squeeze = _funcs.squeeze
352+
swapaxes = _funcs.swapaxes
358353

359354
def transpose(self, *axes):
360-
result = _impl.transpose(self._tensor, *axes)
361-
return ndarray._from_tensor_and_base(result, self)
355+
# np.transpose(arr, axis=None) but arr.transpose(*axes)
356+
return _funcs.transpose(self, axes)
362357

363-
def swapaxes(self, axis1, axis2):
364-
return asarray(_flips.swapaxes(self._tensor, axis1, axis2))
365-
366-
def ravel(self, order="C"):
367-
if order != "C":
368-
raise NotImplementedError
369-
return ndarray._from_tensor_and_base(self._tensor.ravel(), self)
358+
def reshape(self, *shape, order="C"):
359+
# arr.reshape(shape) and arr.reshape(*shape)
360+
return _funcs.reshape(self, shape, order=order)
370361

371-
def flatten(self, order="C"):
372-
if order != "C":
373-
raise NotImplementedError
374-
result = self._tensor.flatten()
375-
return asarray(result)
362+
ravel = _funcs.ravel
363+
flatten = _funcs._flatten
376364

377365
nonzero = _funcs.nonzero
378366
clip = _funcs.clip

torch_np/_wrapper.py

Lines changed: 0 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -457,26 +457,11 @@ def size(a, axis=None):
457457
###### shape manipulations and indexing
458458

459459

460-
def transpose(a, axes=None):
461-
arr = asarray(a)
462-
return arr.transpose(axes)
463-
464-
465-
def reshape(a, newshape, order="C"):
466-
arr = asarray(a)
467-
return arr.reshape(*newshape, order=order)
468-
469-
470460
def ravel(a, order="C"):
471461
arr = asarray(a)
472462
return arr.ravel(order=order)
473463

474464

475-
def squeeze(a, axis=None):
476-
arr = asarray(a)
477-
return arr.squeeze(axis)
478-
479-
480465
def expand_dims(a, axis):
481466
a = asarray(a)
482467
shape = _util.expand_shape(a.shape, axis)
@@ -521,23 +506,6 @@ def broadcast_arrays(*args, subok=False):
521506
return tuple(asarray(_) for _ in res)
522507

523508

524-
@asarray_replacer()
525-
def moveaxis(a, source, destination):
526-
source = _util.normalize_axis_tuple(source, a.ndim, "source")
527-
destination = _util.normalize_axis_tuple(destination, a.ndim, "destination")
528-
return asarray(torch.moveaxis(a, source, destination))
529-
530-
531-
def swapaxes(a, axis1, axis2):
532-
arr = asarray(a)
533-
return arr.swapaxes(axis1, axis2)
534-
535-
536-
@asarray_replacer()
537-
def rollaxis(a, axis, start=0):
538-
return _flips.rollaxis(a, axis, start)
539-
540-
541509
def unravel_index(indices, shape, order="C"):
542510
# cf https://github.com/pytorch/pytorch/pull/66687
543511
# this version is from

0 commit comments

Comments
 (0)