-
Notifications
You must be signed in to change notification settings - Fork 24.3k
Implement forward AD for linalg.svd and improve svd_backward #70253
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
I included a derivation of the formula in the complex case, as it is particularly tricky. As far as I know, this is the first time this formula is derived in the literature. I also implemented a more efficient and more accurate version of svd_backward. More importantly, I also added a lax check in the complex case making sure the loss function just depends on the subspaces spanned by the pairs of singular vectors, and not their joint phase. [ghstack-poisoned]
CI Flow Status⚛️ CI FlowRuleset - Version:
You can add a comment to the PR and tag @pytorchbot with the following commands: # ciflow rerun, "ciflow/default" will always be added automatically
@pytorchbot ciflow rerun
# ciflow rerun with additional labels "-l <ciflow/label_name>", which is equivalent to adding these labels manually and trigger the rerun
@pytorchbot ciflow rerun -l ciflow/scheduled -l ciflow/slow For more information, please take a look at the CI Flow Wiki. |
🔗 Helpful links
💊 CI failures summary and remediationsAs of commit aa7ce25 (more details on the Dr. CI page):
🕵️ 5 new failures recognized by patternsThe following CI failures do not appear to be due to upstream breakages:
|
I included a derivation of the formula in the complex case, as it is particularly tricky. As far as I know, this is the first time this formula is derived in the literature. I also implemented a more efficient and more accurate version of svd_backward. More importantly, I also added a lax check in the complex case making sure the loss function just depends on the subspaces spanned by the pairs of singular vectors, and not their joint phase. ghstack-source-id: 1f6ccb3 Pull Request resolved: #70253
I included a derivation of the formula in the complex case, as it is particularly tricky. As far as I know, this is the first time this formula is derived in the literature. I also implemented a more efficient and more accurate version of svd_backward. More importantly, I also added a lax check in the complex case making sure the loss function just depends on the subspaces spanned by the pairs of singular vectors, and not their joint phase. cc jianyuh nikitaved pearu mruberry walterddr IvanYashchuk xwang233 Lezcano [ghstack-poisoned]
I included a derivation of the formula in the complex case, as it is particularly tricky. As far as I know, this is the first time this formula is derived in the literature. I also implemented a more efficient and more accurate version of svd_backward. More importantly, I also added a lax check in the complex case making sure the loss function just depends on the subspaces spanned by the pairs of singular vectors, and not their joint phase. cc jianyuh nikitaved pearu mruberry walterddr IvanYashchuk xwang233 Lezcano [ghstack-poisoned]
I included a derivation of the formula in the complex case, as it is particularly tricky. As far as I know, this is the first time this formula is derived in the literature. I also implemented a more efficient and more accurate version of svd_backward. More importantly, I also added a lax check in the complex case making sure the loss function just depends on the subspaces spanned by the pairs of singular vectors, and not their joint phase. ghstack-source-id: 09838d3 Pull Request resolved: #70253
I included a derivation of the formula in the complex case, as it is particularly tricky. As far as I know, this is the first time this formula is derived in the literature. I also implemented a more efficient and more accurate version of svd_backward. More importantly, I also added a lax check in the complex case making sure the loss function just depends on the subspaces spanned by the pairs of singular vectors, and not their joint phase. cc jianyuh nikitaved pearu mruberry walterddr IvanYashchuk xwang233 Lezcano [ghstack-poisoned]
I included a derivation of the formula in the complex case, as it is particularly tricky. As far as I know, this is the first time this formula is derived in the literature. I also implemented a more efficient and more accurate version of svd_backward. More importantly, I also added a lax check in the complex case making sure the loss function just depends on the subspaces spanned by the pairs of singular vectors, and not their joint phase. ghstack-source-id: 8c581f1 Pull Request resolved: #70253
I included a derivation of the formula in the complex case, as it is particularly tricky. As far as I know, this is the first time this formula is derived in the literature. I also implemented a more efficient and more accurate version of svd_backward. More importantly, I also added a lax check in the complex case making sure the loss function just depends on the subspaces spanned by the pairs of singular vectors, and not their joint phase. cc jianyuh nikitaved pearu mruberry walterddr IvanYashchuk xwang233 Lezcano [ghstack-poisoned]
This PR adds checks for the backward of `linalg.eig`, similar to those deduced in #70253 It also modifies the function so that it does not save the input matrix, as it's not necessary. It also corrects the forward AD formula for it to be correct. Now all the tests pass for `linalg.eig` and `linalg.eigvals`. It also updates the docs to reflect better what's going on here. [ghstack-poisoned]
I included a derivation of the formula in the complex case, as it is particularly tricky. As far as I know, this is the first time this formula is derived in the literature. I also implemented a more efficient and more accurate version of svd_backward. More importantly, I also added a lax check in the complex case making sure the loss function just depends on the subspaces spanned by the pairs of singular vectors, and not their joint phase. cc jianyuh nikitaved pearu mruberry walterddr IvanYashchuk xwang233 Lezcano Differential Revision: [D33751982](https://our.internmc.facebook.com/intern/diff/D33751982) [ghstack-poisoned]
I included a derivation of the formula in the complex case, as it is particularly tricky. As far as I know, this is the first time this formula is derived in the literature. I also implemented a more efficient and more accurate version of svd_backward. More importantly, I also added a lax check in the complex case making sure the loss function just depends on the subspaces spanned by the pairs of singular vectors, and not their joint phase. ghstack-source-id: 2732ea3 Pull Request resolved: #70253
I have a general question, @lezcano , @IvanYashchuk , @mruberry , @ngimel. Do we want to keep the check for the invariance in the backward? That seems like a significant perf penatly because of device sync. Maybe it is better to mention this invariance in the documentation? Granted that checking it from the user interface is problematic, but then the complex case is quite non-trivial so maybe the user is already aware of potential non-uniqueness issues... Or maybe we can introduce an additional input parameter that controls the check, but.. I do not know whether it is a good solution. |
I chose this design because, as you said, this function is already rather difficult, so whatever we can do to improve the UX will be welcome by them. Note that in JAX (and by extension TF), they wanted to do this, but they did not know how to do it. See this comment jax-ml/jax#2748 (comment) in particular and the rest of that issue / linked issues for context. Also, it is true that this synchronises but:
|
I included a derivation of the formula in the complex case, as it is particularly tricky. As far as I know, this is the first time this formula is derived in the literature. I also implemented a more efficient and more accurate version of svd_backward. More importantly, I also added a lax check in the complex case making sure the loss function just depends on the subspaces spanned by the pairs of singular vectors, and not their joint phase. cc jianyuh nikitaved pearu mruberry walterddr IvanYashchuk xwang233 Lezcano Differential Revision: [D33751982](https://our.internmc.facebook.com/intern/diff/D33751982) [ghstack-poisoned]
I included a derivation of the formula in the complex case, as it is particularly tricky. As far as I know, this is the first time this formula is derived in the literature. I also implemented a more efficient and more accurate version of svd_backward. More importantly, I also added a lax check in the complex case making sure the loss function just depends on the subspaces spanned by the pairs of singular vectors, and not their joint phase. ghstack-source-id: c74c94f Pull Request resolved: #70253
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Thanks you, Mario! Your call, @mruberry !
@mruberry has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator. |
Failures seem unrelated |
Summary: Pull Request resolved: #70253 I included a derivation of the formula in the complex case, as it is particularly tricky. As far as I know, this is the first time this formula is derived in the literature. I also implemented a more efficient and more accurate version of svd_backward. More importantly, I also added a lax check in the complex case making sure the loss function just depends on the subspaces spanned by the pairs of singular vectors, and not their joint phase. cc jianyuh nikitaved pearu mruberry walterddr IvanYashchuk xwang233 Lezcano Test Plan: Imported from OSS Reviewed By: mikaylagawarecki Differential Revision: D33751982 Pulled By: mruberry fbshipit-source-id: c2a4a92a921a732357e99c01ccb563813b1af512
Summary: Pull Request resolved: #70253 I included a derivation of the formula in the complex case, as it is particularly tricky. As far as I know, this is the first time this formula is derived in the literature. I also implemented a more efficient and more accurate version of svd_backward. More importantly, I also added a lax check in the complex case making sure the loss function just depends on the subspaces spanned by the pairs of singular vectors, and not their joint phase. cc jianyuh nikitaved pearu mruberry walterddr IvanYashchuk xwang233 Lezcano Test Plan: Imported from OSS Reviewed By: mikaylagawarecki Differential Revision: D33751982 Pulled By: mruberry fbshipit-source-id: c2a4a92a921a732357e99c01ccb563813b1af512 (cherry picked from commit 391319e)
@@ -2732,7 +2732,7 @@ Tensor linalg_cond(const Tensor& self, const optional<Scalar>& opt_ord) { | |||
|
|||
// If ord == None or ord == ±2 | |||
if (std::abs(ord.toDouble()) == 2.0) { | |||
auto singular_values = std::get<1>(at::svd(self)); | |||
auto singular_values = at::linalg_svdvals(self); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This diff is most likely the reason for the XLA tests failure. #71964
Summary: Pull Request resolved: pytorch/pytorch#70253 I included a derivation of the formula in the complex case, as it is particularly tricky. As far as I know, this is the first time this formula is derived in the literature. I also implemented a more efficient and more accurate version of svd_backward. More importantly, I also added a lax check in the complex case making sure the loss function just depends on the subspaces spanned by the pairs of singular vectors, and not their joint phase. cc jianyuh nikitaved pearu mruberry walterddr IvanYashchuk xwang233 Lezcano Test Plan: Imported from OSS Reviewed By: mikaylagawarecki Differential Revision: D33751982 Pulled By: mruberry fbshipit-source-id: c2a4a92a921a732357e99c01ccb563813b1af512 (cherry picked from commit 391319e)
Summary: Pull Request resolved: pytorch/pytorch#70253 I included a derivation of the formula in the complex case, as it is particularly tricky. As far as I know, this is the first time this formula is derived in the literature. I also implemented a more efficient and more accurate version of svd_backward. More importantly, I also added a lax check in the complex case making sure the loss function just depends on the subspaces spanned by the pairs of singular vectors, and not their joint phase. cc jianyuh nikitaved pearu mruberry walterddr IvanYashchuk xwang233 Lezcano Test Plan: Imported from OSS Reviewed By: mikaylagawarecki Differential Revision: D33751982 Pulled By: mruberry fbshipit-source-id: c2a4a92a921a732357e99c01ccb563813b1af512 (cherry picked from commit 391319e)
Summary: Pull Request resolved: pytorch/pytorch#70253 I included a derivation of the formula in the complex case, as it is particularly tricky. As far as I know, this is the first time this formula is derived in the literature. I also implemented a more efficient and more accurate version of svd_backward. More importantly, I also added a lax check in the complex case making sure the loss function just depends on the subspaces spanned by the pairs of singular vectors, and not their joint phase. cc jianyuh nikitaved pearu mruberry walterddr IvanYashchuk xwang233 Lezcano Test Plan: Imported from OSS Reviewed By: mikaylagawarecki Differential Revision: D33751982 Pulled By: mruberry fbshipit-source-id: c2a4a92a921a732357e99c01ccb563813b1af512 (cherry picked from commit 391319e)
Summary: Pull Request resolved: pytorch/pytorch#70253 I included a derivation of the formula in the complex case, as it is particularly tricky. As far as I know, this is the first time this formula is derived in the literature. I also implemented a more efficient and more accurate version of svd_backward. More importantly, I also added a lax check in the complex case making sure the loss function just depends on the subspaces spanned by the pairs of singular vectors, and not their joint phase. cc jianyuh nikitaved pearu mruberry walterddr IvanYashchuk xwang233 Lezcano Test Plan: Imported from OSS Reviewed By: mikaylagawarecki Differential Revision: D33751982 Pulled By: mruberry fbshipit-source-id: c2a4a92a921a732357e99c01ccb563813b1af512 (cherry picked from commit 391319e)
Hi, I just stumbled upon this. I'm not sure if this helps, but I implemented the JVP rule for the complex-valued SVD here: jax-ml/jax#5225 and we showed how the derivation works in our paper (https://arxiv.org/pdf/2209.14328.pdf) in appendix D. |
I have a half-written paper where I discuss how to formalise all this stuff with principal bundles and all that, but I never found the time to finish writing it really :( |
Nice. I would be curious to see it when it's done. |
Stack from ghstack:
I included a derivation of the formula in the complex case, as it is
particularly tricky. As far as I know, this is the first time this formula
is derived in the literature.
I also implemented a more efficient and more accurate version of svd_backward.
More importantly, I also added a lax check in the complex case making sure the loss
function just depends on the subspaces spanned by the pairs of singular
vectors, and not their joint phase.
cc @jianyuh @nikitaved @pearu @mruberry @walterddr @IvanYashchuk @xwang233 @lezcano
Differential Revision: D33751982