Skip to content

Commit ba31a0b

Browse files
committed
MAINT: move tensor manipulations from _ndarray to _impl
1 parent 6de03dd commit ba31a0b

File tree

3 files changed

+74
-37
lines changed

3 files changed

+74
-37
lines changed

torch_np/_detail/implementations.py

Lines changed: 61 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,6 @@ def dsplit(tensor, indices_or_sections):
216216
return split_helper(tensor, indices_or_sections, 2, strict=True)
217217

218218

219-
220219
def clip(tensor, t_min, t_max):
221220
if t_min is not None:
222221
t_min = torch.broadcast_to(t_min, tensor.shape)
@@ -488,3 +487,64 @@ def full(shape, fill_value, dtype=None):
488487
shape = (shape,)
489488
result = torch.full(shape, fill_value, dtype=dtype)
490489
return result
490+
491+
492+
# ### shape manipulations ###
493+
494+
495+
def roll(tensor, shift, axis=None):
496+
if axis is not None:
497+
axis = _util.normalize_axis_tuple(axis, tensor.ndim, allow_duplicate=True)
498+
if not isinstance(shift, tuple):
499+
shift = (shift,) * len(axis)
500+
result = tensor.roll(shift, axis)
501+
return result
502+
503+
504+
def squeeze(tensor, axis=None):
505+
if axis == ():
506+
result = tensor
507+
elif axis is None:
508+
result = tensor.squeeze()
509+
else:
510+
result = tensor.squeeze(axis)
511+
return result
512+
513+
514+
def reshape(tensor, *shape, order="C"):
515+
if order != "C":
516+
raise NotImplementedError
517+
newshape = shape[0] if len(shape) == 1 else shape
518+
# if sh = (1, 2, 3), numpy allows both .reshape(sh) and .reshape(*sh)
519+
result = tensor.reshape(newshape)
520+
return result
521+
522+
523+
def transpose(tensor, *axes):
524+
# numpy allows both .reshape(sh) and .reshape(*sh)
525+
axes = axes[0] if len(axes) == 1 else axes
526+
if axes == () or axes is None:
527+
axes = tuple(range(tensor.ndim))[::-1]
528+
try:
529+
result = tensor.permute(axes)
530+
except RuntimeError:
531+
raise ValueError("axes don't match array")
532+
return result
533+
534+
535+
# ### Numeric ###
536+
537+
538+
def round(tensor, decimals=0):
539+
if tensor.is_floating_point():
540+
result = torch.round(tensor, decimals=decimals)
541+
elif tensor.is_complex():
542+
# RuntimeError: "round_cpu" not implemented for 'ComplexFloat'
543+
result = (
544+
torch.round(tensor.real, decimals=decimals)
545+
+ torch.round(tensor.imag, decimals=decimals) * 1j
546+
)
547+
else:
548+
# RuntimeError: "round_cpu" not implemented for 'int'
549+
result = tensor
550+
return result

torch_np/_ndarray.py

Lines changed: 7 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -139,11 +139,7 @@ def imag(self, value):
139139
self._tensor.imag = asarray(value).get()
140140

141141
def round(self, decimals=0, out=None):
142-
tensor = self._tensor
143-
if torch.is_floating_point(tensor):
144-
result = torch.round(tensor, decimals=decimals)
145-
else:
146-
result = tensor
142+
result = _impl.round(self._tensor, decimals)
147143
return _helpers.result_or_out(result, out)
148144

149145
# ctors
@@ -328,32 +324,16 @@ def __irshift__(self, other):
328324
### methods to match namespace functions
329325

330326
def squeeze(self, axis=None):
331-
if axis == ():
332-
tensor = self._tensor
333-
elif axis is None:
334-
tensor = self._tensor.squeeze()
335-
else:
336-
tensor = self._tensor.squeeze(axis)
337-
return ndarray._from_tensor_and_base(tensor, self)
327+
result = _impl.squeeze(self._tensor, axis)
328+
return ndarray._from_tensor_and_base(result, self)
338329

339330
def reshape(self, *shape, order="C"):
340-
newshape = shape[0] if len(shape) == 1 else shape
341-
# if sh = (1, 2, 3), numpy allows both .reshape(sh) and .reshape(*sh)
342-
if order != "C":
343-
raise NotImplementedError
344-
tensor = self._tensor.reshape(newshape)
345-
return ndarray._from_tensor_and_base(tensor, self)
331+
result = _impl.reshape(self._tensor, *shape, order=order)
332+
return ndarray._from_tensor_and_base(result, self)
346333

347334
def transpose(self, *axes):
348-
# numpy allows both .reshape(sh) and .reshape(*sh)
349-
axes = axes[0] if len(axes) == 1 else axes
350-
if axes == () or axes is None:
351-
axes = tuple(range(self.ndim))[::-1]
352-
try:
353-
tensor = self._tensor.permute(axes)
354-
except RuntimeError:
355-
raise ValueError("axes don't match array")
356-
return ndarray._from_tensor_and_base(tensor, self)
335+
result = _impl.transpose(self._tensor, *axes)
336+
return ndarray._from_tensor_and_base(result, self)
357337

358338
def swapaxes(self, axis1, axis2):
359339
return asarray(_flips.swapaxes(self._tensor, axis1, axis2))

torch_np/_wrapper.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -595,9 +595,9 @@ def flatnonzero(a):
595595

596596

597597
def argwhere(a):
598-
arr = asarray(a)
599-
tensor = arr.get()
600-
return asarray(torch.argwhere(tensor))
598+
tensor = asarray(a).get()
599+
result = torch.argwhere(tensor)
600+
return asarray(result)
601601

602602

603603
from ._decorators import emulate_out_arg
@@ -606,13 +606,10 @@ def argwhere(a):
606606
count_nonzero = emulate_out_arg(axis_keepdims_wrapper(_reductions.count_nonzero))
607607

608608

609-
@asarray_replacer()
610609
def roll(a, shift, axis=None):
611-
if axis is not None:
612-
axis = _util.normalize_axis_tuple(axis, a.ndim, allow_duplicate=True)
613-
if not isinstance(shift, tuple):
614-
shift = (shift,) * len(axis)
615-
return a.roll(shift, axis)
610+
tensor = asarray(a).get()
611+
result = _impl.roll(tensor, shift, axis)
612+
return asarray(result)
616613

617614

618615
def round_(a, decimals=0, out=None):

0 commit comments

Comments
 (0)