Skip to content

Commit 4331a01

Browse files
committed
MAINT: split clip
1 parent 4798f75 commit 4331a01

File tree

2 files changed

+17
-18
lines changed

2 files changed

+17
-18
lines changed

torch_np/_detail/implementations.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,20 @@ def split_helper_list(tensor, indices_or_sections, axis, strict=False):
7272
return torch.split(tensor, lst, axis)
7373

7474

75+
def clip(tensor, t_min, t_max):
76+
if t_min is not None:
77+
t_min = torch.broadcast_to(t_min, tensor.shape)
78+
79+
if t_max is not None:
80+
t_max = torch.broadcast_to(t_max, tensor.shape)
81+
82+
if t_min is None and t_max is None:
83+
raise ValueError("One of max or min must be given")
84+
85+
result = tensor.clamp(t_min, t_max)
86+
return result
87+
88+
7589
def diff(a_tensor, n=1, axis=-1, prepend_tensor=None, append_tensor=None):
7690
axis = _util.normalize_axis_index(axis, a_tensor.ndim)
7791

torch_np/_ndarray.py

Lines changed: 3 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
emulate_out_arg,
1313
)
1414
from ._detail import _dtypes_impl, _flips, _reductions, _util
15+
from ._detail import implementations as _impl
1516

1617
newaxis = None
1718

@@ -367,24 +368,8 @@ def nonzero(self):
367368
return tuple(asarray(_) for _ in tensor.nonzero(as_tuple=True))
368369

369370
def clip(self, min, max, out=None):
370-
tensor = self._tensor
371-
a_min, a_max = min, max
372-
373-
t_min = None
374-
if a_min is not None:
375-
t_min = asarray(a_min).get()
376-
t_min = torch.broadcast_to(t_min, tensor.shape)
377-
378-
t_max = None
379-
if a_max is not None:
380-
t_max = asarray(a_max).get()
381-
t_max = torch.broadcast_to(t_max, tensor.shape)
382-
383-
if t_min is None and t_max is None:
384-
raise ValueError("One of max or min must be given")
385-
386-
result = tensor.clamp(t_min, t_max)
387-
371+
tensor, t_min, t_max = _helpers.to_tensors_or_none(self, min, max)
372+
result = _impl.clip(tensor, t_min, t_max)
388373
return _helpers.result_or_out(result, out)
389374

390375
argmin = emulate_out_arg(axis_keepdims_wrapper(_reductions.argmin))

0 commit comments

Comments
 (0)