Skip to content

Commit 8a91f0a

Browse files
authored
Merge pull request #52 from Quansight-Labs/move_to_impl
make better split between wrappers and torch implementations
2 parents 39734f7 + ba31a0b commit 8a91f0a

File tree

9 files changed

+617
-424
lines changed

9 files changed

+617
-424
lines changed

torch_np/_decorators.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ def wrapped(*args, dtype=None, **kwds):
1515
torch_dtype = None
1616
if dtype is not None:
1717
dtype = _dtypes.dtype(dtype)
18-
torch_dtype = dtype._scalar_type.torch_dtype
18+
torch_dtype = dtype.torch_dtype
1919
return func(*args, dtype=torch_dtype, **kwds)
2020

2121
return wrapped

torch_np/_detail/_reductions.py

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -146,9 +146,7 @@ def std(tensor, axis=None, dtype=None, ddof=0, *, where=NoValue):
146146
raise NotImplementedError
147147

148148
dtype = _atleast_float(dtype, tensor.dtype)
149-
150-
if dtype is not None:
151-
tensor = tensor.to(dtype)
149+
tensor = _util.cast_if_needed(tensor, dtype)
152150
result = tensor.std(dim=axis, correction=ddof)
153151

154152
return result
@@ -159,9 +157,7 @@ def var(tensor, axis=None, dtype=None, ddof=0, *, where=NoValue):
159157
raise NotImplementedError
160158

161159
dtype = _atleast_float(dtype, tensor.dtype)
162-
163-
if dtype is not None:
164-
tensor = tensor.to(dtype)
160+
tensor = _util.cast_if_needed(tensor, dtype)
165161
result = tensor.var(dim=axis, correction=ddof)
166162

167163
return result
@@ -204,10 +200,9 @@ def average(a_tensor, axis, w_tensor):
204200
a_tensor = a_tensor.to(result_dtype)
205201

206202
result_dtype = _dtypes_impl.result_type_impl([a_tensor.dtype, w_tensor.dtype])
207-
if a_tensor.dtype != result_dtype:
208-
a_tensor = a_tensor.to(result_dtype)
209-
if w_tensor.dtype != result_dtype:
210-
w_tensor = w_tensor.to(result_dtype)
203+
204+
a_tensor = _util.cast_if_needed(a_tensor, result_dtype)
205+
w_tensor = _util.cast_if_needed(w_tensor, result_dtype)
211206

212207
# axis
213208
if axis is None:
@@ -258,7 +253,7 @@ def quantile(a_tensor, q_tensor, axis, method):
258253
axis = _util.normalize_axis_tuple(axis, a_tensor.ndim)
259254
axis = _util.allow_only_single_axis(axis)
260255

261-
q_tensor = q_tensor.to(a_tensor.dtype)
256+
q_tensor = _util.cast_if_needed(q_tensor, a_tensor.dtype)
262257

263258
(a_tensor, q_tensor), axis = _util.axis_none_ravel(a_tensor, q_tensor, axis=axis)
264259

torch_np/_detail/_util.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,13 @@ class UFuncTypeError(TypeError, RuntimeError):
3434
pass
3535

3636

37+
def cast_if_needed(tensor, dtype):
38+
# NB: no casting if dtype=None
39+
if tensor.dtype != dtype:
40+
tensor = tensor.to(dtype)
41+
return tensor
42+
43+
3744
# a replica of the version in ./numpy/numpy/core/src/multiarray/common.h
3845
def normalize_axis_index(ax, ndim, argname=None):
3946
if not (-ndim <= ax < ndim):
@@ -156,10 +163,7 @@ def cast_dont_broadcast(tensors, target_dtype, casting):
156163
f"Cannot cast array data from {tensor.dtype} to"
157164
f" {target_dtype} according to the rule '{casting}'"
158165
)
159-
160-
# cast if needed
161-
if tensor.dtype != target_dtype:
162-
tensor = tensor.to(target_dtype)
166+
tensor = cast_if_needed(tensor, target_dtype)
163167
cast_tensors.append(tensor)
164168

165169
return tuple(cast_tensors)
@@ -200,8 +204,7 @@ def cast_and_broadcast(tensors, out_param, casting):
200204
)
201205

202206
# cast arr if needed
203-
if tensor.dtype != target_dtype:
204-
tensor = tensor.to(target_dtype)
207+
tensor = cast_if_needed(tensor, target_dtype)
205208

206209
# `out` broadcasts `tensor`
207210
if tensor.shape != target_shape:
@@ -285,8 +288,7 @@ def _coerce_to_tensor(obj, dtype=None, copy=False, ndmin=0):
285288
tensor = torch.as_tensor(obj, dtype=torch_dtype)
286289

287290
# type cast if requested
288-
if dtype is not None:
289-
tensor = tensor.to(dtype)
291+
tensor = cast_if_needed(tensor, dtype)
290292

291293
# adjust ndim if needed
292294
ndim_extra = ndmin - tensor.ndim

0 commit comments

Comments
 (0)