Skip to content

ENH: introduce NEP 50 "weak scalars" #140

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 17 commits into from
May 19, 2023
Merged

ENH: introduce NEP 50 "weak scalars" #140

merged 17 commits into from
May 19, 2023

Conversation

ev-br
Copy link
Collaborator

@ev-br ev-br commented May 14, 2023

Make python scalars "weak" [1], meaning that in type promotion, they do not type promote arrays:

(np.int8(3) + 4).dtype == int8

  • Note that array scalars (np.int8(3) etc) are 0D arrays, so they are not weak.
  • Converting a weak scalar to an array (asarray(3) etc) makes it not weak.
  • Scalars are only weak in ufuncs. In places like np.dot([1, 2, 3], 4.0), the result is float64.

[1] https://numpy.org/neps/nep-0050-scalar-promotion.html

This is an alternative to gh-137, which it supersedes and closes.

NB: tests fail. The majority of failures are, I believe, in internals of `torch_np.testing` which are vendored from numpy. And that is incredibly messy, so they are due in for a face-lift. Am migrating them to rely on `torch.testing.assert_close`.

Make python scalars "weak": in type promotion, they do not type promote arrays:

   (np.int8(3) + 4).dtype == int8

- Note that array scalars (np.int8(3) etc) are 0D arrays, so they are not weak.
- Converting a weak scalar to an array (asarray(3) etc) makes it not weak.
- Scalars are only weak in ufuncs. In places like `np.dot([1, 2, 3], 4.0)`,
  the result is float64.
@ev-br ev-br requested a review from lezcano May 14, 2023 07:35
@ev-br ev-br marked this pull request as draft May 15, 2023 09:41
@ev-br
Copy link
Collaborator Author

ev-br commented May 15, 2023

Grrr, no. There's more to NEP 50. Converting to draft for now.

@ev-br ev-br marked this pull request as ready for review May 16, 2023 17:53
@ev-br
Copy link
Collaborator Author

ev-br commented May 16, 2023

CI's green, finally!
There is one more known failure, will need to special-case complex <op> float32 promotion path.
Will also add some more tests along the lines of https://gist.github.com/ev-br/27dee81aae8e24193db8082ee886f6e4 .
Other than these two, this is ready.

@ev-br ev-br force-pushed the weak_scalars_2 branch from bc6ebce to 5cacd4b Compare May 17, 2023 14:15
@ev-br ev-br force-pushed the weak_scalars_2 branch from 5cacd4b to 19e96c2 Compare May 17, 2023 14:17
@ev-br ev-br force-pushed the weak_scalars_2 branch from 897ce2d to cacdbdf Compare May 17, 2023 14:48
@ev-br ev-br force-pushed the weak_scalars_2 branch from cacdbdf to ad8752a Compare May 17, 2023 14:49
@ev-br
Copy link
Collaborator Author

ev-br commented May 17, 2023

Based on an off-line discussion with @lezcano , this PR now:

  • all NEP 50 handling is confined to binary ufuncs
  • we define a special annotation/normalization, ArrayLike | Scalar, where python numeric scalars do not get converted to tensors
  • this annotation is only used in binary ufuncs.

Copy link
Collaborator

@lezcano lezcano left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A 70 line implementation of NEP-50. How cool is that?


# detect uint overflow: in PyTorch, uint8(-1) wraps around to 255,
# while NEP50 mandates an exception.
if weak_.dtype == torch.uint8 and weak_.item() != weak:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this happen just for uint8 or for any int dtype?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, using .item() is not kosher. Let's do 0 <= weak < 2**8 before doing the as_tensor.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a bit different: checking the weak's value does not detect uint(100) + 200. However, numpy warns not raises, so we shouldn't raise either. As discussed, this PR now does what numpy does, sans RuntimeWarnings.

Copy link
Collaborator

@lezcano lezcano May 18, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, it's a bit annoying that this check is just done for ints. If it were done for all dtypes, we could create the tensors with torch.full, which does check if the number fits in the given type.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that is also a valid choice, numpy already gives a customizable warning anyway if you overflow to inf, so this just seemed the easier/OK thing. ints don't overflow graciously though...

@lezcano
Copy link
Collaborator

lezcano commented May 17, 2023

Perhaps @seberg could have a look at this one as well?

@seberg
Copy link

seberg commented May 17, 2023

The truly complicated stuff is if you have more than 2 operands, without that, this seems fine (no you cannot just do a reduce(two_op_result_type, values)).

I am not actually sure about that dot example, not that dot worries me, but overall the issue remains a bit that some functions effectively call np.asarray() on entry, and thus disable any chance of "weak" handling. And... I have not figured that out fully, we may want to move more functions to weak handling.

(OTOH, I do feel that beyond ufuncs the issue is probably small enough that missing a few functions isn't the end of the world.)

@ev-br ev-br force-pushed the weak_scalars_2 branch from 76bb388 to 2a5ac15 Compare May 18, 2023 09:33
@ev-br
Copy link
Collaborator Author

ev-br commented May 18, 2023

OK, two updates:

  • this all is needed in a subset of binary ufuncs only (add, multiply etc). Comparisons must be exempt (consider np.uint8(1) > -1), not sure about other things (np.hypot?). So guard it with an explicit whitelist
  • this interacts with set_default_dtype :-). Since NEP50 is numpy-specific, I'm just checking if the defaults are those from numpy, and bail out otherwise.

@ev-br
Copy link
Collaborator Author

ev-br commented May 19, 2023

This PR now includes @lezcano 's updates from gh-143

@ev-br ev-br merged commit 0d117e0 into main May 19, 2023
@ev-br ev-br deleted the weak_scalars_2 branch May 19, 2023 08:09
@ev-br ev-br mentioned this pull request May 19, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants