Skip to content

MAINT: rework result_type_impl to accept *tensors, not dtypes #139

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion torch_np/_binary_ufuncs_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def matmul(x, y):
# - RuntimeError: expected scalar type Int but found Double
# - RuntimeError: "addmm_impl_cpu_" not implemented for 'Bool'
# - RuntimeError: "addmm_impl_cpu_" not implemented for 'Half'
dtype = _dtypes_impl.result_type_impl((x.dtype, y.dtype))
dtype = _dtypes_impl.result_type_impl(x, y)
is_bool = dtype == torch.bool
is_half = dtype == torch.float16

Expand Down
4 changes: 2 additions & 2 deletions torch_np/_dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import torch

from . import _dtypes_impl
from . import _dtypes_impl, _util

# more __all__ manipulations at the bottom
__all__ = ["dtype", "DType", "typecodes", "issubdtype", "set_default_dtype"]
Expand Down Expand Up @@ -34,7 +34,7 @@ def __new__(self, value):
tensor = value.tensor
else:
try:
tensor = torch.as_tensor(value, dtype=self.torch_dtype)
tensor = _util._coerce_to_tensor(value, dtype=self.torch_dtype)
except RuntimeError as e:
if "Overflow" in str(e):
raise OverflowError(e.args)
Expand Down
10 changes: 5 additions & 5 deletions torch_np/_dtypes_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,13 @@ def can_cast_impl(from_torch_dtype, to_torch_dtype, casting):
return _cd._can_cast_dict[casting][from_torch_dtype][to_torch_dtype]


def result_type_impl(dtypes):
def result_type_impl(*tensors):
# NB: torch dtypes here
dtyp = dtypes[0]
if len(dtypes) == 1:
dtyp = tensors[0].dtype
if len(tensors) == 1:
return dtyp

for curr in dtypes[1:]:
dtyp = _cd._result_type_dict[dtyp][curr]
for curr in tensors[1:]:
dtyp = _cd._result_type_dict[dtyp][curr.dtype]

return dtyp
33 changes: 15 additions & 18 deletions torch_np/_funcs_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def _concat_cast_helper(tensors, out=None, dtype=None, casting="same_kind"):
# figure out the type of the inputs and outputs
out_dtype = out.dtype.torch_dtype if dtype is None else dtype
else:
out_dtype = _dtypes_impl.result_type_impl([t.dtype for t in tensors])
out_dtype = _dtypes_impl.result_type_impl(*tensors)

# cast input arrays if necessary; do not broadcast them agains `out`
tensors = _util.typecast_tensors(tensors, out_dtype, casting)
Expand Down Expand Up @@ -354,9 +354,11 @@ def arange(
# the dtype of the result
if dtype is None:
dtype = _dtypes_impl.default_dtypes.int_dtype
dt_list = [_util._coerce_to_tensor(x).dtype for x in (start, stop, step)]
dt_list.append(dtype)
target_dtype = _dtypes_impl.result_type_impl(dt_list)
# XXX: default values do not get normalized
start, stop, step = (_util._coerce_to_tensor(x) for x in (start, stop, step))

dummy = torch.empty(1, dtype=dtype)
target_dtype = _dtypes_impl.result_type_impl(start, stop, step, dummy)

# work around RuntimeError: "arange_cpu" not implemented for 'ComplexFloat'
work_dtype = torch.float64 if target_dtype.is_complex else target_dtype
Expand Down Expand Up @@ -571,7 +573,7 @@ def cov(


def _conv_corr_impl(a, v, mode):
dt = _dtypes_impl.result_type_impl((a.dtype, v.dtype))
dt = _dtypes_impl.result_type_impl(a, v)
a = _util.cast_if_needed(a, dt)
v = _util.cast_if_needed(v, dt)

Expand Down Expand Up @@ -857,15 +859,14 @@ def nanpercentile():


def isclose(a: ArrayLike, b: ArrayLike, rtol=1.0e-5, atol=1.0e-8, equal_nan=False):
dtype = _dtypes_impl.result_type_impl((a.dtype, b.dtype))
dtype = _dtypes_impl.result_type_impl(a, b)
a = _util.cast_if_needed(a, dtype)
b = _util.cast_if_needed(b, dtype)
result = torch.isclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan)
return result
return torch.isclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan)


def allclose(a: ArrayLike, b: ArrayLike, rtol=1e-05, atol=1e-08, equal_nan=False):
dtype = _dtypes_impl.result_type_impl((a.dtype, b.dtype))
dtype = _dtypes_impl.result_type_impl(a, b)
a = _util.cast_if_needed(a, dtype)
b = _util.cast_if_needed(b, dtype)
return torch.allclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan)
Expand Down Expand Up @@ -1175,7 +1176,7 @@ def vdot(a: ArrayLike, b: ArrayLike, /):
if t_b.ndim > 1:
t_b = t_b.flatten()

dtype = _dtypes_impl.result_type_impl((t_a.dtype, t_b.dtype))
dtype = _dtypes_impl.result_type_impl(t_a, t_b)
is_half = dtype == torch.float16
is_bool = dtype == torch.bool

Expand All @@ -1202,15 +1203,15 @@ def tensordot(a: ArrayLike, b: ArrayLike, axes=2):
if isinstance(axes, (list, tuple)):
axes = [[ax] if isinstance(ax, int) else ax for ax in axes]

target_dtype = _dtypes_impl.result_type_impl((a.dtype, b.dtype))
target_dtype = _dtypes_impl.result_type_impl(a, b)
a = _util.cast_if_needed(a, target_dtype)
b = _util.cast_if_needed(b, target_dtype)

return torch.tensordot(a, b, dims=axes)


def dot(a: ArrayLike, b: ArrayLike, out: Optional[OutArray] = None):
dtype = _dtypes_impl.result_type_impl((a.dtype, b.dtype))
dtype = _dtypes_impl.result_type_impl(a, b)
a = _util.cast_if_needed(a, dtype)
b = _util.cast_if_needed(b, dtype)

Expand All @@ -1222,7 +1223,7 @@ def dot(a: ArrayLike, b: ArrayLike, out: Optional[OutArray] = None):


def inner(a: ArrayLike, b: ArrayLike, /):
dtype = _dtypes_impl.result_type_impl((a.dtype, b.dtype))
dtype = _dtypes_impl.result_type_impl(a, b)
is_half = dtype == torch.float16
is_bool = dtype == torch.bool

Expand Down Expand Up @@ -1284,11 +1285,7 @@ def einsum(*operands, out=None, dtype=None, order="K", casting="safe", optimize=
subscripts, array_operands = operands[0], operands[1:]

tensors = [normalize_array_like(op) for op in array_operands]
target_dtype = (
_dtypes_impl.result_type_impl([op.dtype for op in tensors])
if dtype is None
else dtype
)
target_dtype = _dtypes_impl.result_type_impl(*tensors) if dtype is None else dtype

# work around 'bmm' not implemented for 'Half' etc
is_half = target_dtype == torch.float16
Expand Down
13 changes: 8 additions & 5 deletions torch_np/_ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -518,11 +518,14 @@ def can_cast(from_, to, casting="safe"):


def result_type(*arrays_and_dtypes):
dtypes = []

tensors = []
for entry in arrays_and_dtypes:
dty = _extract_dtype(entry)
dtypes.append(dty.torch_dtype)
try:
t = asarray(entry).tensor
except ((RuntimeError, ValueError, TypeError)):
dty = _dtypes.dtype(entry)
t = torch.empty(1, dtype=dty.torch_dtype)
tensors.append(t)

torch_dtype = _dtypes_impl.result_type_impl(dtypes)
torch_dtype = _dtypes_impl.result_type_impl(*tensors)
return _dtypes.dtype(torch_dtype)
2 changes: 1 addition & 1 deletion torch_np/_reductions.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,7 @@ def average(
weights = weights.swapaxes(-1, axis)

# do the work
result_dtype = _dtypes_impl.result_type_impl([a.dtype, weights.dtype])
result_dtype = _dtypes_impl.result_type_impl(a, weights)
numerator = sum(a * weights, axis, dtype=result_dtype)
wsum = sum(weights, axis, dtype=result_dtype)
result = numerator / wsum
Expand Down
11 changes: 5 additions & 6 deletions torch_np/_ufuncs.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

def _ufunc_preprocess(tensors, where, casting, order, dtype, subok, signature, extobj):
if dtype is None:
dtype = _dtypes_impl.result_type_impl([t.dtype for t in tensors])
dtype = _dtypes_impl.result_type_impl(*tensors)

tensors = _util.typecast_tensors(tensors, dtype, casting)

Expand All @@ -26,7 +26,7 @@ def _ufunc_preprocess(tensors, where, casting, order, dtype, subok, signature, e

def _ufunc_postprocess(result, out, casting):
if out is not None:
(result,) = _util.typecast_tensors((result,), out.dtype.torch_dtype, casting)
result = _util.typecast_tensor(result, out.dtype.torch_dtype, casting)
result = torch.broadcast_to(result, out.shape)
return result

Expand Down Expand Up @@ -198,10 +198,9 @@ def wrapped(
signature=None,
extobj=None,
):
tensors = _ufunc_preprocess(
(x,), where, casting, order, dtype, subok, signature, extobj
)
result = torch_func(*tensors)
if dtype is not None:
x = _util.typecast_tensor(x, dtype, casting)
result = torch_func(x)
result = _ufunc_postprocess(result, out, casting)
return result

Expand Down
44 changes: 23 additions & 21 deletions torch_np/_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,37 +135,40 @@ def axis_none_flatten(*tensors, axis=None):
return tensors, axis


def typecast_tensors(tensors, target_dtype, casting):
"""Dtype-cast tensors to target_dtype.
def typecast_tensor(t, target_dtype, casting):
"""Dtype-cast tensor to target_dtype.
Parameters
----------
tensors : iterable
tuple or list of torch.Tensors to typecast
target_dtype : torch dtype object, optional
t : torch.Tensor
The tensor to cast
target_dtype : torch dtype object
The array dtype to cast all tensors to
casting : str
The casting mode, see `np.can_cast`
Returns
-------
a tuple of torch.Tensors with dtype being the PyTorch counterpart
of the `target_dtype`
Returns
-------
`torch.Tensor` of the `target_dtype` dtype
Raises
------
ValueError
if the argument cannot be cast according to the `casting` rule
"""
# check if we can dtype-cast all arguments
cast_tensors = []
can_cast = _dtypes_impl.can_cast_impl

for tensor in tensors:
if not can_cast(tensor.dtype, target_dtype, casting=casting):
raise TypeError(
f"Cannot cast array data from {tensor.dtype} to"
f" {target_dtype} according to the rule '{casting}'"
)
tensor = cast_if_needed(tensor, target_dtype)
cast_tensors.append(tensor)
if not can_cast(t.dtype, target_dtype, casting=casting):
raise TypeError(
f"Cannot cast array data from {t.dtype} to"
f" {target_dtype} according to the rule '{casting}'"
)
return cast_if_needed(t, target_dtype)


return tuple(cast_tensors)
def typecast_tensors(tensors, target_dtype, casting):
return tuple(typecast_tensor(t, target_dtype, casting) for t in tensors)


def _coerce_to_tensor(obj, dtype=None, copy=False, ndmin=0):
Expand Down Expand Up @@ -193,7 +196,6 @@ def _coerce_to_tensor(obj, dtype=None, copy=False, ndmin=0):
"""
if isinstance(obj, torch.Tensor):
tensor = obj
base = None
else:
tensor = torch.as_tensor(obj)
base = None
Expand Down
2 changes: 1 addition & 1 deletion torch_np/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def _atleast_float_1(a):


def _atleast_float_2(a, b):
dtyp = _dtypes_impl.result_type_impl((a.dtype, b.dtype))
dtyp = _dtypes_impl.result_type_impl(a, b)
if not (dtyp.is_floating_point or dtyp.is_complex):
dtyp = _dtypes_impl.default_dtypes.float_dtype

Expand Down