Skip to content

Add linalg solve function specification #115

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
May 12, 2021
Merged

Add linalg solve function specification #115

merged 12 commits into from
May 12, 2021

Conversation

kgryte
Copy link
Contributor

@kgryte kgryte commented Jan 19, 2021

This PR

  • specifies the interface for performing solving a system of linear equations.
  • is derived from comparing signatures across array libraries.

Notes

  • Following Torch, MXNet, TF, NumPy, CuPy, and JAX, this proposal allows for providing a stack of square matrices. Dask does not currently support providing stacks.

  • Torch (1.7) swaps A and B array arguments. This PR follows NumPy and others.

  • TF supports a boolean flag indicating whether to solve with the adjoint.

  • Dask, similar to SciPy, supports a boolean flag to indicate whether one can assume the coefficient matrix is symmetric and positive definite.

  • Torch (1.7) returns the LU factorization of the coefficient matrix.

  • NumPy allows providing B having shape (..., M,) (i.e., supports solving a system of linear scalar equations), while Torch (1.7) and TF require (..., M, K). This PR follows NumPy.

@rgommers
Copy link
Member

Torch currently swaps A and B array arguments. This PR follows NumPy and others.

Torch returns the LU factorization of the coefficient matrix.

NumPy allows providing B having shape (..., M,) (i.e., supports solving a system of linear scalar equations), while Torch

https://pytorch.org/docs/master/linalg.html#torch.linalg.solve is the function to look at, it matches NumPy. torch.solve is legacy.

@kgryte
Copy link
Contributor Author

kgryte commented Jan 20, 2021

@rgommers Yeah, I've been basing everything on the latest stable Torch version 1.7.1. Can update once the work in master is stable, mainly as a matter of process.

@rgommers
Copy link
Member

I wouldn't wait, if the PyTorch team has already decided to switch to numpy-compatible APIs, we shouldn't use the legacy API to make new decisions.

@leofang
Copy link
Contributor

leofang commented Jan 21, 2021

I can take a closer look at CuPy tonight. I skimmed over the source code and am under the impression that it supports batches.

@leofang
Copy link
Contributor

leofang commented Jan 27, 2021

I can take a closer look at CuPy tonight. I skimmed over the source code and am under the impression that it supports batches.

Sorry I dropped the ball. Batch solving seems to work on CuPy's master branch:

>>> import cupy as cp
>>> a = cp.random.random((16, 256, 256))
>>> b = cp.random.random((16, 256))
>>> out = cp.linalg.solve(a, b)
>>> out.shape
(16, 256)
>>> out
array([[-7.54122463e-01,  5.12721191e-01,  1.99352410e-01, ...,
        -6.23679757e-01,  1.58717630e+00,  1.20552063e+00],
       [ 1.40216909e+01,  1.60826566e+01,  5.13492282e+00, ...,
         7.86564501e+00, -9.89509721e-01, -6.26412570e+00],
       [ 3.68326684e+00, -2.00001553e+01, -2.38530695e+00, ...,
        -4.56784107e+00,  9.60668021e+00, -5.74262514e+00],
       ...,
       [-4.24311695e-03, -1.63314378e+00, -7.87817536e-01, ...,
        -4.04266646e-02, -2.67829118e-01, -1.10735434e-01],
       [ 2.32818737e-02, -8.14974058e-01, -8.93238457e+00, ...,
        -7.73627020e+00,  5.69240520e-02,  1.83340313e+00],
       [-1.02653492e+00, -3.52034898e-01,  9.79629855e-01, ...,
        -7.32350737e-01,  3.48412196e-01, -1.21837283e+00]])

@kgryte
Copy link
Contributor Author

kgryte commented Jan 28, 2021

@leofang Thanks for checking! I've updated the OP to indicate that CuPy supports providing stacks.

@leofang
Copy link
Contributor

leofang commented Jan 28, 2021

@kgryte btw can we clarify in the API description what K in (..., M, K) stands for? The NumPy documentation, for example, is unclear on this...

@kgryte
Copy link
Contributor Author

kgryte commented Jan 28, 2021

@leofang I believe this concerns support for solving for multiple columns of ordinate values, rather than just a vector.

@rgommers rgommers added the API extension Adds new functions or objects to the API. label Mar 20, 2021
Copy link
Contributor

@leofang leofang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@rgommers rgommers force-pushed the main branch 3 times, most recently from 0607525 to 138e963 Compare April 19, 2021 20:25
@kgryte
Copy link
Contributor Author

kgryte commented May 12, 2021

Thanks, @leofang, for the review! This PR is ready for merge...

@kgryte kgryte merged commit 5677c24 into main May 12, 2021
@kgryte kgryte deleted the solve branch May 12, 2021 04:56
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
API extension Adds new functions or objects to the API.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants