Skip to content

Add specification for returning the least-squares solution to a linear matrix equation (linalg: lstsq) #119

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 18 commits into from
May 12, 2021

Conversation

kgryte
Copy link
Contributor

@kgryte kgryte commented Jan 25, 2021

This PR

  • specifies the interface for returning the least-squares solution to a linear matrix equation.
  • is derived from comparing signatures across array libraries.

Notes

  • Only TF allows for providing a stack of matrices. Torch, MXNet, CuPy, NumPy, and JAX do not. This proposal follows TF and ensures consistency with other linalg interfaces which currently support stacks.

  • TF supports l2_regularizer and fast keyword arguments and is alone in doing so.

  • Neither Dask, Torch, nor TF support an rcond keyword argument. This proposal includes an rtol argument (note: rtol is renamed from rcond to unify tolerance keywords across pinv, lstsq, and matrix_rank), similar to the pinv proposal.

  • Similar to pinv, the rcond argument can either be a float or an array and have default values determined by type promotion rules.

  • NumPy, MXNet, CuPy, and JAX all support b being specified as either a vector or matrix. TF requires an (..., M,K) matrix. This PR follows NumPy.

  • Return results:

    • TF only returns an array containing solutions.
    • Torch returns a namedtuple of solutions and the QR factorization.
    • NumPy et al return a tuple.
    • Dask returns a tuple with a rank field which is an array.
    • NumPy et al return a rank field which is an integer.
    • NumPy returns a residuals field which is empty for low-rank or over-determined solutions. JAX always returns residuals for JIT purposes, unless one sets numpy_resid=True.

    This proposal returns a namedtuple with a rank field which is an array due to support for providing stacks of matrices and also returns that the residuals field always be returned, following JAX.

@rgommers
Copy link
Member

NumPy returns a residuals field which is empty for low-rank or over-determined solutions. JAX always returns residuals for JIT purposes.

The JAX docs say that behaviour matches NumPy, empty array can be returned: jax.readthedocs.io/en/latest/_autosummary/jax.numpy.linalg.lstsq.html. The in-progress PR for torch.linalg.lstsq will also match NumPy mostly - returning empty residuals probably (pytorch/pytorch#49093 (comment))

@kgryte
Copy link
Contributor Author

kgryte commented Jan 26, 2021

@rgommers Re: JAX. Sorry, I should have clarified. JAX's default behavior does not match NumPy's.

LAX-backend implementation of lstsq(). It has two important differences:

In numpy.linalg.lstsq, the default rcond is -1, and warns that in the future the default will be None. Here, the default rcond is None.

In np.linalg.lstsq the returned residuals are empty for low-rank or over-determined solutions. Here, the residuals are returned in all cases, to make the function compatible with jit. The non-jit compatible numpy behavior can be recovered by passing numpy_resid=True.

I've updated the OP accordingly.

@kgryte
Copy link
Contributor Author

kgryte commented Feb 16, 2021

Renamed rcond to tol to unify keyword arguments across pinv, lstsq, and matrix_rank APIs.

@kgryte
Copy link
Contributor Author

kgryte commented Mar 4, 2021

Renamed tol to rtol to more explicitly indicate relative tolerance and pave the way for future specification evolution (e.g., atol).

@leofang
Copy link
Contributor

leofang commented Mar 11, 2021

The PR looks fine to me, just a few high-level design questions:

  • NumPy, MXNet, CuPy, and JAX all support b being specified as either a vector or matrix. TF requires an (..., M,K) matrix. This PR follows TF to reduce API surface area and ensure consistency across all invocations.

I feel it's not very convenient to always request a matrix and forbid vector inputs. I understand we can always broadcast (..., M) to (..., M, 1) to make it work, but can't we do this internally (likely done in several libraries) to give users a bit more flexibility? For example, we can do pre-processing like this:

if x1.ndim == x2.ndim + 1:
    x2 = x2[..., None]  # or use newaxis
assert x1.ndim == x2.ndim 

NumPy et al return a namedtuple.

No, NumPy and CuPy return a tuple. Given that we didn't return namedtuple in SVD, perhaps we shouldn't do it here either to be consistent?

@kgryte
Copy link
Contributor Author

kgryte commented Mar 11, 2021

@leofang Re: namedtuple and NumPy. You are correct. I misread the NumPy docs. I updated the OP. However, for the SVD proposal, we do return a namedtuple (see here). So returning one here is consistent with that proposal.

@leofang
Copy link
Contributor

leofang commented Mar 11, 2021

However, for the SVD proposal, we do return a namedtuple (see here).

Ah OK, thanks Athan! I missed that and thought _Tuple\[ ... refers to (unnamed) tuple.

@rgommers rgommers added the API extension Adds new functions or objects to the API. label Mar 20, 2021
@kgryte
Copy link
Contributor Author

kgryte commented Mar 24, 2021

@leofang Re: matrix/vector input. I've updated the proposal to include support for an ordinate vector.

Copy link
Contributor

@leofang leofang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thanks @kgryte!

@rgommers rgommers force-pushed the main branch 2 times, most recently from 0607525 to 138e963 Compare April 19, 2021 20:25
@kgryte
Copy link
Contributor Author

kgryte commented May 12, 2021

Thanks, @leofang, for the review! This PR is ready for merge...

@kgryte kgryte merged commit e32a6a8 into main May 12, 2021
@kgryte kgryte deleted the lstsq branch May 12, 2021 04:58
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
API extension Adds new functions or objects to the API.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants