Skip to content

Implement forward AD for linalg.svd and improve svd_backward #70253

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 32 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
94deba4
Implement forward AD for linalg.svd and improve svd_backward
lezcano Dec 21, 2021
e4b52df
Update on "Implement forward AD for linalg.svd and improve svd_backward"
lezcano Dec 22, 2021
0550a98
Update on "Implement forward AD for linalg.svd and improve svd_backward"
lezcano Dec 22, 2021
6d7f709
Update on "Implement forward AD for linalg.svd and improve svd_backward"
lezcano Dec 29, 2021
3c7dd72
Update on "Implement forward AD for linalg.svd and improve svd_backward"
lezcano Dec 29, 2021
060a0fc
Update on "Implement forward AD for linalg.svd and improve svd_backward"
lezcano Dec 30, 2021
049df7b
Update on "Implement forward AD for linalg.svd and improve svd_backward"
lezcano Dec 31, 2021
cd66d46
Update on "Implement forward AD for linalg.svd and improve svd_backward"
lezcano Jan 3, 2022
e526fc0
Update on "Implement forward AD for linalg.svd and improve svd_backward"
lezcano Jan 7, 2022
5a98b31
Update on "Implement forward AD for linalg.svd and improve svd_backward"
lezcano Jan 7, 2022
2f287f4
Update on "Implement forward AD for linalg.svd and improve svd_backward"
lezcano Jan 7, 2022
f3986d2
Update on "Implement forward AD for linalg.svd and improve svd_backward"
lezcano Jan 11, 2022
8b31912
Update on "Implement forward AD for linalg.svd and improve svd_backward"
lezcano Jan 11, 2022
248637e
Update on "Implement forward AD for linalg.svd and improve svd_backward"
lezcano Jan 11, 2022
a959c7e
Update on "Implement forward AD for linalg.svd and improve svd_backward"
lezcano Jan 11, 2022
7438204
Update on "Implement forward AD for linalg.svd and improve svd_backward"
lezcano Jan 12, 2022
e829c0c
Update on "Implement forward AD for linalg.svd and improve svd_backward"
lezcano Jan 13, 2022
059816f
Update on "Implement forward AD for linalg.svd and improve svd_backward"
lezcano Jan 13, 2022
0372118
Update on "Implement forward AD for linalg.svd and improve svd_backward"
lezcano Jan 13, 2022
22e11f2
Update on "Implement forward AD for linalg.svd and improve svd_backward"
lezcano Jan 14, 2022
43a5b64
Update on "Implement forward AD for linalg.svd and improve svd_backward"
lezcano Jan 18, 2022
82c4fa6
Update on "Implement forward AD for linalg.svd and improve svd_backward"
lezcano Jan 18, 2022
c9b0a3f
Update on "Implement forward AD for linalg.svd and improve svd_backward"
lezcano Jan 18, 2022
6e6d11b
Update on "Implement forward AD for linalg.svd and improve svd_backward"
lezcano Jan 19, 2022
61adad6
Update on "Implement forward AD for linalg.svd and improve svd_backward"
lezcano Jan 24, 2022
8a3b611
Update on "Implement forward AD for linalg.svd and improve svd_backward"
lezcano Jan 24, 2022
c07d4a9
Update on "Implement forward AD for linalg.svd and improve svd_backward"
lezcano Jan 24, 2022
3fb91cd
Update on "Implement forward AD for linalg.svd and improve svd_backward"
lezcano Jan 24, 2022
3c43821
Update on "Implement forward AD for linalg.svd and improve svd_backward"
lezcano Jan 25, 2022
515e31e
Update on "Implement forward AD for linalg.svd and improve svd_backward"
lezcano Jan 25, 2022
181fd5f
Update on "Implement forward AD for linalg.svd and improve svd_backward"
lezcano Jan 25, 2022
aa7ce25
Update on "Implement forward AD for linalg.svd and improve svd_backward"
lezcano Jan 26, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions aten/src/ATen/native/LinearAlgebra.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,9 @@ Tensor linalg_pinv(
const optional<Tensor>& atol_opt,
const optional<Tensor>& rtol_opt,
bool hermitian) {
// FIXME: Whenever we have a nice lstsq, we should dispatch this function to simply be
// `torch.lstsq(A, torch.eye(A.shape[-1]), atol=atol, rtol=rtol)`
// with a driver that supports singular inputs
NoTF32Guard disable_tf32;
ScalarType t = input.scalar_type();
TORCH_CHECK((t == ScalarType::Double || t == ScalarType::Float || t == ScalarType::ComplexFloat || t == ScalarType::ComplexDouble)
Expand Down Expand Up @@ -2347,10 +2350,7 @@ Tensor& nuclear_norm_out(const Tensor& self, IntArrayRef dim, bool keepdim, Tens

auto permutation = create_dim_backshift_permutation(dim_[0], dim_[1], self.dim());
Tensor p = self.permute(permutation);
// NOTE: U and V are computed only if gradmode is enabled, since the backward for nuclear
// norm uses svd_backward, which requires them.
Tensor result_ = at::sum(std::get<1>(at::svd(p, /*some=*/true,
/*compute_uv=*/at::GradMode::is_enabled() && self.requires_grad())), -1, keepdim);
Tensor result_ = at::sum(at::linalg_svdvals(p), -1, keepdim);
if (keepdim) {
result_.unsqueeze_(-1);
auto permutation_reverse = create_reverse_permutation(permutation);
Expand Down Expand Up @@ -2417,7 +2417,7 @@ static Tensor& _linalg_norm_matrix_out(Tensor& result, const Tensor &self, const
}

if (std::abs(ord) == 2) {
// Need to shift the reduction dims to the back, because at::svd will only operate on
// Need to shift the reduction dims to the back, because at::linalg_svdvals will only operate on
// the last 2 dimensions
auto permutation = create_dim_backshift_permutation(dim_[0], dim_[1], self.dim());
auto permutation_reverse = create_reverse_permutation(permutation);
Expand Down Expand Up @@ -2732,7 +2732,7 @@ Tensor linalg_cond(const Tensor& self, const optional<Scalar>& opt_ord) {

// If ord == None or ord == ±2
if (std::abs(ord.toDouble()) == 2.0) {
auto singular_values = std::get<1>(at::svd(self));
auto singular_values = at::linalg_svdvals(self);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This diff is most likely the reason for the XLA tests failure. #71964

// singular values are sorted in descending order
auto s_max = at::narrow(singular_values, /*dim=*/-1, /*start=*/0, /*length=*/1);
auto s_min = at::narrow(singular_values, /*dim=*/-1, /*start=*/-1, /*length=*/1);
Expand Down
21 changes: 21 additions & 0 deletions test/test_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -2817,6 +2817,27 @@ def test_svd(self, device, dtype):
S_s = torch.svd(A, compute_uv=False).S
self.assertEqual(S_s, S)

@skipCUDAIfNoMagmaAndNoCusolver
@skipCPUIfNoLapack
@dtypes(torch.complex128)
def test_invariance_error_spectral_decompositions(self, device, dtype):
make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=True)
A = make_arg((3, 3))
with self.assertRaisesRegex(RuntimeError, "ill-defined"):
U, _, Vh = torch.linalg.svd(A, full_matrices=False)
(U + Vh).sum().backward()

A = make_arg((3, 3))
with self.assertRaisesRegex(RuntimeError, "ill-defined"):
V = torch.linalg.eig(A).eigenvectors
V.sum().backward()

A = make_arg((3, 3))
A = A + A.mH
with self.assertRaisesRegex(RuntimeError, "ill-defined"):
Q = torch.linalg.eigh(A).eigenvectors
Q.sum().backward()

@skipCUDAIfNoCusolver # MAGMA backend doesn't work in this case
@skipCUDAIfRocm
@precisionOverride({torch.float: 1e-4, torch.cfloat: 1e-4})
Expand Down
3 changes: 3 additions & 0 deletions test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,9 @@ def test_noncontiguous_samples(self, device, dtype, op):
if not test_grad:
continue

expected = sample_input.output_process_fn_grad(expected)
actual = sample_input.output_process_fn_grad(actual)

if isinstance(expected, torch.Tensor):
expected_backward_tensor = expected
actual_backward_tensor = actual
Expand Down
10 changes: 9 additions & 1 deletion tools/autograd/derivatives.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1467,7 +1467,15 @@

# We never call _linalg_svd with compute_uv=False in an autograd context, so we don't even consider it here
- name: _linalg_svd(Tensor A, bool full_matrices=False, bool compute_uv=True) -> (Tensor U, Tensor S, Tensor Vh)
A: svd_backward(grads, full_matrices, U, S, Vh)
A: "svd_backward(full_matrices && grad_U.defined() ? grad_U.narrow(-1, 0, S.size(-1)) : grad_U,
grad_S,
full_matrices && grad_Vh.defined() ? grad_Vh.narrow(-2, 0, S.size(-1)) : grad_Vh,
full_matrices ? U.narrow(-1, 0, S.size(-1)) : U,
S,
full_matrices ? Vh.narrow(-2, 0, S.size(-1)) : Vh)"
Comment on lines +1472 to +1475
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the narrowing is no longer done in the backward?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I moved it here to remove the full_matrices keyword from the backwards. svd_backwards is used in a number of places, and I wanted to make sure no one could call it with full_matrices=True, as the usual caveat applies: svd with full_matrices=True is not a well-defined function so, in particular, it is not differentiable.

U: std::get<0>(linalg_svd_jvp(A_t, U, S, Vh, full_matrices))
S: std::get<1>(linalg_svd_jvp(A_t, U, S, Vh, full_matrices))
Vh: std::get<2>(linalg_svd_jvp(A_t, U, S, Vh, full_matrices))

- name: symeig(Tensor self, bool eigenvectors=False, bool upper=True) -> (Tensor eigenvalues, Tensor eigenvectors)
self: linalg_eig_backward(grads[0], grads[1], eigenvalues, eigenvectors_return, /*is_hermitian=*/true, /*symeig_eigenvector=*/eigenvectors)
Expand Down
Loading