Skip to content

Commit 2ad79e2

Browse files
committed
Implement forward AD for linalg.svd and improve svd_backward
I included a derivation of the formula in the complex case, as it is particularly tricky. As far as I know, this is the first time this formula is derived in the literature. I also implemented a more efficient and more accurate version of svd_backward. More importantly, I also added a lax check in the complex case making sure the loss function just depends on the subspaces spanned by the pairs of singular vectors, and not their joint phase. ghstack-source-id: 2732ea3 Pull Request resolved: #70253
1 parent 514e83a commit 2ad79e2

File tree

8 files changed

+480
-206
lines changed

8 files changed

+480
-206
lines changed

aten/src/ATen/native/LinearAlgebra.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -285,6 +285,9 @@ Tensor linalg_pinv(
285285
const optional<Tensor>& atol_opt,
286286
const optional<Tensor>& rtol_opt,
287287
bool hermitian) {
288+
// FIXME: Whenever we have a nice lstsq, we should dispatch this function to simply be
289+
// `torch.lstsq(A, torch.eye(A.shape[-1]), atol=atol, rtol=rtol)`
290+
// with a driver that supports singular inputs
288291
NoTF32Guard disable_tf32;
289292
ScalarType t = input.scalar_type();
290293
TORCH_CHECK((t == ScalarType::Double || t == ScalarType::Float || t == ScalarType::ComplexFloat || t == ScalarType::ComplexDouble)
@@ -2342,10 +2345,7 @@ Tensor& nuclear_norm_out(const Tensor& self, IntArrayRef dim, bool keepdim, Tens
23422345

23432346
auto permutation = create_dim_backshift_permutation(dim_[0], dim_[1], self.dim());
23442347
Tensor p = self.permute(permutation);
2345-
// NOTE: U and V are computed only if gradmode is enabled, since the backward for nuclear
2346-
// norm uses svd_backward, which requires them.
2347-
Tensor result_ = at::sum(std::get<1>(at::svd(p, /*some=*/true,
2348-
/*compute_uv=*/at::GradMode::is_enabled() && self.requires_grad())), -1, keepdim);
2348+
Tensor result_ = at::sum(at::linalg_svdvals(p), -1, keepdim);
23492349
if (keepdim) {
23502350
result_.unsqueeze_(-1);
23512351
auto permutation_reverse = create_reverse_permutation(permutation);
@@ -2412,7 +2412,7 @@ static Tensor& _linalg_norm_matrix_out(Tensor& result, const Tensor &self, const
24122412
}
24132413

24142414
if (std::abs(ord) == 2) {
2415-
// Need to shift the reduction dims to the back, because at::svd will only operate on
2415+
// Need to shift the reduction dims to the back, because at::linalg_svdvals will only operate on
24162416
// the last 2 dimensions
24172417
auto permutation = create_dim_backshift_permutation(dim_[0], dim_[1], self.dim());
24182418
auto permutation_reverse = create_reverse_permutation(permutation);
@@ -2727,7 +2727,7 @@ Tensor linalg_cond(const Tensor& self, const optional<Scalar>& opt_ord) {
27272727

27282728
// If ord == None or ord == ±2
27292729
if (std::abs(ord.toDouble()) == 2.0) {
2730-
auto singular_values = std::get<1>(at::svd(self));
2730+
auto singular_values = at::linalg_svdvals(self);
27312731
// singular values are sorted in descending order
27322732
auto s_max = at::narrow(singular_values, /*dim=*/-1, /*start=*/0, /*length=*/1);
27332733
auto s_min = at::narrow(singular_values, /*dim=*/-1, /*start=*/-1, /*length=*/1);

test/test_ops.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -266,6 +266,9 @@ def test_noncontiguous_samples(self, device, dtype, op):
266266
if not test_grad:
267267
continue
268268

269+
expected = sample_input.output_process_fn_grad(expected)
270+
actual = sample_input.output_process_fn_grad(actual)
271+
269272
if isinstance(expected, torch.Tensor):
270273
expected_backward_tensor = expected
271274
actual_backward_tensor = actual

tools/autograd/derivatives.yaml

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1467,7 +1467,15 @@
14671467

14681468
# We never call _linalg_svd with compute_uv=False in an autograd context, so we don't even consider it here
14691469
- name: _linalg_svd(Tensor A, bool full_matrices=False, bool compute_uv=True) -> (Tensor U, Tensor S, Tensor Vh)
1470-
A: svd_backward(grads, full_matrices, U, S, Vh)
1470+
A: "svd_backward(full_matrices && grad_U.defined() ? grad_U.narrow(-1, 0, S.size(-1)) : grad_U,
1471+
grad_S,
1472+
full_matrices && grad_Vh.defined() ? grad_Vh.narrow(-2, 0, S.size(-1)) : grad_Vh,
1473+
full_matrices ? U.narrow(-1, 0, S.size(-1)) : U,
1474+
S,
1475+
full_matrices ? Vh.narrow(-2, 0, S.size(-1)) : Vh)"
1476+
U: std::get<0>(linalg_svd_jvp(A_t, U, S, Vh, full_matrices))
1477+
S: std::get<1>(linalg_svd_jvp(A_t, U, S, Vh, full_matrices))
1478+
Vh: std::get<2>(linalg_svd_jvp(A_t, U, S, Vh, full_matrices))
14711479

14721480
- name: symeig(Tensor self, bool eigenvectors=False, bool upper=True) -> (Tensor eigenvalues, Tensor eigenvectors)
14731481
self: linalg_eig_backward(grads[0], grads[1], eigenvalues, eigenvectors_return, /*is_hermitian=*/true, /*symeig_eigenvector=*/eigenvectors)

0 commit comments

Comments
 (0)