-
Notifications
You must be signed in to change notification settings - Fork 53
Add matrix_power specification #112
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For CuPy, only
float32
andfloat64
is allowed. In general, we restrict linalg functions to floating-point dtypes. That practice is followed here.
Only specifying float32/64 sounds right, however the phrasing is a bit off. I'd rather add a note saying that the behaviour is only specified for these dtypes, and that for other dtypes implementation may differ - they will either cast or raise for some or all values of n
.
Re: dtypes. We'd need to add that note more generally at the top of the linalg spec, as this is generally true for most linalg specs, given their current support being "limited" to floating-point dtypes. The restriction to floating-point dtypes was discussed a couple times during meetings and found its support there. |
@rgommers re: dtypes. We could change the language from In this particular case, I would, however, advocate for always returning a floating-point dtype. |
0607525
to
138e963
Compare
This PR has been open for some time without further comment and has been discussed/approved during meetings. Will merge, and we can submit follow-up PRs to resolve any issues/concerns which may arise. |
For reference, inclusion of this API in the standard was discussed here. |
This PR
Notes
The output array dtype for NumPy depends on both the input array and whether
n
is negative. If negative, then the result is floating-point. If non-negative, the result has the same dtype as the input. For CuPy, onlyfloat32
andfloat64
is allowed. In general, we recommend restricting linalg functions to floating-point dtypes. That practice is followed here.Following Torch and maintaining consistency with other linalg interfaces operating in square matrices, a stack of square matrices is accepted.