-
Notifications
You must be signed in to change notification settings - Fork 24.3k
Implement forward AD for linalg.svd and improve svd_backward #70253
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
94deba4
e4b52df
0550a98
6d7f709
3c7dd72
060a0fc
049df7b
cd66d46
e526fc0
5a98b31
2f287f4
f3986d2
8b31912
248637e
a959c7e
7438204
e829c0c
059816f
0372118
22e11f2
43a5b64
82c4fa6
c9b0a3f
6e6d11b
61adad6
8a3b611
c07d4a9
3fb91cd
3c43821
515e31e
181fd5f
aa7ce25
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -285,6 +285,9 @@ Tensor linalg_pinv( | |
const optional<Tensor>& atol_opt, | ||
const optional<Tensor>& rtol_opt, | ||
bool hermitian) { | ||
// FIXME: Whenever we have a nice lstsq, we should dispatch this function to simply be | ||
// `torch.lstsq(A, torch.eye(A.shape[-1]), atol=atol, rtol=rtol)` | ||
// with a driver that supports singular inputs | ||
NoTF32Guard disable_tf32; | ||
ScalarType t = input.scalar_type(); | ||
TORCH_CHECK((t == ScalarType::Double || t == ScalarType::Float || t == ScalarType::ComplexFloat || t == ScalarType::ComplexDouble) | ||
|
@@ -2347,10 +2350,7 @@ Tensor& nuclear_norm_out(const Tensor& self, IntArrayRef dim, bool keepdim, Tens | |
|
||
auto permutation = create_dim_backshift_permutation(dim_[0], dim_[1], self.dim()); | ||
Tensor p = self.permute(permutation); | ||
// NOTE: U and V are computed only if gradmode is enabled, since the backward for nuclear | ||
// norm uses svd_backward, which requires them. | ||
Tensor result_ = at::sum(std::get<1>(at::svd(p, /*some=*/true, | ||
/*compute_uv=*/at::GradMode::is_enabled() && self.requires_grad())), -1, keepdim); | ||
Tensor result_ = at::sum(at::linalg_svdvals(p), -1, keepdim); | ||
if (keepdim) { | ||
result_.unsqueeze_(-1); | ||
auto permutation_reverse = create_reverse_permutation(permutation); | ||
|
@@ -2417,7 +2417,7 @@ static Tensor& _linalg_norm_matrix_out(Tensor& result, const Tensor &self, const | |
} | ||
|
||
if (std::abs(ord) == 2) { | ||
// Need to shift the reduction dims to the back, because at::svd will only operate on | ||
// Need to shift the reduction dims to the back, because at::linalg_svdvals will only operate on | ||
// the last 2 dimensions | ||
auto permutation = create_dim_backshift_permutation(dim_[0], dim_[1], self.dim()); | ||
auto permutation_reverse = create_reverse_permutation(permutation); | ||
|
@@ -2732,7 +2732,7 @@ Tensor linalg_cond(const Tensor& self, const optional<Scalar>& opt_ord) { | |
|
||
// If ord == None or ord == ±2 | ||
if (std::abs(ord.toDouble()) == 2.0) { | ||
auto singular_values = std::get<1>(at::svd(self)); | ||
auto singular_values = at::linalg_svdvals(self); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This diff is most likely the reason for the XLA tests failure. #71964 |
||
// singular values are sorted in descending order | ||
auto s_max = at::narrow(singular_values, /*dim=*/-1, /*start=*/0, /*length=*/1); | ||
auto s_min = at::narrow(singular_values, /*dim=*/-1, /*start=*/-1, /*length=*/1); | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1467,7 +1467,15 @@ | |
|
||
# We never call _linalg_svd with compute_uv=False in an autograd context, so we don't even consider it here | ||
- name: _linalg_svd(Tensor A, bool full_matrices=False, bool compute_uv=True) -> (Tensor U, Tensor S, Tensor Vh) | ||
A: svd_backward(grads, full_matrices, U, S, Vh) | ||
A: "svd_backward(full_matrices && grad_U.defined() ? grad_U.narrow(-1, 0, S.size(-1)) : grad_U, | ||
albanD marked this conversation as resolved.
Show resolved
Hide resolved
|
||
grad_S, | ||
full_matrices && grad_Vh.defined() ? grad_Vh.narrow(-2, 0, S.size(-1)) : grad_Vh, | ||
full_matrices ? U.narrow(-1, 0, S.size(-1)) : U, | ||
S, | ||
full_matrices ? Vh.narrow(-2, 0, S.size(-1)) : Vh)" | ||
Comment on lines
+1472
to
+1475
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the narrowing is no longer done in the backward? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I moved it here to remove the |
||
U: std::get<0>(linalg_svd_jvp(A_t, U, S, Vh, full_matrices)) | ||
S: std::get<1>(linalg_svd_jvp(A_t, U, S, Vh, full_matrices)) | ||
Vh: std::get<2>(linalg_svd_jvp(A_t, U, S, Vh, full_matrices)) | ||
|
||
- name: symeig(Tensor self, bool eigenvectors=False, bool upper=True) -> (Tensor eigenvalues, Tensor eigenvectors) | ||
self: linalg_eig_backward(grads[0], grads[1], eigenvalues, eigenvectors_return, /*is_hermitian=*/true, /*symeig_eigenvector=*/eigenvectors) | ||
|
Uh oh!
There was an error while loading. Please reload this page.