Skip to content

Commit fe0694c

Browse files
lezcanocyyever
authored andcommitted
Rewrite svd and linalg.svd as structured kernels (#69827)
Summary: Pull Request resolved: pytorch/pytorch#69827 In general, the current pattern allows for implementing optimisations for all the backends in a common place (see for example the optimisation for empty matrices). After this PR, `torch.svd` is implemented in terms of `linalg.svd` and `linalg.svdvals`, as expected. This makes it differentiable in the case when `compute_uv=False`, although this is not particularly important, as `torch.svd` will eventually be deprecated. This PR also instantiates smaller `U` / `V` when calling cusolver_gesvdj in the cases when `full_matrices=False` or `compute_uv=False`. The memory for auxiliary `U` and `V` in the cases above, needed for some cuSOLVER routines is allocated raw allocators rather than through fully fledged tensors, as it's just a blob of memory the algorithm requests. As the code is better structured now, it was easier to see that `U` and `Vh` needn't be allocated when calling `svd_cusolver_gesvd`. Now `linalg.svdvals` work as expected wrt the `out=` parameter. Note that in the test `test_svd_memory_allocation` we were passing a tensor of the wrong size and dtype and the test seemed to pass... This PR also changes the backward formula to avoid saving the input matrix, as it's not necessary. In a follow up PR, I will clean the backward formula and make it more numerically stable and efficient. This PR also does a number of memory optimisations here and there, and fixes the call to cusolver_gesvd, which were incorrect for m <= n. To test this path, I compiled the code with a flag to unconditionally execute the `if (!gesvdj_convergence_check.empty())` branch, and all the tests passed. I also took this chance to simplify the tests for these functions in `test_linalg.py`, as we had lots of tests that were testing some functionality that is already currently tested in the corresponding OpInfos. I used xwang233's feature to test both MAGMA and CUDA backends. This is particularly good for SVD, as cuSOLVER is always chosen over MAGMA when available, so testing MAGMA otherwise would be tricky. cc jianyuh nikitaved pearu mruberry walterddr IvanYashchuk xwang233 Lezcano Test Plan: Imported from OSS Reviewed By: mikaylagawarecki Differential Revision: D33751983 Pulled By: mruberry fbshipit-source-id: 11d48d977946345583d33d14fb11a170a7d14fd2 (cherry picked from commit a1860bd)
1 parent 6a34818 commit fe0694c

21 files changed

+637
-848
lines changed

aten/src/ATen/ConjugateFallback.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,8 @@ TORCH_LIBRARY_IMPL(aten, Conjugate, m) {
5151
m.impl("baddbmm", torch::CppFunction::makeFallthrough());
5252
m.impl("baddbmm_", torch::CppFunction::makeFallthrough());
5353
m.impl("baddbmm.out", torch::CppFunction::makeFallthrough());
54+
m.impl("linalg_svd", torch::CppFunction::makeFallthrough());
55+
m.impl("linalg_svd.U", torch::CppFunction::makeFallthrough());
5456

5557
TORCH_VIEW_FNS(m)
5658
TENSOR_UTILITIES_AND_CONSTRUCTORS(m)

aten/src/ATen/native/BatchLinearAlgebra.cpp

Lines changed: 191 additions & 189 deletions
Large diffs are not rendered by default.

aten/src/ATen/native/BatchLinearAlgebra.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,8 @@ void lapackLuSolve(char trans, int n, int nrhs, scalar_t *a, int lda, int *ipiv,
161161
template <class scalar_t>
162162
void lapackLu(int m, int n, scalar_t *a, int lda, int *ipiv, int *info);
163163

164+
template<class scalar_t, class value_t=scalar_t>
165+
void lapackSvd(char jobz, int m, int n, scalar_t *a, int lda, value_t *s, scalar_t *u, int ldu, scalar_t *vt, int ldvt, scalar_t *work, int lwork, value_t *rwork, int *iwork, int *info);
164166
#endif
165167

166168
#if AT_BUILD_WITH_BLAS()
@@ -239,5 +241,14 @@ using lu_solve_trans_fn = void (*)(
239241
TransposeType /*trans*/);
240242
DECLARE_DISPATCH(lu_solve_trans_fn, lu_solve_trans_stub);
241243

244+
using svd_fn = void (*)(
245+
const Tensor& /*A*/,
246+
const bool /*full_matrices*/,
247+
const bool /*compute_uv*/,
248+
const Tensor& /*U*/,
249+
const Tensor& /*S*/,
250+
const Tensor& /*Vh*/,
251+
const Tensor& /*info*/);
252+
DECLARE_DISPATCH(svd_fn, svd_stub);
242253

243254
}} // namespace at::native

aten/src/ATen/native/BatchLinearAlgebraKernel.cpp

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -943,6 +943,84 @@ void lu_solve_kernel(const Tensor& b, const Tensor& lu, const Tensor& pivots) {
943943
lu_solve_trans_kernel(b, lu, pivots, TransposeType::NoTranspose);
944944
}
945945

946+
template <typename scalar_t>
947+
static void apply_svd(const Tensor& A,
948+
const bool full_matrices,
949+
const bool compute_uv,
950+
const Tensor& U,
951+
const Tensor& S,
952+
const Tensor& Vh,
953+
const Tensor& info) {
954+
#if !AT_BUILD_WITH_LAPACK()
955+
TORCH_CHECK(false, "svd: LAPACK library not found in compilation");
956+
#else
957+
using value_t = typename c10::scalar_value_type<scalar_t>::type;
958+
const auto A_data = A.data_ptr<scalar_t>();
959+
const auto U_data = compute_uv ? U.data_ptr<scalar_t>() : nullptr;
960+
const auto S_data = S.data_ptr<value_t>();
961+
const auto info_data = info.data_ptr<int>();
962+
const auto Vh_data = compute_uv ? Vh.data_ptr<scalar_t>() : nullptr;
963+
const auto A_stride = matrixStride(A);
964+
const auto S_stride = S.size(-1);
965+
const auto U_stride = compute_uv ? matrixStride(U) : 1;
966+
const auto Vh_stride = compute_uv ? matrixStride(Vh) : 1;
967+
const auto batchsize = batchCount(A);
968+
const char jobz = compute_uv ? (full_matrices ? 'A' : 'S') : 'N';
969+
970+
const auto m = A.size(-2);
971+
const auto n = A.size(-1);
972+
const auto lda = A.stride(-1);
973+
const auto ldu= compute_uv ? U.stride(-1) : 1;
974+
const auto ldvh = compute_uv ? Vh.stride(-1) : 1;
975+
976+
auto iwork = std::vector<int>(8 * std::min(m, n));
977+
auto* const iwork_data = iwork.data();
978+
979+
// rwork is just used for the complex decomposition
980+
auto rwork = std::vector<value_t>{};
981+
if (A.is_complex()) {
982+
rwork.resize(std::max(computeLRWorkDim(jobz, m, n), int64_t{1}));
983+
}
984+
auto* const rwork_data = rwork.data();
985+
986+
// Query svd for the optimal lwork size
987+
int lwork = -1;
988+
{
989+
scalar_t wkopt;
990+
lapackSvd<scalar_t, value_t>(jobz, m, n, A_data, lda, S_data, U_data, ldu, Vh_data, ldvh, &wkopt, lwork, rwork_data, iwork_data, info_data);
991+
lwork = std::max<int>(1, real_impl<scalar_t, value_t>(wkopt));
992+
}
993+
auto work = std::vector<scalar_t>(lwork);
994+
auto* const work_data = work.data();
995+
996+
for (const auto i : c10::irange(batchsize)) {
997+
auto* const A_working_ptr = &A_data[i * A_stride];
998+
auto* const S_working_ptr = &S_data[i * S_stride];
999+
auto* const U_working_ptr = compute_uv ? &U_data[i * U_stride] : nullptr;
1000+
auto* const Vh_working_ptr = compute_uv ? &Vh_data[i * Vh_stride] : nullptr;
1001+
1002+
// Compute S, U (optionally) and Vh (optionally)
1003+
lapackSvd<scalar_t, value_t>(jobz, m, n, A_working_ptr, lda,
1004+
S_working_ptr, U_working_ptr, ldu, Vh_working_ptr, ldvh, work_data, lwork, rwork_data, iwork_data, info_data + i);
1005+
}
1006+
#endif
1007+
}
1008+
1009+
void svd_kernel(const Tensor& A,
1010+
const bool full_matrices,
1011+
const bool compute_uv,
1012+
const Tensor& U,
1013+
const Tensor& S,
1014+
const Tensor& Vh,
1015+
const Tensor& infos) {
1016+
// Need to copy A as column major, as its contents will be destroyed in the LAPACK call.
1017+
// FIXME It'd be more efficient, rather than cloning A, to copy it into `U` or `Vh` (depending on m > n
1018+
// or m < n) and call jobz='O'
1019+
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(A.scalar_type(), "linalg_svd_cpu", [&]{
1020+
apply_svd<scalar_t>(cloneBatchedColumnMajor(A), full_matrices, compute_uv, U, S, Vh, infos);
1021+
});
1022+
}
1023+
9461024
} // anonymous namespace
9471025

9481026
REGISTER_ARCH_DISPATCH(cholesky_stub, DEFAULT, &cholesky_kernel);
@@ -1023,4 +1101,9 @@ REGISTER_AVX2_DISPATCH(lu_solve_stub, &lu_solve_kernel);
10231101
REGISTER_VSX_DISPATCH(lu_solve_stub, &lu_solve_kernel);
10241102
REGISTER_ZVECTOR_DISPATCH(lu_solve_stub, &lu_solve_kernel);
10251103

1104+
REGISTER_ARCH_DISPATCH(svd_stub, DEFAULT, &svd_kernel);
1105+
REGISTER_AVX512_DISPATCH(svd_stub, &svd_kernel);
1106+
REGISTER_AVX2_DISPATCH(svd_stub, &svd_kernel);
1107+
REGISTER_VSX_DISPATCH(svd_stub, &svd_kernel);
1108+
REGISTER_ZVECTOR_DISPATCH(svd_stub, &svd_kernel);
10261109
}} // namespace at::native

aten/src/ATen/native/LinearAlgebraUtils.h

Lines changed: 7 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include <cstring>
1515
#include <cctype>
1616

17+
1718
namespace at { namespace native {
1819

1920
// Used as an interface between the different BLAS-like libraries
@@ -248,7 +249,6 @@ void batch_iterator_with_broadcasting(const Tensor& a, const Tensor& b, const fu
248249
iter.serial_for_each(loop, {0, batchCount(b)});
249250
}
250251

251-
252252
// Returns the epsilon value for floating types except half
253253
static inline double _get_epsilon(const ScalarType& sc_type) {
254254
switch (sc_type) {
@@ -468,55 +468,14 @@ static inline std::tuple<std::vector<int64_t>,
468468
return std::make_tuple(q_sizes, q_strides, n_columns_q);
469469
}
470470

471-
// Function to generate empty tensors of required size, strides and dtype for the SVD operation
472-
static inline std::tuple<Tensor, Tensor, Tensor> _create_U_S_VT(const Tensor& input, bool some, bool compute_uv,
473-
const bool svd_use_cusolver=false) {
474-
475-
// U, S, VT are initialized as empty tensors.
476-
// For CPU LAPACK and GPU MAGMA backend, the tensors are initialized on CPU.
477-
// For GPU cuSOLVER backend, the tensors are initialized on GPU.
478-
const auto usvt_device = svd_use_cusolver ? at::kCUDA : at::kCPU;
479-
480-
auto sizes = input.sizes().vec();
481-
int64_t m = input.size(-2), n = input.size(-1);
482-
483-
sizes[input.dim() - 1] = some ? std::min(m, n) : m;
484-
const auto u_strides = contiguous_strides(sizes, /*f-contig*/true);
485-
486-
// cuSOLVER's gesvdjBatched fails with illegal memory access and
487-
// cuSOLVER's gesvdj fails with CUSOLVER_STATUS_EXECUTION_FAILED
488-
// if matrices for U and VT are not allocated
489-
// even though the result of computation is not used we need to allocate this memory
490-
491-
Tensor U_empty = (compute_uv || svd_use_cusolver)
492-
? at::empty_strided(sizes, u_strides, input.options().device(usvt_device))
493-
: at::empty({0}, input.options().device(usvt_device));
494-
495-
// VT should be a column-major or a batch of column-major matrices
496-
sizes[input.dim() - 2] = some ? std::min(m, n) : n;
497-
sizes[input.dim() - 1] = n;
498-
const auto vt_strides = contiguous_strides(sizes, /*f-contig*/!svd_use_cusolver);
499-
Tensor VT_empty = (compute_uv || svd_use_cusolver)
500-
? at::empty_strided(sizes, vt_strides, input.options().device(usvt_device))
501-
: at::empty({0}, input.options().device(usvt_device));
502-
503-
// U and VT might not get filled in this case
504-
if (!some && compute_uv && input.numel() == 0) {
505-
U_empty.zero_();
506-
VT_empty.zero_();
507-
// make U and VT an identity matrix, because they should be orthogonal
508-
U_empty.diagonal(0, -2, -1).fill_(1);
509-
VT_empty.diagonal(0, -2, -1).fill_(1);
510-
}
511-
512-
sizes.pop_back();
513-
sizes[input.dim() - 2] = std::min(m, n);
514-
ScalarType dtype = toValueType(input.scalar_type());
515-
Tensor S_empty = at::empty(sizes, input.options().dtype(dtype).device(usvt_device));
516-
517-
return std::tuple<Tensor, Tensor, Tensor>(U_empty, S_empty, VT_empty);
471+
static inline bool svd_uses_cusolver(const Tensor& A) {
472+
// if cusolver is available, it is used unconditionally
473+
return A.is_cuda()
474+
&& at::globalContext().hasCuSOLVER()
475+
&& at::globalContext().linalgPreferredBackend() != at::LinalgBackend::Magma;
518476
}
519477

478+
520479
// Function used instead of .to so that the original strides are retained
521480
// .to doesn't retain strides and make the output tensor contiguous
522481
static inline Tensor same_stride_to(const Tensor& original_tensor, const at::TensorOptions& options) {

aten/src/ATen/native/NegateFallback.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@ TORCH_LIBRARY_IMPL(aten, Negative, m) {
3535
// linear algebra functions
3636
m.impl("linalg_solve_triangular", torch::CppFunction::makeFallthrough());
3737
m.impl("linalg_solve_triangular.out", torch::CppFunction::makeFallthrough());
38+
m.impl("linalg_svd", torch::CppFunction::makeFallthrough());
39+
m.impl("linalg_svd.U", torch::CppFunction::makeFallthrough());
3840

3941
TORCH_VIEW_FNS(m)
4042
TENSOR_UTILITIES_AND_CONSTRUCTORS(m)

0 commit comments

Comments
 (0)