Skip to content

Commit 5b8ed80

Browse files
committed
MAINT: merge torch implementations of _flips.py to _flips
1 parent ec2a7be commit 5b8ed80

File tree

6 files changed

+204
-318
lines changed

6 files changed

+204
-318
lines changed

torch_np/_detail/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,3 @@
33

44
# leading underscore (ndarray.flatten yes, np.flatten no)
55
from .implementations import *
6-
from .implementations import _flatten

torch_np/_detail/_flips.py

Lines changed: 0 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -26,36 +26,3 @@ def fliplr(m_tensor):
2626
def rot90(m_tensor, k=1, axes=(0, 1)):
2727
axes = _util.normalize_axis_tuple(axes, m_tensor.ndim)
2828
return torch.rot90(m_tensor, k, axes)
29-
30-
31-
def swapaxes(tensor, axis1, axis2):
32-
axis1 = _util.normalize_axis_index(axis1, tensor.ndim)
33-
axis2 = _util.normalize_axis_index(axis2, tensor.ndim)
34-
return torch.swapaxes(tensor, axis1, axis2)
35-
36-
37-
# Straight vendor from:
38-
# https://github.com/numpy/numpy/blob/v1.24.0/numpy/core/numeric.py#L1259
39-
#
40-
# Also note this function in NumPy is mostly retained for backwards compat
41-
# (https://stackoverflow.com/questions/29891583/reason-why-numpy-rollaxis-is-so-confusing)
42-
# so let's not touch it unless hard pressed.
43-
def rollaxis(tensor, axis, start=0):
44-
n = tensor.ndim
45-
axis = _util.normalize_axis_index(axis, n)
46-
if start < 0:
47-
start += n
48-
msg = "'%s' arg requires %d <= %s < %d, but %d was passed in"
49-
if not (0 <= start < n + 1):
50-
raise _util.AxisError(msg % ("start", -n, "start", n + 1, start))
51-
if axis < start:
52-
# it's been removed
53-
start -= 1
54-
if axis == start:
55-
# numpy returns a view, here we try returning the tensor itself
56-
# return tensor[...]
57-
return tensor
58-
axes = list(range(0, n))
59-
axes.remove(axis)
60-
axes.insert(start, axis)
61-
return tensor.view(axes)

torch_np/_detail/implementations.py

Lines changed: 0 additions & 250 deletions
Original file line numberDiff line numberDiff line change
@@ -129,62 +129,6 @@ def triu_indices(n, k=0, m=None):
129129
return result
130130

131131

132-
def diag_indices(n, ndim=2):
133-
idx = torch.arange(n)
134-
return (idx,) * ndim
135-
136-
137-
def diag_indices_from(tensor):
138-
if not tensor.ndim >= 2:
139-
raise ValueError("input array must be at least 2-d")
140-
# For more than d=2, the strided formula is only valid for arrays with
141-
# all dimensions equal, so we check first.
142-
s = tensor.shape
143-
if s[1:] != s[:-1]:
144-
raise ValueError("All dimensions of input must be of equal length")
145-
return diag_indices(s[0], tensor.ndim)
146-
147-
148-
def fill_diagonal(tensor, t_val, wrap):
149-
# torch.Tensor.fill_diagonal_ only accepts scalars. Thus vendor the numpy source,
150-
# https://github.com/numpy/numpy/blob/v1.24.0/numpy/lib/index_tricks.py#L786-L917
151-
152-
if tensor.ndim < 2:
153-
raise ValueError("array must be at least 2-d")
154-
end = None
155-
if tensor.ndim == 2:
156-
# Explicit, fast formula for the common case. For 2-d arrays, we
157-
# accept rectangular ones.
158-
step = tensor.shape[1] + 1
159-
# This is needed to don't have tall matrix have the diagonal wrap.
160-
if not wrap:
161-
end = tensor.shape[1] * tensor.shape[1]
162-
else:
163-
# For more than d=2, the strided formula is only valid for arrays with
164-
# all dimensions equal, so we check first.
165-
s = tensor.shape
166-
if s[1:] != s[:-1]:
167-
raise ValueError("All dimensions of input must be of equal length")
168-
sz = torch.as_tensor(tensor.shape[:-1])
169-
step = 1 + (torch.cumprod(sz, 0)).sum()
170-
171-
# Write the value out into the diagonal.
172-
tensor.ravel()[:end:step] = t_val
173-
return tensor
174-
175-
176-
def trace(tensor, offset=0, axis1=0, axis2=1, dtype=None, out=None):
177-
result = torch.diagonal(tensor, offset, dim1=axis1, dim2=axis2).sum(-1, dtype=dtype)
178-
return result
179-
180-
181-
def diagonal(tensor, offset=0, axis1=0, axis2=1):
182-
axis1 = _util.normalize_axis_index(axis1, tensor.ndim)
183-
axis2 = _util.normalize_axis_index(axis2, tensor.ndim)
184-
result = torch.diagonal(tensor, offset, axis1, axis2)
185-
return result
186-
187-
188132
# ### splits ###
189133

190134

@@ -263,14 +207,6 @@ def dsplit(tensor, indices_or_sections):
263207
return split_helper(tensor, indices_or_sections, 2, strict=True)
264208

265209

266-
def clip(tensor, t_min, t_max):
267-
if t_min is None and t_max is None:
268-
raise ValueError("One of max or min must be given")
269-
270-
result = tensor.clamp(t_min, t_max)
271-
return result
272-
273-
274210
def diff(a_tensor, n=1, axis=-1, prepend_tensor=None, append_tensor=None):
275211
axis = _util.normalize_axis_index(axis, a_tensor.ndim)
276212

@@ -364,14 +300,6 @@ def vstack(tensors, *, dtype=None, casting="same_kind"):
364300
return result
365301

366302

367-
def tile(tensor, reps):
368-
if isinstance(reps, int):
369-
reps = (reps,)
370-
371-
result = torch.tile(tensor, reps)
372-
return result
373-
374-
375303
# #### cov & corrcoef
376304

377305

@@ -549,14 +477,6 @@ def arange(start=None, stop=None, step=1, dtype=None):
549477
# ### empty/full et al ###
550478

551479

552-
def eye(N, M=None, k=0, dtype=float):
553-
if M is None:
554-
M = N
555-
z = torch.zeros(N, M, dtype=dtype)
556-
z.diagonal(k).fill_(1)
557-
return z
558-
559-
560480
def zeros(shape, dtype=None, order="C"):
561481
if order != "C":
562482
raise NotImplementedError
@@ -637,102 +557,12 @@ def full(shape, fill_value, dtype=None, order="C"):
637557
# ### shape manipulations ###
638558

639559

640-
def roll(tensor, shift, axis=None):
641-
if axis is not None:
642-
axis = _util.normalize_axis_tuple(axis, tensor.ndim, allow_duplicate=True)
643-
if not isinstance(shift, tuple):
644-
shift = (shift,) * len(axis)
645-
result = tensor.roll(shift, axis)
646-
return result
647-
648-
649-
def squeeze(tensor, axis=None):
650-
if axis == ():
651-
result = tensor
652-
elif axis is None:
653-
result = tensor.squeeze()
654-
else:
655-
if isinstance(axis, tuple):
656-
result = tensor
657-
for ax in axis:
658-
result = result.squeeze(ax)
659-
else:
660-
result = tensor.squeeze(axis)
661-
return result
662-
663-
664-
def reshape(tensor, shape, order="C"):
665-
if order != "C":
666-
raise NotImplementedError
667-
# if sh = (1, 2, 3), numpy allows both .reshape(sh) and .reshape(*sh)
668-
newshape = shape[0] if len(shape) == 1 else shape
669-
result = tensor.reshape(newshape)
670-
return result
671-
672-
673-
def transpose(tensor, axes=None):
674-
# numpy allows both .tranpose(sh) and .transpose(*sh)
675-
if axes in [(), None, (None,)]:
676-
axes = tuple(range(tensor.ndim))[::-1]
677-
try:
678-
result = tensor.permute(axes)
679-
except RuntimeError:
680-
raise ValueError("axes don't match array")
681-
return result
682-
683-
684-
def ravel(tensor, order="C"):
685-
if order != "C":
686-
raise NotImplementedError
687-
result = tensor.ravel()
688-
return result
689-
690-
691-
# leading underscore since arr.flatten exists but np.flatten does not
692-
def _flatten(tensor, order="C"):
693-
if order != "C":
694-
raise NotImplementedError
695-
# return a copy
696-
result = tensor.flatten()
697-
return result
698-
699-
700560
# ### swap/move/roll axis ###
701561

702562

703-
def moveaxis(tensor, source, destination):
704-
source = _util.normalize_axis_tuple(source, tensor.ndim, "source")
705-
destination = _util.normalize_axis_tuple(destination, tensor.ndim, "destination")
706-
result = torch.moveaxis(tensor, source, destination)
707-
return result
708-
709-
710563
# ### Numeric ###
711564

712565

713-
def round(tensor, decimals=0):
714-
if tensor.is_floating_point():
715-
result = torch.round(tensor, decimals=decimals)
716-
elif tensor.is_complex():
717-
# RuntimeError: "round_cpu" not implemented for 'ComplexFloat'
718-
result = (
719-
torch.round(tensor.real, decimals=decimals)
720-
+ torch.round(tensor.imag, decimals=decimals) * 1j
721-
)
722-
else:
723-
# RuntimeError: "round_cpu" not implemented for 'int'
724-
result = tensor
725-
return result
726-
727-
728-
def imag(tensor):
729-
if tensor.is_complex():
730-
result = tensor.imag
731-
else:
732-
result = torch.zeros_like(tensor)
733-
return result
734-
735-
736566
# ### put/take along axis ###
737567

738568

@@ -753,36 +583,6 @@ def put_along_dim(tensor, t_indices, t_values, axis):
753583
return result
754584

755585

756-
# ### sort and partition ###
757-
758-
759-
def _sort_helper(tensor, axis, kind, order):
760-
if order is not None:
761-
# only relevant for structured dtypes; not supported
762-
raise NotImplementedError(
763-
"'order' keyword is only relevant for structured dtypes"
764-
)
765-
766-
(tensor,), axis = _util.axis_none_ravel(tensor, axis=axis)
767-
axis = _util.normalize_axis_index(axis, tensor.ndim)
768-
769-
stable = kind == "stable"
770-
771-
return tensor, axis, stable
772-
773-
774-
def sort(tensor, axis=-1, kind=None, order=None):
775-
tensor, axis, stable = _sort_helper(tensor, axis, kind, order)
776-
result = torch.sort(tensor, dim=axis, stable=stable)
777-
return result.values
778-
779-
780-
def argsort(tensor, axis=-1, kind=None, order=None):
781-
tensor, axis, stable = _sort_helper(tensor, axis, kind, order)
782-
result = torch.argsort(tensor, dim=axis, stable=stable)
783-
return result
784-
785-
786586
# ### logic and selection ###
787587

788588

@@ -831,56 +631,6 @@ def inner(t_a, t_b):
831631
return result
832632

833633

834-
def vdot(t_a, t_b, /):
835-
# 1. torch only accepts 1D arrays, numpy ravels
836-
# 2. torch requires matching dtype, while numpy casts (?)
837-
t_a, t_b = torch.atleast_1d(t_a, t_b)
838-
if t_a.ndim > 1:
839-
t_a = t_a.ravel()
840-
if t_b.ndim > 1:
841-
t_b = t_b.ravel()
842-
843-
dtype = _dtypes_impl.result_type_impl((t_a.dtype, t_b.dtype))
844-
is_half = dtype == torch.float16
845-
is_bool = dtype == torch.bool
846-
847-
# work around torch's "dot" not implemented for 'Half', 'Bool'
848-
if is_half:
849-
dtype = torch.float32
850-
if is_bool:
851-
dtype = torch.uint8
852-
853-
t_a = _util.cast_if_needed(t_a, dtype)
854-
t_b = _util.cast_if_needed(t_b, dtype)
855-
856-
result = torch.vdot(t_a, t_b)
857-
858-
if is_half:
859-
result = result.to(torch.float16)
860-
if is_bool:
861-
result = result.to(torch.bool)
862-
863-
return result
864-
865-
866-
def dot(t_a, t_b):
867-
dtype = _dtypes_impl.result_type_impl((t_a.dtype, t_b.dtype))
868-
t_a = _util.cast_if_needed(t_a, dtype)
869-
t_b = _util.cast_if_needed(t_b, dtype)
870-
871-
if t_a.ndim == 0 or t_b.ndim == 0:
872-
result = t_a * t_b
873-
elif t_a.ndim == 1 and t_b.ndim == 1:
874-
result = torch.dot(t_a, t_b)
875-
elif t_a.ndim == 1:
876-
result = torch.mv(t_b.T, t_a).T
877-
elif t_b.ndim == 1:
878-
result = torch.mv(t_a, t_b)
879-
else:
880-
result = torch.matmul(t_a, t_b)
881-
return result
882-
883-
884634
# ### unique et al ###
885635

886636

0 commit comments

Comments
 (0)