Skip to content

Implement forward AD for linalg.svd and improve svd_backward #70253

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 32 commits into from

Conversation

lezcano
Copy link
Collaborator

@lezcano lezcano commented Dec 21, 2021

Stack from ghstack:

I included a derivation of the formula in the complex case, as it is
particularly tricky. As far as I know, this is the first time this formula
is derived in the literature.

I also implemented a more efficient and more accurate version of svd_backward.
More importantly, I also added a lax check in the complex case making sure the loss
function just depends on the subspaces spanned by the pairs of singular
vectors, and not their joint phase.

cc @jianyuh @nikitaved @pearu @mruberry @walterddr @IvanYashchuk @xwang233 @lezcano

Differential Revision: D33751982

I included a derivation of the formula in the complex case, as it is
particularly tricky. As far as I know, this is the first time this formula
is derived in the literature.

I also implemented a more efficient and more accurate version of svd_backward.
More importantly, I also added a lax check in the complex case making sure the loss
function just depends on the subspaces spanned by the pairs of singular
vectors, and not their joint phase.

[ghstack-poisoned]
@pytorch-probot
Copy link

pytorch-probot bot commented Dec 21, 2021

CI Flow Status

⚛️ CI Flow

Ruleset - Version: v1
Ruleset - File: https://github.com/pytorch/pytorch/blob/6e6d11b8016152ae4557976e8f08029a90b124b6/.github/generated-ciflow-ruleset.json
PR ciflow labels: ciflow/default,ciflow/all

Workflows Labels (bold enabled) Status
Triggered Workflows
caffe2-linux-xenial-py3.7-gcc5.4 ciflow/all, ciflow/cpu, ciflow/linux, ciflow/trunk ✅ triggered
docker-builds ciflow/all, ciflow/trunk ✅ triggered
ios-12-5-1-arm64 ciflow/all, ciflow/ios, ciflow/macos, ciflow/trunk ✅ triggered
ios-12-5-1-arm64-coreml ciflow/all, ciflow/ios, ciflow/macos, ciflow/trunk ✅ triggered
ios-12-5-1-arm64-custom-ops ciflow/all, ciflow/ios, ciflow/macos, ciflow/trunk ✅ triggered
ios-12-5-1-arm64-full-jit ciflow/all, ciflow/ios, ciflow/macos, ciflow/trunk ✅ triggered
ios-12-5-1-arm64-metal ciflow/all, ciflow/ios, ciflow/macos, ciflow/trunk ✅ triggered
ios-12-5-1-x86-64 ciflow/all, ciflow/ios, ciflow/macos, ciflow/trunk ✅ triggered
ios-12-5-1-x86-64-coreml ciflow/all, ciflow/ios, ciflow/macos, ciflow/trunk ✅ triggered
ios-12-5-1-x86-64-full-jit ciflow/all, ciflow/ios, ciflow/macos, ciflow/trunk ✅ triggered
libtorch-linux-xenial-cuda10.2-py3.7-gcc7 ciflow/all, ciflow/cuda, ciflow/libtorch, ciflow/linux, ciflow/trunk ✅ triggered
libtorch-linux-xenial-cuda11.3-py3.7-gcc7 ciflow/all, ciflow/cuda, ciflow/libtorch, ciflow/linux, ciflow/trunk ✅ triggered
linux-bionic-cuda10.2-py3.9-gcc7 ciflow/all, ciflow/cuda, ciflow/linux, ciflow/slow, ciflow/trunk ✅ triggered
linux-bionic-py3.7-clang9 ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/noarch, ciflow/trunk ✅ triggered
linux-docs ciflow/all, ciflow/cpu, ciflow/default, ciflow/docs, ciflow/linux, ciflow/trunk ✅ triggered
linux-docs-push ciflow/all, ciflow/cpu, ciflow/linux, ciflow/scheduled ✅ triggered
linux-vulkan-bionic-py3.7-clang9 ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/trunk, ciflow/vulkan ✅ triggered
linux-xenial-cuda11.3-py3.7-gcc7 ciflow/all, ciflow/cuda, ciflow/default, ciflow/linux, ciflow/trunk ✅ triggered
linux-xenial-cuda11.3-py3.7-gcc7-bazel-test ciflow/all, ciflow/bazel, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/trunk ✅ triggered
linux-xenial-cuda11.3-py3.7-gcc7-no-ops ciflow/all, ciflow/cuda, ciflow/linux, ciflow/trunk ✅ triggered
linux-xenial-py3-clang5-mobile-build ciflow/all, ciflow/default, ciflow/linux, ciflow/mobile, ciflow/trunk ✅ triggered
linux-xenial-py3-clang5-mobile-custom-build-static ciflow/all, ciflow/default, ciflow/linux, ciflow/mobile, ciflow/trunk ✅ triggered
linux-xenial-py3.7-clang7-asan ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/sanitizers, ciflow/trunk ✅ triggered
linux-xenial-py3.7-clang7-onnx ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/onnx, ciflow/trunk ✅ triggered
linux-xenial-py3.7-gcc5.4 ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/trunk ✅ triggered
linux-xenial-py3.7-gcc7 ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/trunk ✅ triggered
linux-xenial-py3.7-gcc7-no-ops ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/trunk ✅ triggered
macos-10-15-py3-arm64 ciflow/all, ciflow/macos, ciflow/trunk ✅ triggered
macos-10-15-py3-lite-interpreter-x86-64 ciflow/all, ciflow/macos, ciflow/trunk ✅ triggered
macos-11-py3-x86-64 ciflow/all, ciflow/macos, ciflow/trunk ✅ triggered
parallelnative-linux-xenial-py3.7-gcc5.4 ciflow/all, ciflow/cpu, ciflow/linux, ciflow/trunk ✅ triggered
periodic-libtorch-linux-bionic-cuda11.5-py3.7-gcc7 ciflow/all, ciflow/cuda, ciflow/libtorch, ciflow/linux, ciflow/scheduled ✅ triggered
periodic-libtorch-linux-xenial-cuda11.1-py3.7-gcc7 ciflow/all, ciflow/cuda, ciflow/libtorch, ciflow/linux, ciflow/scheduled ✅ triggered
periodic-linux-bionic-cuda11.5-py3.7-gcc7 ciflow/all, ciflow/cuda, ciflow/linux, ciflow/scheduled ✅ triggered
periodic-linux-xenial-cuda10.2-py3-gcc7-slow-gradcheck ciflow/all, ciflow/cuda, ciflow/linux, ciflow/scheduled, ciflow/slow, ciflow/slow-gradcheck ✅ triggered
periodic-linux-xenial-cuda11.1-py3.7-gcc7-debug ciflow/all, ciflow/cuda, ciflow/linux, ciflow/scheduled ✅ triggered
periodic-win-vs2019-cuda11.1-py3 ciflow/all, ciflow/cuda, ciflow/scheduled, ciflow/win ✅ triggered
periodic-win-vs2019-cuda11.5-py3 ciflow/all, ciflow/cuda, ciflow/scheduled, ciflow/win ✅ triggered
pytorch-linux-xenial-py3-clang5-android-ndk-r19c-build ciflow/all, ciflow/android, ciflow/cpu, ciflow/linux, ciflow/trunk ✅ triggered
pytorch-linux-xenial-py3-clang5-android-ndk-r19c-gradle-custom-build-single ciflow/all, ciflow/android, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/trunk ✅ triggered
pytorch-linux-xenial-py3-clang5-android-ndk-r19c-gradle-custom-build-single-full-jit ciflow/all, ciflow/android, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/trunk ✅ triggered
win-vs2019-cpu-py3 ciflow/all, ciflow/cpu, ciflow/default, ciflow/trunk, ciflow/win ✅ triggered
win-vs2019-cuda11.3-py3 ciflow/all, ciflow/cuda, ciflow/default, ciflow/trunk, ciflow/win ✅ triggered
Skipped Workflows
linux-binary-conda ciflow/binaries, ciflow/binaries/conda 🚫 skipped
linux-binary-libtorch-cxx11-abi ciflow/binaries, ciflow/binaries/libtorch 🚫 skipped
linux-binary-libtorch-pre-cxx11 ciflow/binaries, ciflow/binaries/libtorch 🚫 skipped
linux-binary-manywheel ciflow/binaries, ciflow/binaries/wheel 🚫 skipped
linux-bionic-py3.6-clang9 ciflow/xla 🚫 skipped

You can add a comment to the PR and tag @pytorchbot with the following commands:
# ciflow rerun, "ciflow/default" will always be added automatically
@pytorchbot ciflow rerun

# ciflow rerun with additional labels "-l <ciflow/label_name>", which is equivalent to adding these labels manually and trigger the rerun
@pytorchbot ciflow rerun -l ciflow/scheduled -l ciflow/slow

For more information, please take a look at the CI Flow Wiki.

@facebook-github-bot
Copy link
Contributor

facebook-github-bot commented Dec 21, 2021

🔗 Helpful links

💊 CI failures summary and remediations

As of commit aa7ce25 (more details on the Dr. CI page):


  • 6/6 failures possibly* introduced in this PR
    • 1/6 non-scanned failure(s)

🕵️ 5 new failures recognized by patterns

The following CI failures do not appear to be due to upstream breakages:

See GitHub Actions build periodic-win-vs2019-cuda11.5-py3 / test (default, 1, 2, windows.8xlarge.nvidia.gpu) (1/5)

Step: "Test" (full log | diagnosis details | 🔁 rerun)

2022-01-26T16:39:05.7791635Z RuntimeError: test_ops failed!
2022-01-26T16:38:50.1558629Z Generated XML report: test-reports\python-unittest\test_ops\TEST-TestGradientsCUDA-20220126150636.xml
2022-01-26T16:38:50.1560167Z Generated XML report: test-reports\python-unittest\test_ops\TEST-TestJitCPU-20220126150636.xml
2022-01-26T16:38:50.1561571Z Generated XML report: test-reports\python-unittest\test_ops\TEST-TestJitCUDA-20220126150636.xml
2022-01-26T16:38:50.1563069Z Generated XML report: test-reports\python-unittest\test_ops\TEST-TestMathBitsCPU-20220126150636.xml
2022-01-26T16:38:50.1564645Z Generated XML report: test-reports\python-unittest\test_ops\TEST-TestMathBitsCUDA-20220126150636.xml
2022-01-26T16:39:05.7786680Z Traceback (most recent call last):
2022-01-26T16:39:05.7788159Z   File "run_test.py", line 1101, in <module>
2022-01-26T16:39:05.7788920Z     main()
2022-01-26T16:39:05.7789753Z   File "run_test.py", line 1079, in main
2022-01-26T16:39:05.7790683Z     raise RuntimeError(err_message)
2022-01-26T16:39:05.7791635Z RuntimeError: test_ops failed!
2022-01-26T16:39:06.1579175Z 
2022-01-26T16:39:06.1580110Z (base) C:\actions-runner\_work\pytorch\pytorch\test>if ERRORLEVEL 1 goto fail 
2022-01-26T16:39:06.1583618Z 
2022-01-26T16:39:06.1584628Z (base) C:\actions-runner\_work\pytorch\pytorch\test>exit /b 1 
2022-01-26T16:39:06.1624571Z + cleanup
2022-01-26T16:39:06.1625151Z + retcode=1
2022-01-26T16:39:06.1625606Z + set +x
2022-01-26T16:39:06.1663695Z ##[error]Process completed with exit code 1.
2022-01-26T16:39:06.1936676Z ##[group]Run # -ir => recursive include all files in pattern
2022-01-26T16:39:06.1937603Z �[36;1m# -ir => recursive include all files in pattern�[0m

See GitHub Actions build periodic-win-vs2019-cuda11.1-py3 / test (default, 1, 2, windows.8xlarge.nvidia.gpu) (2/5)

Step: "Test" (full log | diagnosis details | 🔁 rerun)

2022-01-26T15:38:06.6983633Z RuntimeError: CUDA error: an illegal memory access was encountered
2022-01-26T15:38:06.6973070Z   File "C:\actions-runner\_work\pytorch\pytorch\build\win_tmp\build\torch\autograd\gradcheck.py", line 555, in _get_analytical_vJu_backward_mode
2022-01-26T15:38:06.6974201Z     all_vJ = _check_analytical_jacobian_attributes(inputs, output, nondet_tol, check_grad_dtypes,
2022-01-26T15:38:06.6975441Z   File "C:\actions-runner\_work\pytorch\pytorch\build\win_tmp\build\torch\autograd\gradcheck.py", line 529, in _check_analytical_jacobian_attributes
2022-01-26T15:38:06.6976552Z     vjps2 = _get_analytical_vjps_wrt_specific_output(vjp_fn, output.clone(), v)
2022-01-26T15:38:06.6977714Z   File "C:\actions-runner\_work\pytorch\pytorch\build\win_tmp\build\torch\autograd\gradcheck.py", line 633, in _get_analytical_vjps_wrt_specific_output
2022-01-26T15:38:06.6978768Z     grad_inputs = vjp_fn(v.reshape(sample_output.shape))
2022-01-26T15:38:06.6979722Z   File "C:\actions-runner\_work\pytorch\pytorch\build\win_tmp\build\torch\autograd\gradcheck.py", line 524, in vjp_fn
2022-01-26T15:38:06.6980704Z     return torch.autograd.grad(output, diff_input_list, grad_output,
2022-01-26T15:38:06.6981696Z   File "C:\actions-runner\_work\pytorch\pytorch\build\win_tmp\build\torch\autograd\__init__.py", line 275, in grad
2022-01-26T15:38:06.6982741Z     return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
2022-01-26T15:38:06.6983633Z RuntimeError: CUDA error: an illegal memory access was encountered
2022-01-26T15:38:06.6984638Z CUDA kernel errors might be asynchronously reported at some other API call,so the stacktrace below might be incorrect.
2022-01-26T15:38:06.6985574Z For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
2022-01-26T15:38:06.6985987Z 
2022-01-26T15:38:06.6986674Z 🚨 ERROR: TestGradientsCUDA.test_fn_fwgrad_bwgrad_gradient_cuda_complex128
2022-01-26T15:38:06.6987373Z None
2022-01-26T15:38:06.6988022Z 🚨 ERROR: TestGradientsCUDA.test_fn_fwgrad_bwgrad_gradient_cuda_float64
2022-01-26T15:38:06.6988649Z None
2022-01-26T15:38:06.6995150Z ✅ 19045 Passed
2022-01-26T15:38:06.6995592Z 💨 13241 Skipped
2022-01-26T15:38:06.6995966Z 🚨 3 Failed

See GitHub Actions build win-vs2019-cuda11.3-py3 / test (default, 1, 2, windows.8xlarge.nvidia.gpu) (3/5)

Step: "Test" (full log | diagnosis details | 🔁 rerun)

2022-01-26T16:43:55.8759728Z RuntimeError: test_ops failed!
2022-01-26T16:43:38.4407197Z Generated XML report: test-reports\python-unittest\test_ops\TEST-TestGradientsCUDA-20220126150935.xml
2022-01-26T16:43:38.4408742Z Generated XML report: test-reports\python-unittest\test_ops\TEST-TestJitCPU-20220126150935.xml
2022-01-26T16:43:38.4410122Z Generated XML report: test-reports\python-unittest\test_ops\TEST-TestJitCUDA-20220126150935.xml
2022-01-26T16:43:38.4411627Z Generated XML report: test-reports\python-unittest\test_ops\TEST-TestMathBitsCPU-20220126150935.xml
2022-01-26T16:43:38.4413509Z Generated XML report: test-reports\python-unittest\test_ops\TEST-TestMathBitsCUDA-20220126150935.xml
2022-01-26T16:43:55.8754264Z Traceback (most recent call last):
2022-01-26T16:43:55.8756265Z   File "run_test.py", line 1101, in <module>
2022-01-26T16:43:55.8757078Z     main()
2022-01-26T16:43:55.8757887Z   File "run_test.py", line 1079, in main
2022-01-26T16:43:55.8758812Z     raise RuntimeError(err_message)
2022-01-26T16:43:55.8759728Z RuntimeError: test_ops failed!
2022-01-26T16:43:56.2974401Z 
2022-01-26T16:43:56.2975753Z (base) C:\actions-runner\_work\pytorch\pytorch\test>if ERRORLEVEL 1 goto fail 
2022-01-26T16:43:56.2978960Z 
2022-01-26T16:43:56.2980136Z (base) C:\actions-runner\_work\pytorch\pytorch\test>exit /b 1 
2022-01-26T16:43:56.3023506Z + cleanup
2022-01-26T16:43:56.3024395Z + retcode=1
2022-01-26T16:43:56.3024858Z + set +x
2022-01-26T16:43:56.3065744Z ##[error]Process completed with exit code 1.
2022-01-26T16:43:56.3408039Z ##[group]Run # -ir => recursive include all files in pattern
2022-01-26T16:43:56.3408913Z �[36;1m# -ir => recursive include all files in pattern�[0m

See GitHub Actions build linux-bionic-cuda10.2-py3.9-gcc7 / test (default, 1, 2, linux.4xlarge.nvidia.gpu) (4/5)

Step: "Test" (full log | diagnosis details | 🔁 rerun)

2022-01-26T14:23:01.7357182Z RuntimeError: CUDA error: an illegal memory access was encountered
2022-01-26T14:23:01.7352209Z   File "/var/lib/jenkins/workspace/test/test_ops.py", line 1262, in test_neg_view
2022-01-26T14:23:01.7352634Z     self._test_math_view(device, dtype, op, samples, math_op_physical, math_op_view, is_bit_set,
2022-01-26T14:23:01.7353041Z   File "/var/lib/jenkins/workspace/test/test_ops.py", line 1231, in _test_math_view
2022-01-26T14:23:01.7353693Z     self.assertEqual(tensor.grad, cloned1_tensor.grad)
2022-01-26T14:23:01.7354259Z   File "/opt/conda/lib/python3.9/site-packages/torch/testing/_internal/common_utils.py", line 2083, in assertEqual
2022-01-26T14:23:01.7354695Z     result, debug_msg_generic = self._compareTensors(x, y, rtol=rtol, atol=atol,
2022-01-26T14:23:01.7355265Z   File "/opt/conda/lib/python3.9/site-packages/torch/testing/_internal/common_utils.py", line 1948, in _compareTensors
2022-01-26T14:23:01.7355707Z     return _compare_tensors_internal(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan)
2022-01-26T14:23:01.7356390Z   File "/opt/conda/lib/python3.9/site-packages/torch/testing/_core.py", line 96, in _compare_tensors_internal
2022-01-26T14:23:01.7356819Z     if torch.allclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan):
2022-01-26T14:23:01.7357182Z RuntimeError: CUDA error: an illegal memory access was encountered
2022-01-26T14:23:01.7357648Z CUDA kernel errors might be asynchronously reported at some other API call,so the stacktrace below might be incorrect.
2022-01-26T14:23:01.7358084Z For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
2022-01-26T14:23:01.7358278Z 
2022-01-26T14:23:01.7358537Z ✅ 7609 Passed
2022-01-26T14:23:01.7358795Z 💨 5859 Skipped
2022-01-26T14:23:01.7359049Z 🚨 1 Failed
2022-01-26T14:23:01.7751238Z ##[group]Run # Remove any previous test jsons if they exist
2022-01-26T14:23:01.7751648Z �[36;1m# Remove any previous test jsons if they exist�[0m
2022-01-26T14:23:01.7751962Z �[36;1mrm -f test-jsons-*.zip�[0m
2022-01-26T14:23:01.7752292Z �[36;1mzip -r "test-jsons-${FILE_SUFFIX}.zip" test -i '*.json'�[0m

See GitHub Actions build periodic-linux-xenial-cuda11.1-py3.7-gcc7-debug / test (default, 2, 2, linux.4xlarge.nvidia.gpu) (5/5)

Step: "Test" (full log | diagnosis details | 🔁 rerun)

2022-01-26T14:28:58.7248577Z RuntimeError: CUDA error: an illegal memory access was encountered
2022-01-26T14:28:58.7244412Z   File "test_ops.py", line 1263, in test_neg_view
2022-01-26T14:28:58.7244679Z     lambda x: True)
2022-01-26T14:28:58.7244940Z   File "test_ops.py", line 1231, in _test_math_view
2022-01-26T14:28:58.7245381Z     self.assertEqual(tensor.grad, cloned1_tensor.grad)
2022-01-26T14:28:58.7245916Z   File "/opt/conda/lib/python3.7/site-packages/torch/testing/_internal/common_utils.py", line 2085, in assertEqual
2022-01-26T14:28:58.7246292Z     exact_device=exact_device)
2022-01-26T14:28:58.7246774Z   File "/opt/conda/lib/python3.7/site-packages/torch/testing/_internal/common_utils.py", line 1948, in _compareTensors
2022-01-26T14:28:58.7247233Z     return _compare_tensors_internal(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan)
2022-01-26T14:28:58.7247784Z   File "/opt/conda/lib/python3.7/site-packages/torch/testing/_core.py", line 96, in _compare_tensors_internal
2022-01-26T14:28:58.7248193Z     if torch.allclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan):
2022-01-26T14:28:58.7248577Z RuntimeError: CUDA error: an illegal memory access was encountered
2022-01-26T14:28:58.7249037Z CUDA kernel errors might be asynchronously reported at some other API call,so the stacktrace below might be incorrect.
2022-01-26T14:28:58.7249459Z For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
2022-01-26T14:28:58.7249736Z 		
2022-01-26T14:28:58.7249983Z ✅ 7609 Passed
2022-01-26T14:28:58.7250259Z 💨 5859 Skipped
2022-01-26T14:28:58.7250491Z 🚨 1 Failed
2022-01-26T14:28:58.7637748Z ##[group]Run # Remove any previous test jsons if they exist
2022-01-26T14:28:58.7638152Z �[36;1m# Remove any previous test jsons if they exist�[0m
2022-01-26T14:28:58.7638446Z �[36;1mrm -f test-jsons-*.zip�[0m
2022-01-26T14:28:58.7638774Z �[36;1mzip -r "test-jsons-${FILE_SUFFIX}.zip" test -i '*.json'�[0m

ci.pytorch.org: 1 failed


This comment was automatically generated by Dr. CI (expand for details).

Please report bugs/suggestions to the (internal) Dr. CI Users group.

Click here to manually regenerate this comment.

lezcano added a commit that referenced this pull request Dec 21, 2021
I included a derivation of the formula in the complex case, as it is
particularly tricky. As far as I know, this is the first time this formula
is derived in the literature.

I also implemented a more efficient and more accurate version of svd_backward.
More importantly, I also added a lax check in the complex case making sure the loss
function just depends on the subspaces spanned by the pairs of singular
vectors, and not their joint phase.

ghstack-source-id: 1f6ccb3
Pull Request resolved: #70253
@lezcano lezcano requested review from nikitaved and IvanYashchuk and removed request for soulitzer December 21, 2021 18:59
@lezcano lezcano added module: derivatives Related to derivatives of operators module: linear algebra Issues related to specialized linear algebra operations in PyTorch; includes matrix multiply matmul labels Dec 21, 2021
I included a derivation of the formula in the complex case, as it is
particularly tricky. As far as I know, this is the first time this formula
is derived in the literature.

I also implemented a more efficient and more accurate version of svd_backward.
More importantly, I also added a lax check in the complex case making sure the loss
function just depends on the subspaces spanned by the pairs of singular
vectors, and not their joint phase.

cc jianyuh nikitaved pearu mruberry walterddr IvanYashchuk xwang233 Lezcano

[ghstack-poisoned]
I included a derivation of the formula in the complex case, as it is
particularly tricky. As far as I know, this is the first time this formula
is derived in the literature.

I also implemented a more efficient and more accurate version of svd_backward.
More importantly, I also added a lax check in the complex case making sure the loss
function just depends on the subspaces spanned by the pairs of singular
vectors, and not their joint phase.

cc jianyuh nikitaved pearu mruberry walterddr IvanYashchuk xwang233 Lezcano

[ghstack-poisoned]
lezcano added a commit that referenced this pull request Dec 22, 2021
I included a derivation of the formula in the complex case, as it is
particularly tricky. As far as I know, this is the first time this formula
is derived in the literature.

I also implemented a more efficient and more accurate version of svd_backward.
More importantly, I also added a lax check in the complex case making sure the loss
function just depends on the subspaces spanned by the pairs of singular
vectors, and not their joint phase.

ghstack-source-id: 09838d3
Pull Request resolved: #70253
I included a derivation of the formula in the complex case, as it is
particularly tricky. As far as I know, this is the first time this formula
is derived in the literature.

I also implemented a more efficient and more accurate version of svd_backward.
More importantly, I also added a lax check in the complex case making sure the loss
function just depends on the subspaces spanned by the pairs of singular
vectors, and not their joint phase.

cc jianyuh nikitaved pearu mruberry walterddr IvanYashchuk xwang233 Lezcano

[ghstack-poisoned]
lezcano added a commit that referenced this pull request Dec 29, 2021
I included a derivation of the formula in the complex case, as it is
particularly tricky. As far as I know, this is the first time this formula
is derived in the literature.

I also implemented a more efficient and more accurate version of svd_backward.
More importantly, I also added a lax check in the complex case making sure the loss
function just depends on the subspaces spanned by the pairs of singular
vectors, and not their joint phase.

ghstack-source-id: 8c581f1
Pull Request resolved: #70253
I included a derivation of the formula in the complex case, as it is
particularly tricky. As far as I know, this is the first time this formula
is derived in the literature.

I also implemented a more efficient and more accurate version of svd_backward.
More importantly, I also added a lax check in the complex case making sure the loss
function just depends on the subspaces spanned by the pairs of singular
vectors, and not their joint phase.

cc jianyuh nikitaved pearu mruberry walterddr IvanYashchuk xwang233 Lezcano

[ghstack-poisoned]
lezcano added a commit that referenced this pull request Dec 30, 2021
This PR adds checks for the backward of `linalg.eig`, similar to those
deduced in #70253

It also modifies the function so that it does not save the input matrix,
as it's not necessary.

It also corrects the forward AD formula for it to be correct. Now all
the tests pass for `linalg.eig` and `linalg.eigvals`.

It also updates the docs to reflect better what's going on here.

[ghstack-poisoned]
I included a derivation of the formula in the complex case, as it is
particularly tricky. As far as I know, this is the first time this formula
is derived in the literature.

I also implemented a more efficient and more accurate version of svd_backward.
More importantly, I also added a lax check in the complex case making sure the loss
function just depends on the subspaces spanned by the pairs of singular
vectors, and not their joint phase.

cc jianyuh nikitaved pearu mruberry walterddr IvanYashchuk xwang233 Lezcano

Differential Revision: [D33751982](https://our.internmc.facebook.com/intern/diff/D33751982)

[ghstack-poisoned]
lezcano added a commit that referenced this pull request Jan 25, 2022
I included a derivation of the formula in the complex case, as it is
particularly tricky. As far as I know, this is the first time this formula
is derived in the literature.

I also implemented a more efficient and more accurate version of svd_backward.
More importantly, I also added a lax check in the complex case making sure the loss
function just depends on the subspaces spanned by the pairs of singular
vectors, and not their joint phase.

ghstack-source-id: 2732ea3
Pull Request resolved: #70253
@nikitaved
Copy link
Collaborator

nikitaved commented Jan 26, 2022

I have a general question, @lezcano , @IvanYashchuk , @mruberry , @ngimel. Do we want to keep the check for the invariance in the backward? That seems like a significant perf penatly because of device sync. Maybe it is better to mention this invariance in the documentation? Granted that checking it from the user interface is problematic, but then the complex case is quite non-trivial so maybe the user is already aware of potential non-uniqueness issues... Or maybe we can introduce an additional input parameter that controls the check, but.. I do not know whether it is a good solution.

@lezcano
Copy link
Collaborator Author

lezcano commented Jan 26, 2022

I chose this design because, as you said, this function is already rather difficult, so whatever we can do to improve the UX will be welcome by them. Note that in JAX (and by extension TF), they wanted to do this, but they did not know how to do it. See this comment jax-ml/jax#2748 (comment) in particular and the rest of that issue / linked issues for context.

Also, it is true that this synchronises but:

  • The SVD is already a very heavy function, so well.
  • The check just happens in the complex case, which is not used that much.
  • We already do this in eig and eigh:
    if (V.is_complex()) {
    // Check invariance of the loss function wrt the transformation V -> V e^{i\phi}
    const auto imdiag_VhgV = at::imag(diag_VhgV);
    TORCH_CHECK(at::allclose(imdiag_VhgV, at::zeros_like(imdiag_VhgV), /*rtol=*/1e-2, /*atol=*/1e-2),
    is_hermitian ? "linalg_eigh_backward" : "linalg_eig_backward",
    ": The eigenvectors in the complex case are specified up to multiplication ",
    "by e^{i phi}. The specified loss function depends on this quantity, so it is ill-defined.");
    }

I included a derivation of the formula in the complex case, as it is
particularly tricky. As far as I know, this is the first time this formula
is derived in the literature.

I also implemented a more efficient and more accurate version of svd_backward.
More importantly, I also added a lax check in the complex case making sure the loss
function just depends on the subspaces spanned by the pairs of singular
vectors, and not their joint phase.

cc jianyuh nikitaved pearu mruberry walterddr IvanYashchuk xwang233 Lezcano

Differential Revision: [D33751982](https://our.internmc.facebook.com/intern/diff/D33751982)

[ghstack-poisoned]
lezcano added a commit that referenced this pull request Jan 26, 2022
I included a derivation of the formula in the complex case, as it is
particularly tricky. As far as I know, this is the first time this formula
is derived in the literature.

I also implemented a more efficient and more accurate version of svd_backward.
More importantly, I also added a lax check in the complex case making sure the loss
function just depends on the subspaces spanned by the pairs of singular
vectors, and not their joint phase.

ghstack-source-id: c74c94f
Pull Request resolved: #70253
Copy link
Collaborator

@nikitaved nikitaved left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thanks you, Mario! Your call, @mruberry !

@mruberry
Copy link
Collaborator

@mruberry has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@lezcano
Copy link
Collaborator Author

lezcano commented Jan 27, 2022

Failures seem unrelated

facebook-github-bot pushed a commit that referenced this pull request Jan 27, 2022
Summary:
Pull Request resolved: #70253

I included a derivation of the formula in the complex case, as it is
particularly tricky. As far as I know, this is the first time this formula
is derived in the literature.

I also implemented a more efficient and more accurate version of svd_backward.
More importantly, I also added a lax check in the complex case making sure the loss
function just depends on the subspaces spanned by the pairs of singular
vectors, and not their joint phase.

cc jianyuh nikitaved pearu mruberry walterddr IvanYashchuk xwang233 Lezcano

Test Plan: Imported from OSS

Reviewed By: mikaylagawarecki

Differential Revision: D33751982

Pulled By: mruberry

fbshipit-source-id: c2a4a92a921a732357e99c01ccb563813b1af512
alexhagiopol pushed a commit that referenced this pull request Jan 28, 2022
Summary:
Pull Request resolved: #70253

I included a derivation of the formula in the complex case, as it is
particularly tricky. As far as I know, this is the first time this formula
is derived in the literature.

I also implemented a more efficient and more accurate version of svd_backward.
More importantly, I also added a lax check in the complex case making sure the loss
function just depends on the subspaces spanned by the pairs of singular
vectors, and not their joint phase.

cc jianyuh nikitaved pearu mruberry walterddr IvanYashchuk xwang233 Lezcano

Test Plan: Imported from OSS

Reviewed By: mikaylagawarecki

Differential Revision: D33751982

Pulled By: mruberry

fbshipit-source-id: c2a4a92a921a732357e99c01ccb563813b1af512
(cherry picked from commit 391319e)
@@ -2732,7 +2732,7 @@ Tensor linalg_cond(const Tensor& self, const optional<Scalar>& opt_ord) {

// If ord == None or ord == ±2
if (std::abs(ord.toDouble()) == 2.0) {
auto singular_values = std::get<1>(at::svd(self));
auto singular_values = at::linalg_svdvals(self);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This diff is most likely the reason for the XLA tests failure. #71964

@facebook-github-bot facebook-github-bot deleted the gh/Lezcano/43/head branch January 31, 2022 15:16
cyyever pushed a commit to cyyever/pytorch_private that referenced this pull request Feb 3, 2022
Summary:
Pull Request resolved: pytorch/pytorch#70253

I included a derivation of the formula in the complex case, as it is
particularly tricky. As far as I know, this is the first time this formula
is derived in the literature.

I also implemented a more efficient and more accurate version of svd_backward.
More importantly, I also added a lax check in the complex case making sure the loss
function just depends on the subspaces spanned by the pairs of singular
vectors, and not their joint phase.

cc jianyuh nikitaved pearu mruberry walterddr IvanYashchuk xwang233 Lezcano

Test Plan: Imported from OSS

Reviewed By: mikaylagawarecki

Differential Revision: D33751982

Pulled By: mruberry

fbshipit-source-id: c2a4a92a921a732357e99c01ccb563813b1af512
(cherry picked from commit 391319e)
cyyever pushed a commit to cyyever/pytorch_private that referenced this pull request Feb 3, 2022
Summary:
Pull Request resolved: pytorch/pytorch#70253

I included a derivation of the formula in the complex case, as it is
particularly tricky. As far as I know, this is the first time this formula
is derived in the literature.

I also implemented a more efficient and more accurate version of svd_backward.
More importantly, I also added a lax check in the complex case making sure the loss
function just depends on the subspaces spanned by the pairs of singular
vectors, and not their joint phase.

cc jianyuh nikitaved pearu mruberry walterddr IvanYashchuk xwang233 Lezcano

Test Plan: Imported from OSS

Reviewed By: mikaylagawarecki

Differential Revision: D33751982

Pulled By: mruberry

fbshipit-source-id: c2a4a92a921a732357e99c01ccb563813b1af512
(cherry picked from commit 391319e)
cyyever pushed a commit to cyyever/pytorch_private that referenced this pull request Feb 9, 2022
Summary:
Pull Request resolved: pytorch/pytorch#70253

I included a derivation of the formula in the complex case, as it is
particularly tricky. As far as I know, this is the first time this formula
is derived in the literature.

I also implemented a more efficient and more accurate version of svd_backward.
More importantly, I also added a lax check in the complex case making sure the loss
function just depends on the subspaces spanned by the pairs of singular
vectors, and not their joint phase.

cc jianyuh nikitaved pearu mruberry walterddr IvanYashchuk xwang233 Lezcano

Test Plan: Imported from OSS

Reviewed By: mikaylagawarecki

Differential Revision: D33751982

Pulled By: mruberry

fbshipit-source-id: c2a4a92a921a732357e99c01ccb563813b1af512
(cherry picked from commit 391319e)
cyyever pushed a commit to cyyever/pytorch_private that referenced this pull request Feb 9, 2022
Summary:
Pull Request resolved: pytorch/pytorch#70253

I included a derivation of the formula in the complex case, as it is
particularly tricky. As far as I know, this is the first time this formula
is derived in the literature.

I also implemented a more efficient and more accurate version of svd_backward.
More importantly, I also added a lax check in the complex case making sure the loss
function just depends on the subspaces spanned by the pairs of singular
vectors, and not their joint phase.

cc jianyuh nikitaved pearu mruberry walterddr IvanYashchuk xwang233 Lezcano

Test Plan: Imported from OSS

Reviewed By: mikaylagawarecki

Differential Revision: D33751982

Pulled By: mruberry

fbshipit-source-id: c2a4a92a921a732357e99c01ccb563813b1af512
(cherry picked from commit 391319e)
@frederikwilde
Copy link

Hi, I just stumbled upon this. I'm not sure if this helps, but I implemented the JVP rule for the complex-valued SVD here: jax-ml/jax#5225 and we showed how the derivation works in our paper (https://arxiv.org/pdf/2209.14328.pdf) in appendix D.

@lezcano
Copy link
Collaborator Author

lezcano commented Apr 21, 2023

I have a half-written paper where I discuss how to formalise all this stuff with principal bundles and all that, but I never found the time to finish writing it really :(

@frederikwilde
Copy link

Nice. I would be curious to see it when it's done.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla signed module: derivatives Related to derivatives of operators module: linear algebra Issues related to specialized linear algebra operations in PyTorch; includes matrix multiply matmul open source
Projects
None yet
Development

Successfully merging this pull request may close these issues.

8 participants