-
Notifications
You must be signed in to change notification settings - Fork 53
Add specification for returning the least-squares solution to a linear matrix equation (linalg: lstsq) #119
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
The JAX docs say that behaviour matches NumPy, empty array can be returned: jax.readthedocs.io/en/latest/_autosummary/jax.numpy.linalg.lstsq.html. The in-progress PR for |
@rgommers Re: JAX. Sorry, I should have clarified. JAX's default behavior does not match NumPy's.
I've updated the OP accordingly. |
Renamed |
Renamed |
The PR looks fine to me, just a few high-level design questions:
I feel it's not very convenient to always request a matrix and forbid vector inputs. I understand we can always broadcast if x1.ndim == x2.ndim + 1:
x2 = x2[..., None] # or use newaxis
assert x1.ndim == x2.ndim
No, NumPy and CuPy return a tuple. Given that we didn't return namedtuple in SVD, perhaps we shouldn't do it here either to be consistent? |
Ah OK, thanks Athan! I missed that and thought |
@leofang Re: matrix/vector input. I've updated the proposal to include support for an ordinate vector. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Thanks @kgryte!
0607525
to
138e963
Compare
Thanks, @leofang, for the review! This PR is ready for merge... |
This PR
Notes
Only TF allows for providing a stack of matrices. Torch, MXNet, CuPy, NumPy, and JAX do not. This proposal follows TF and ensures consistency with other linalg interfaces which currently support stacks.
TF supports
l2_regularizer
andfast
keyword arguments and is alone in doing so.Neither Dask, Torch, nor TF support an
rcond
keyword argument. This proposal includes anrtol
argument (note:rtol
is renamed fromrcond
to unify tolerance keywords acrosspinv
,lstsq
, andmatrix_rank
), similar to the pinv proposal.Similar to pinv, the
rcond
argument can either be afloat
or anarray
and have default values determined by type promotion rules.NumPy, MXNet, CuPy, and JAX all support
b
being specified as either a vector or matrix. TF requires an(..., M,K)
matrix. This PR follows NumPy.Return results:
rank
field which is an array.rank
field which is an integer.residuals
field which is empty for low-rank or over-determined solutions. JAX always returns residuals for JIT purposes, unless one setsnumpy_resid=True
.This proposal returns a namedtuple with a
rank
field which is an array due to support for providing stacks of matrices and also returns that theresiduals
field always be returned, following JAX.