-
Notifications
You must be signed in to change notification settings - Fork 53
Add linalg solve function specification #115
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
https://pytorch.org/docs/master/linalg.html#torch.linalg.solve is the function to look at, it matches NumPy. |
@rgommers Yeah, I've been basing everything on the latest |
I wouldn't wait, if the PyTorch team has already decided to switch to numpy-compatible APIs, we shouldn't use the legacy API to make new decisions. |
I can take a closer look at CuPy tonight. I skimmed over the source code and am under the impression that it supports batches. |
Sorry I dropped the ball. Batch solving seems to work on CuPy's master branch: >>> import cupy as cp
>>> a = cp.random.random((16, 256, 256))
>>> b = cp.random.random((16, 256))
>>> out = cp.linalg.solve(a, b)
>>> out.shape
(16, 256)
>>> out
array([[-7.54122463e-01, 5.12721191e-01, 1.99352410e-01, ...,
-6.23679757e-01, 1.58717630e+00, 1.20552063e+00],
[ 1.40216909e+01, 1.60826566e+01, 5.13492282e+00, ...,
7.86564501e+00, -9.89509721e-01, -6.26412570e+00],
[ 3.68326684e+00, -2.00001553e+01, -2.38530695e+00, ...,
-4.56784107e+00, 9.60668021e+00, -5.74262514e+00],
...,
[-4.24311695e-03, -1.63314378e+00, -7.87817536e-01, ...,
-4.04266646e-02, -2.67829118e-01, -1.10735434e-01],
[ 2.32818737e-02, -8.14974058e-01, -8.93238457e+00, ...,
-7.73627020e+00, 5.69240520e-02, 1.83340313e+00],
[-1.02653492e+00, -3.52034898e-01, 9.79629855e-01, ...,
-7.32350737e-01, 3.48412196e-01, -1.21837283e+00]]) |
@leofang Thanks for checking! I've updated the OP to indicate that CuPy supports providing stacks. |
@kgryte btw can we clarify in the API description what |
@leofang I believe this concerns support for solving for multiple columns of ordinate values, rather than just a vector. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
0607525
to
138e963
Compare
Thanks, @leofang, for the review! This PR is ready for merge... |
This PR
Notes
Following Torch, MXNet, TF, NumPy, CuPy, and JAX, this proposal allows for providing a stack of square matrices. Dask does not currently support providing stacks.
Torch (1.7) swaps
A
andB
array arguments. This PR follows NumPy and others.TF supports a boolean flag indicating whether to solve with the adjoint.
Dask, similar to SciPy, supports a boolean flag to indicate whether one can assume the coefficient matrix is symmetric and positive definite.
Torch (1.7) returns the LU factorization of the coefficient matrix.
NumPy allows providing
B
having shape(..., M,)
(i.e., supports solving a system of linear scalar equations), while Torch (1.7) and TF require(..., M, K)
. This PR follows NumPy.