diff --git a/torch_np/_binary_ufuncs_impl.py b/torch_np/_binary_ufuncs_impl.py index 98a29996..5bbdb759 100644 --- a/torch_np/_binary_ufuncs_impl.py +++ b/torch_np/_binary_ufuncs_impl.py @@ -51,7 +51,7 @@ def matmul(x, y): # - RuntimeError: expected scalar type Int but found Double # - RuntimeError: "addmm_impl_cpu_" not implemented for 'Bool' # - RuntimeError: "addmm_impl_cpu_" not implemented for 'Half' - dtype = _dtypes_impl.result_type_impl((x.dtype, y.dtype)) + dtype = _dtypes_impl.result_type_impl(x, y) is_bool = dtype == torch.bool is_half = dtype == torch.float16 diff --git a/torch_np/_dtypes.py b/torch_np/_dtypes.py index 751de9b3..6a0b86af 100644 --- a/torch_np/_dtypes.py +++ b/torch_np/_dtypes.py @@ -6,7 +6,7 @@ import torch -from . import _dtypes_impl +from . import _dtypes_impl, _util # more __all__ manipulations at the bottom __all__ = ["dtype", "DType", "typecodes", "issubdtype", "set_default_dtype"] @@ -34,7 +34,7 @@ def __new__(self, value): tensor = value.tensor else: try: - tensor = torch.as_tensor(value, dtype=self.torch_dtype) + tensor = _util._coerce_to_tensor(value, dtype=self.torch_dtype) except RuntimeError as e: if "Overflow" in str(e): raise OverflowError(e.args) diff --git a/torch_np/_dtypes_impl.py b/torch_np/_dtypes_impl.py index cf5d8887..f21b83be 100644 --- a/torch_np/_dtypes_impl.py +++ b/torch_np/_dtypes_impl.py @@ -39,13 +39,13 @@ def can_cast_impl(from_torch_dtype, to_torch_dtype, casting): return _cd._can_cast_dict[casting][from_torch_dtype][to_torch_dtype] -def result_type_impl(dtypes): +def result_type_impl(*tensors): # NB: torch dtypes here - dtyp = dtypes[0] - if len(dtypes) == 1: + dtyp = tensors[0].dtype + if len(tensors) == 1: return dtyp - for curr in dtypes[1:]: - dtyp = _cd._result_type_dict[dtyp][curr] + for curr in tensors[1:]: + dtyp = _cd._result_type_dict[dtyp][curr.dtype] return dtyp diff --git a/torch_np/_funcs_impl.py b/torch_np/_funcs_impl.py index 09217725..bd5fa428 100644 --- a/torch_np/_funcs_impl.py +++ b/torch_np/_funcs_impl.py @@ -91,7 +91,7 @@ def _concat_cast_helper(tensors, out=None, dtype=None, casting="same_kind"): # figure out the type of the inputs and outputs out_dtype = out.dtype.torch_dtype if dtype is None else dtype else: - out_dtype = _dtypes_impl.result_type_impl([t.dtype for t in tensors]) + out_dtype = _dtypes_impl.result_type_impl(*tensors) # cast input arrays if necessary; do not broadcast them agains `out` tensors = _util.typecast_tensors(tensors, out_dtype, casting) @@ -354,9 +354,11 @@ def arange( # the dtype of the result if dtype is None: dtype = _dtypes_impl.default_dtypes.int_dtype - dt_list = [_util._coerce_to_tensor(x).dtype for x in (start, stop, step)] - dt_list.append(dtype) - target_dtype = _dtypes_impl.result_type_impl(dt_list) + # XXX: default values do not get normalized + start, stop, step = (_util._coerce_to_tensor(x) for x in (start, stop, step)) + + dummy = torch.empty(1, dtype=dtype) + target_dtype = _dtypes_impl.result_type_impl(start, stop, step, dummy) # work around RuntimeError: "arange_cpu" not implemented for 'ComplexFloat' work_dtype = torch.float64 if target_dtype.is_complex else target_dtype @@ -571,7 +573,7 @@ def cov( def _conv_corr_impl(a, v, mode): - dt = _dtypes_impl.result_type_impl((a.dtype, v.dtype)) + dt = _dtypes_impl.result_type_impl(a, v) a = _util.cast_if_needed(a, dt) v = _util.cast_if_needed(v, dt) @@ -857,15 +859,14 @@ def nanpercentile(): def isclose(a: ArrayLike, b: ArrayLike, rtol=1.0e-5, atol=1.0e-8, equal_nan=False): - dtype = _dtypes_impl.result_type_impl((a.dtype, b.dtype)) + dtype = _dtypes_impl.result_type_impl(a, b) a = _util.cast_if_needed(a, dtype) b = _util.cast_if_needed(b, dtype) - result = torch.isclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan) - return result + return torch.isclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan) def allclose(a: ArrayLike, b: ArrayLike, rtol=1e-05, atol=1e-08, equal_nan=False): - dtype = _dtypes_impl.result_type_impl((a.dtype, b.dtype)) + dtype = _dtypes_impl.result_type_impl(a, b) a = _util.cast_if_needed(a, dtype) b = _util.cast_if_needed(b, dtype) return torch.allclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan) @@ -1175,7 +1176,7 @@ def vdot(a: ArrayLike, b: ArrayLike, /): if t_b.ndim > 1: t_b = t_b.flatten() - dtype = _dtypes_impl.result_type_impl((t_a.dtype, t_b.dtype)) + dtype = _dtypes_impl.result_type_impl(t_a, t_b) is_half = dtype == torch.float16 is_bool = dtype == torch.bool @@ -1202,7 +1203,7 @@ def tensordot(a: ArrayLike, b: ArrayLike, axes=2): if isinstance(axes, (list, tuple)): axes = [[ax] if isinstance(ax, int) else ax for ax in axes] - target_dtype = _dtypes_impl.result_type_impl((a.dtype, b.dtype)) + target_dtype = _dtypes_impl.result_type_impl(a, b) a = _util.cast_if_needed(a, target_dtype) b = _util.cast_if_needed(b, target_dtype) @@ -1210,7 +1211,7 @@ def tensordot(a: ArrayLike, b: ArrayLike, axes=2): def dot(a: ArrayLike, b: ArrayLike, out: Optional[OutArray] = None): - dtype = _dtypes_impl.result_type_impl((a.dtype, b.dtype)) + dtype = _dtypes_impl.result_type_impl(a, b) a = _util.cast_if_needed(a, dtype) b = _util.cast_if_needed(b, dtype) @@ -1222,7 +1223,7 @@ def dot(a: ArrayLike, b: ArrayLike, out: Optional[OutArray] = None): def inner(a: ArrayLike, b: ArrayLike, /): - dtype = _dtypes_impl.result_type_impl((a.dtype, b.dtype)) + dtype = _dtypes_impl.result_type_impl(a, b) is_half = dtype == torch.float16 is_bool = dtype == torch.bool @@ -1284,11 +1285,7 @@ def einsum(*operands, out=None, dtype=None, order="K", casting="safe", optimize= subscripts, array_operands = operands[0], operands[1:] tensors = [normalize_array_like(op) for op in array_operands] - target_dtype = ( - _dtypes_impl.result_type_impl([op.dtype for op in tensors]) - if dtype is None - else dtype - ) + target_dtype = _dtypes_impl.result_type_impl(*tensors) if dtype is None else dtype # work around 'bmm' not implemented for 'Half' etc is_half = target_dtype == torch.float16 diff --git a/torch_np/_ndarray.py b/torch_np/_ndarray.py index df7b54a8..41eac1f4 100644 --- a/torch_np/_ndarray.py +++ b/torch_np/_ndarray.py @@ -518,11 +518,14 @@ def can_cast(from_, to, casting="safe"): def result_type(*arrays_and_dtypes): - dtypes = [] - + tensors = [] for entry in arrays_and_dtypes: - dty = _extract_dtype(entry) - dtypes.append(dty.torch_dtype) + try: + t = asarray(entry).tensor + except ((RuntimeError, ValueError, TypeError)): + dty = _dtypes.dtype(entry) + t = torch.empty(1, dtype=dty.torch_dtype) + tensors.append(t) - torch_dtype = _dtypes_impl.result_type_impl(dtypes) + torch_dtype = _dtypes_impl.result_type_impl(*tensors) return _dtypes.dtype(torch_dtype) diff --git a/torch_np/_reductions.py b/torch_np/_reductions.py index dc6ac21e..b91a44c7 100644 --- a/torch_np/_reductions.py +++ b/torch_np/_reductions.py @@ -342,7 +342,7 @@ def average( weights = weights.swapaxes(-1, axis) # do the work - result_dtype = _dtypes_impl.result_type_impl([a.dtype, weights.dtype]) + result_dtype = _dtypes_impl.result_type_impl(a, weights) numerator = sum(a * weights, axis, dtype=result_dtype) wsum = sum(weights, axis, dtype=result_dtype) result = numerator / wsum diff --git a/torch_np/_ufuncs.py b/torch_np/_ufuncs.py index dcb3ac8f..8021d64a 100644 --- a/torch_np/_ufuncs.py +++ b/torch_np/_ufuncs.py @@ -17,7 +17,7 @@ def _ufunc_preprocess(tensors, where, casting, order, dtype, subok, signature, extobj): if dtype is None: - dtype = _dtypes_impl.result_type_impl([t.dtype for t in tensors]) + dtype = _dtypes_impl.result_type_impl(*tensors) tensors = _util.typecast_tensors(tensors, dtype, casting) @@ -26,7 +26,7 @@ def _ufunc_preprocess(tensors, where, casting, order, dtype, subok, signature, e def _ufunc_postprocess(result, out, casting): if out is not None: - (result,) = _util.typecast_tensors((result,), out.dtype.torch_dtype, casting) + result = _util.typecast_tensor(result, out.dtype.torch_dtype, casting) result = torch.broadcast_to(result, out.shape) return result @@ -198,10 +198,9 @@ def wrapped( signature=None, extobj=None, ): - tensors = _ufunc_preprocess( - (x,), where, casting, order, dtype, subok, signature, extobj - ) - result = torch_func(*tensors) + if dtype is not None: + x = _util.typecast_tensor(x, dtype, casting) + result = torch_func(x) result = _ufunc_postprocess(result, out, casting) return result diff --git a/torch_np/_util.py b/torch_np/_util.py index e120898c..fc7651dd 100644 --- a/torch_np/_util.py +++ b/torch_np/_util.py @@ -135,37 +135,40 @@ def axis_none_flatten(*tensors, axis=None): return tensors, axis -def typecast_tensors(tensors, target_dtype, casting): - """Dtype-cast tensors to target_dtype. +def typecast_tensor(t, target_dtype, casting): + """Dtype-cast tensor to target_dtype. Parameters ---------- - tensors : iterable - tuple or list of torch.Tensors to typecast - target_dtype : torch dtype object, optional + t : torch.Tensor + The tensor to cast + target_dtype : torch dtype object The array dtype to cast all tensors to casting : str The casting mode, see `np.can_cast` - Returns - ------- - a tuple of torch.Tensors with dtype being the PyTorch counterpart - of the `target_dtype` + Returns + ------- + `torch.Tensor` of the `target_dtype` dtype + + Raises + ------ + ValueError + if the argument cannot be cast according to the `casting` rule + """ - # check if we can dtype-cast all arguments - cast_tensors = [] can_cast = _dtypes_impl.can_cast_impl - for tensor in tensors: - if not can_cast(tensor.dtype, target_dtype, casting=casting): - raise TypeError( - f"Cannot cast array data from {tensor.dtype} to" - f" {target_dtype} according to the rule '{casting}'" - ) - tensor = cast_if_needed(tensor, target_dtype) - cast_tensors.append(tensor) + if not can_cast(t.dtype, target_dtype, casting=casting): + raise TypeError( + f"Cannot cast array data from {t.dtype} to" + f" {target_dtype} according to the rule '{casting}'" + ) + return cast_if_needed(t, target_dtype) + - return tuple(cast_tensors) +def typecast_tensors(tensors, target_dtype, casting): + return tuple(typecast_tensor(t, target_dtype, casting) for t in tensors) def _coerce_to_tensor(obj, dtype=None, copy=False, ndmin=0): @@ -193,7 +196,6 @@ def _coerce_to_tensor(obj, dtype=None, copy=False, ndmin=0): """ if isinstance(obj, torch.Tensor): tensor = obj - base = None else: tensor = torch.as_tensor(obj) base = None diff --git a/torch_np/linalg.py b/torch_np/linalg.py index 7acd1515..46bfea1b 100644 --- a/torch_np/linalg.py +++ b/torch_np/linalg.py @@ -21,7 +21,7 @@ def _atleast_float_1(a): def _atleast_float_2(a, b): - dtyp = _dtypes_impl.result_type_impl((a.dtype, b.dtype)) + dtyp = _dtypes_impl.result_type_impl(a, b) if not (dtyp.is_floating_point or dtyp.is_complex): dtyp = _dtypes_impl.default_dtypes.float_dtype