-
Notifications
You must be signed in to change notification settings - Fork 4
ENH: introduce NEP 50 "weak scalars" #140
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 9 commits
aa30300
b6246a0
b4e5f73
de0a611
299c047
6bcd353
3656962
19e96c2
4429b2f
ad8752a
863ef4c
22130af
b7f035f
61322fd
2a5ac15
8dbe000
2d1636e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -49,3 +49,80 @@ def result_type_impl(*tensors): | |
dtyp = _cd._result_type_dict[dtyp][curr.dtype] | ||
|
||
return dtyp | ||
|
||
|
||
# ### NEP 50 helpers ### | ||
|
||
SCALAR_TYPES = (int, bool, float, complex) | ||
|
||
|
||
def _dtype_for_scalar(py_type): | ||
return { | ||
bool: torch.bool, | ||
int: torch.int64, | ||
float: torch.float64, | ||
complex: torch.complex128, | ||
}[py_type] | ||
|
||
|
||
categories = [ | ||
(torch.bool,), | ||
(torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64), | ||
(torch.float16, torch.float32, torch.float64), | ||
(torch.complex64, torch.complex128), | ||
] | ||
|
||
|
||
def category(dtyp): | ||
for j, cat in enumerate(categories): | ||
if dtyp in cat: | ||
return j | ||
raise ValueError(f"unknown dtype {dtyp}") | ||
|
||
|
||
dtype_for_cat = {0: torch.bool, 1: torch.int64, 2: torch.float64, 3: torch.complex128} | ||
|
||
|
||
def nep50_to_tensors(x1, x2): | ||
"""If either of inputs is a python scalar, type-promote with NEP 50. | ||
|
||
NB: NEP 50 mandates RuntimeWarnings on some overflows. We do not emit them: | ||
we either raise OverflowError or just do the computation. | ||
""" | ||
|
||
x1_type, x2_type = type(x1), type(x2) | ||
x1_is_weak = x1_type in SCALAR_TYPES | ||
x2_is_weak = x2_type in SCALAR_TYPES | ||
if x1_is_weak and x2_is_weak: | ||
# two scalars: promote | ||
x1 = torch.as_tensor(x1, dtype=_dtype_for_scalar(x1_type)) | ||
x2 = torch.as_tensor(x2, dtype=_dtype_for_scalar(x2_type)) | ||
return x1, x2 | ||
elif not (x1_is_weak or x2_is_weak): | ||
# two tensors: nothing to do here | ||
return x1, x2 | ||
else: | ||
# scalar <op> scalar: NEP 50 | ||
weak, not_weak = (x1, x2) if x1_is_weak else (x2, x1) | ||
|
||
# find the dtype for the weak's type | ||
weak_dtype = _dtype_for_scalar(type(weak)) | ||
|
||
cat_weak = category(weak_dtype) | ||
cat_not_weak = category(not_weak.dtype) | ||
|
||
dt = not_weak.dtype if cat_weak <= cat_not_weak else dtype_for_cat[cat_weak] | ||
|
||
# special-case complex + float32 | ||
if weak_dtype.is_complex and not_weak.dtype == torch.float32: | ||
dt = torch.complex64 | ||
|
||
# finally, can cast make `weak` into a 0D tensor | ||
weak_ = torch.as_tensor(weak, dtype=dt) | ||
|
||
# detect uint overflow: in PyTorch, uint8(-1) wraps around to 255, | ||
# while NEP50 mandates an exception. | ||
if weak_.dtype == torch.uint8 and weak_.item() != weak: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does this happen just for There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also, using .item() is not kosher. Let's do 0 <= weak < 2**8 before doing the as_tensor. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a bit different: checking the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also, it's a bit annoying that this check is just done for ints. If it were done for all dtypes, we could create the tensors with There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think that is also a valid choice, numpy already gives a customizable warning anyway if you overflow to |
||
raise OverflowError(f"Python integer {weak} out of bounds for {weak_.dtype}") | ||
|
||
return (weak_, not_weak) if x1_is_weak else (not_weak, weak_) |
Uh oh!
There was an error while loading. Please reload this page.