Skip to content

Commit 391319e

Browse files
lezcanofacebook-github-bot
authored andcommitted
Implement forward AD for linalg.svd and improve svd_backward (#70253)
Summary: Pull Request resolved: #70253 I included a derivation of the formula in the complex case, as it is particularly tricky. As far as I know, this is the first time this formula is derived in the literature. I also implemented a more efficient and more accurate version of svd_backward. More importantly, I also added a lax check in the complex case making sure the loss function just depends on the subspaces spanned by the pairs of singular vectors, and not their joint phase. cc jianyuh nikitaved pearu mruberry walterddr IvanYashchuk xwang233 Lezcano Test Plan: Imported from OSS Reviewed By: mikaylagawarecki Differential Revision: D33751982 Pulled By: mruberry fbshipit-source-id: c2a4a92a921a732357e99c01ccb563813b1af512
1 parent a1860bd commit 391319e

File tree

9 files changed

+501
-206
lines changed

9 files changed

+501
-206
lines changed

aten/src/ATen/native/LinearAlgebra.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -285,6 +285,9 @@ Tensor linalg_pinv(
285285
const optional<Tensor>& atol_opt,
286286
const optional<Tensor>& rtol_opt,
287287
bool hermitian) {
288+
// FIXME: Whenever we have a nice lstsq, we should dispatch this function to simply be
289+
// `torch.lstsq(A, torch.eye(A.shape[-1]), atol=atol, rtol=rtol)`
290+
// with a driver that supports singular inputs
288291
NoTF32Guard disable_tf32;
289292
ScalarType t = input.scalar_type();
290293
TORCH_CHECK((t == ScalarType::Double || t == ScalarType::Float || t == ScalarType::ComplexFloat || t == ScalarType::ComplexDouble)
@@ -2347,10 +2350,7 @@ Tensor& nuclear_norm_out(const Tensor& self, IntArrayRef dim, bool keepdim, Tens
23472350

23482351
auto permutation = create_dim_backshift_permutation(dim_[0], dim_[1], self.dim());
23492352
Tensor p = self.permute(permutation);
2350-
// NOTE: U and V are computed only if gradmode is enabled, since the backward for nuclear
2351-
// norm uses svd_backward, which requires them.
2352-
Tensor result_ = at::sum(std::get<1>(at::svd(p, /*some=*/true,
2353-
/*compute_uv=*/at::GradMode::is_enabled() && self.requires_grad())), -1, keepdim);
2353+
Tensor result_ = at::sum(at::linalg_svdvals(p), -1, keepdim);
23542354
if (keepdim) {
23552355
result_.unsqueeze_(-1);
23562356
auto permutation_reverse = create_reverse_permutation(permutation);
@@ -2417,7 +2417,7 @@ static Tensor& _linalg_norm_matrix_out(Tensor& result, const Tensor &self, const
24172417
}
24182418

24192419
if (std::abs(ord) == 2) {
2420-
// Need to shift the reduction dims to the back, because at::svd will only operate on
2420+
// Need to shift the reduction dims to the back, because at::linalg_svdvals will only operate on
24212421
// the last 2 dimensions
24222422
auto permutation = create_dim_backshift_permutation(dim_[0], dim_[1], self.dim());
24232423
auto permutation_reverse = create_reverse_permutation(permutation);
@@ -2732,7 +2732,7 @@ Tensor linalg_cond(const Tensor& self, const optional<Scalar>& opt_ord) {
27322732

27332733
// If ord == None or ord == ±2
27342734
if (std::abs(ord.toDouble()) == 2.0) {
2735-
auto singular_values = std::get<1>(at::svd(self));
2735+
auto singular_values = at::linalg_svdvals(self);
27362736
// singular values are sorted in descending order
27372737
auto s_max = at::narrow(singular_values, /*dim=*/-1, /*start=*/0, /*length=*/1);
27382738
auto s_min = at::narrow(singular_values, /*dim=*/-1, /*start=*/-1, /*length=*/1);

test/test_linalg.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2817,6 +2817,27 @@ def test_svd(self, device, dtype):
28172817
S_s = torch.svd(A, compute_uv=False).S
28182818
self.assertEqual(S_s, S)
28192819

2820+
@skipCUDAIfNoMagmaAndNoCusolver
2821+
@skipCPUIfNoLapack
2822+
@dtypes(torch.complex128)
2823+
def test_invariance_error_spectral_decompositions(self, device, dtype):
2824+
make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=True)
2825+
A = make_arg((3, 3))
2826+
with self.assertRaisesRegex(RuntimeError, "ill-defined"):
2827+
U, _, Vh = torch.linalg.svd(A, full_matrices=False)
2828+
(U + Vh).sum().backward()
2829+
2830+
A = make_arg((3, 3))
2831+
with self.assertRaisesRegex(RuntimeError, "ill-defined"):
2832+
V = torch.linalg.eig(A).eigenvectors
2833+
V.sum().backward()
2834+
2835+
A = make_arg((3, 3))
2836+
A = A + A.mH
2837+
with self.assertRaisesRegex(RuntimeError, "ill-defined"):
2838+
Q = torch.linalg.eigh(A).eigenvectors
2839+
Q.sum().backward()
2840+
28202841
@skipCUDAIfNoCusolver # MAGMA backend doesn't work in this case
28212842
@skipCUDAIfRocm
28222843
@precisionOverride({torch.float: 1e-4, torch.cfloat: 1e-4})

test/test_ops.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -266,6 +266,9 @@ def test_noncontiguous_samples(self, device, dtype, op):
266266
if not test_grad:
267267
continue
268268

269+
expected = sample_input.output_process_fn_grad(expected)
270+
actual = sample_input.output_process_fn_grad(actual)
271+
269272
if isinstance(expected, torch.Tensor):
270273
expected_backward_tensor = expected
271274
actual_backward_tensor = actual

tools/autograd/derivatives.yaml

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1471,7 +1471,15 @@
14711471

14721472
# We never call _linalg_svd with compute_uv=False in an autograd context, so we don't even consider it here
14731473
- name: _linalg_svd(Tensor A, bool full_matrices=False, bool compute_uv=True) -> (Tensor U, Tensor S, Tensor Vh)
1474-
A: svd_backward(grads, full_matrices, U, S, Vh)
1474+
A: "svd_backward(full_matrices && grad_U.defined() ? grad_U.narrow(-1, 0, S.size(-1)) : grad_U,
1475+
grad_S,
1476+
full_matrices && grad_Vh.defined() ? grad_Vh.narrow(-2, 0, S.size(-1)) : grad_Vh,
1477+
full_matrices ? U.narrow(-1, 0, S.size(-1)) : U,
1478+
S,
1479+
full_matrices ? Vh.narrow(-2, 0, S.size(-1)) : Vh)"
1480+
U: std::get<0>(linalg_svd_jvp(A_t, U, S, Vh, full_matrices))
1481+
S: std::get<1>(linalg_svd_jvp(A_t, U, S, Vh, full_matrices))
1482+
Vh: std::get<2>(linalg_svd_jvp(A_t, U, S, Vh, full_matrices))
14751483

14761484
- name: symeig(Tensor self, bool eigenvectors=False, bool upper=True) -> (Tensor eigenvalues, Tensor eigenvectors)
14771485
self: linalg_eig_backward(grads[0], grads[1], eigenvalues, eigenvectors_return, /*is_hermitian=*/true, /*symeig_eigenvector=*/eigenvectors)

0 commit comments

Comments
 (0)