Skip to content

Commit 9f11675

Browse files
committed
MAINT: add a utility for to wrap tensor.to(dtype != tensor.dtype)
1 parent cf4d5c1 commit 9f11675

File tree

3 files changed

+20
-26
lines changed

3 files changed

+20
-26
lines changed

torch_np/_detail/_reductions.py

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -146,9 +146,7 @@ def std(tensor, axis=None, dtype=None, ddof=0, *, where=NoValue):
146146
raise NotImplementedError
147147

148148
dtype = _atleast_float(dtype, tensor.dtype)
149-
150-
if dtype is not None:
151-
tensor = tensor.to(dtype)
149+
tensor = _util.cast_if_needed(tensor, dtype)
152150
result = tensor.std(dim=axis, correction=ddof)
153151

154152
return result
@@ -159,9 +157,7 @@ def var(tensor, axis=None, dtype=None, ddof=0, *, where=NoValue):
159157
raise NotImplementedError
160158

161159
dtype = _atleast_float(dtype, tensor.dtype)
162-
163-
if dtype is not None:
164-
tensor = tensor.to(dtype)
160+
tensor = _util.cast_if_needed(tensor, dtype)
165161
result = tensor.var(dim=axis, correction=ddof)
166162

167163
return result
@@ -204,10 +200,9 @@ def average(a_tensor, axis, w_tensor):
204200
a_tensor = a_tensor.to(result_dtype)
205201

206202
result_dtype = _dtypes_impl.result_type_impl([a_tensor.dtype, w_tensor.dtype])
207-
if a_tensor.dtype != result_dtype:
208-
a_tensor = a_tensor.to(result_dtype)
209-
if w_tensor.dtype != result_dtype:
210-
w_tensor = w_tensor.to(result_dtype)
203+
204+
a_tensor = _util.cast_if_needed(a_tensor, result_dtype)
205+
w_tensor = _util.cast_if_needed(w_tensor, result_dtype)
211206

212207
# axis
213208
if axis is None:
@@ -258,7 +253,7 @@ def quantile(a_tensor, q_tensor, axis, method):
258253
axis = _util.normalize_axis_tuple(axis, a_tensor.ndim)
259254
axis = _util.allow_only_single_axis(axis)
260255

261-
q_tensor = q_tensor.to(a_tensor.dtype)
256+
q_tensor = _util.cast_if_needed(q_tensor, a_tensor.dtype)
262257

263258
(a_tensor, q_tensor), axis = _util.axis_none_ravel(a_tensor, q_tensor, axis=axis)
264259

torch_np/_detail/_util.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,13 @@ class UFuncTypeError(TypeError, RuntimeError):
3434
pass
3535

3636

37+
def cast_if_needed(tensor, dtype):
38+
# NB: no casting if dtype=None
39+
if tensor.dtype != dtype:
40+
tensor = tensor.to(dtype)
41+
return tensor
42+
43+
3744
# a replica of the version in ./numpy/numpy/core/src/multiarray/common.h
3845
def normalize_axis_index(ax, ndim, argname=None):
3946
if not (-ndim <= ax < ndim):
@@ -156,10 +163,7 @@ def cast_dont_broadcast(tensors, target_dtype, casting):
156163
f"Cannot cast array data from {tensor.dtype} to"
157164
f" {target_dtype} according to the rule '{casting}'"
158165
)
159-
160-
# cast if needed
161-
if tensor.dtype != target_dtype:
162-
tensor = tensor.to(target_dtype)
166+
tensor = cast_if_needed(tensor, target_dtype)
163167
cast_tensors.append(tensor)
164168

165169
return tuple(cast_tensors)
@@ -200,8 +204,7 @@ def cast_and_broadcast(tensors, out_param, casting):
200204
)
201205

202206
# cast arr if needed
203-
if tensor.dtype != target_dtype:
204-
tensor = tensor.to(target_dtype)
207+
tensor = cast_if_needed(tensor, target_dtype)
205208

206209
# `out` broadcasts `tensor`
207210
if tensor.shape != target_shape:
@@ -285,8 +288,7 @@ def _coerce_to_tensor(obj, dtype=None, copy=False, ndmin=0):
285288
tensor = torch.as_tensor(obj, dtype=torch_dtype)
286289

287290
# type cast if requested
288-
if dtype is not None:
289-
tensor = tensor.to(dtype)
291+
tensor = cast_if_needed(tensor, dtype)
290292

291293
# adjust ndim if needed
292294
ndim_extra = ndmin - tensor.ndim

torch_np/_detail/implementations.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,8 @@ def tensor_equiv(a1_t, a2_t):
3333

3434
def isclose(a, b, rtol=1.0e-5, atol=1.0e-8, equal_nan=False):
3535
dtype = _dtypes_impl.result_type_impl((a.dtype, b.dtype))
36-
a = a.to(dtype)
37-
b = b.to(dtype)
36+
a = _util.cast_if_needed(a, dtype)
37+
b = _util.cast_if_needed(b, dtype)
3838
result = torch.isclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan)
3939
return result
4040

@@ -308,9 +308,7 @@ def corrcoef(xy_tensor, *, dtype=None):
308308
# work around torch's "addmm_impl_cpu_" not implemented for 'Half'"
309309
dtype = torch.float32
310310

311-
if dtype is not None:
312-
xy_tensor = xy_tensor.to(dtype)
313-
311+
xy_tensor = _util.cast_if_needed(xy_tensor, dtype)
314312
result = torch.corrcoef(xy_tensor)
315313

316314
if is_half:
@@ -336,8 +334,7 @@ def cov(
336334
# work around torch's "addmm_impl_cpu_" not implemented for 'Half'"
337335
dtype = torch.float32
338336

339-
if dtype is not None:
340-
m_tensor = m_tensor.to(dtype)
337+
m_tensor = _util.cast_if_needed(m_tensor, dtype)
341338

342339
result = torch.cov(
343340
m_tensor, correction=ddof, aweights=aweights_tensor, fweights=fweights_tensor

0 commit comments

Comments
 (0)