Skip to content

Expose matrix multiplication operations with conjugate transposes of the inputs #51750

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
ngimel opened this issue Feb 4, 2021 · 4 comments
Closed
Labels
enhancement Not as big of a feature, but technically not a bug. Should be easy to fix module: complex Related to complex number support in PyTorch module: linear algebra Issues related to specialized linear algebra operations in PyTorch; includes matrix multiply matmul triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@ngimel
Copy link
Collaborator

ngimel commented Feb 4, 2021

🚀 Feature

BLAS libraries that pytorch uses have an option of doing implicit conjugate transpose of an argument ('h' option), but pytorch does not have bindings to those, and does not expose a way to call them. This option can be useful to speed up backward pass through matrix multiplication ops, because we could potentially avoid materializing conjugate.
It's not fully clear how to best expose this to user, or whether it should be exposed at all, as opposed to being called in the backward when necessary.
Pytorch sets t argument to blas calls depending on the strides of the input matrices, h cannot be set independently, it's possible to set it only if physical memory layout corresponds to transposed matrix, so UX here is not very clear. This issue is to discuss what exposure we want.

Related: we have dot and vdot functions, where vdot does an implicit conjugate of an argument.

Also related: #45063, where some comments discuss the possibility of adding conjugate views.

cc @ezyang @anjali411 @dylanbespalko @mruberry @jianyuh @nikitaved @pearu @heitorschueroff @walterddr @IvanYashchuk

@ngimel ngimel added module: linear algebra Issues related to specialized linear algebra operations in PyTorch; includes matrix multiply matmul triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module enhancement Not as big of a feature, but technically not a bug. Should be easy to fix module: complex Related to complex number support in PyTorch labels Feb 4, 2021
@ezyang
Copy link
Contributor

ezyang commented Feb 5, 2021

If conjugate views happen, we don't have to add new functions as matmul(x, y.H) will be just as efficient as calling the implicit conjugate (modulo O(1) overhead of the y.H Tensor). It might still be a good idea to add a few new functions for explicit conjugation anyway though, however, naming them well is a challenge.

@ngimel
Copy link
Collaborator Author

ngimel commented Feb 5, 2021

If conjugate views happen, we don't have to add new functions as matmul(x, y.H) will be just as efficient as calling the implicit conjugate (modulo O(1) overhead of the y.H Tensor). It might still be a good idea to add a few new functions for explicit conjugation anyway though, however, naming them well is a challenge.

Conjugation is slightly weirder though. When you are calling matmul(x,y.t()) if the inputs are non-overlapping and the smallest stride is 1, no copy will happen. For matmul(x,y.H) you have additional constraints on y strides (because in blas conjugation is tied with physical memory layout).
Step 0 is still at least binding those blas calls, and using them in backward in non-user-facing manner, that's irrespective of whether conjugate views happen or not.

@ezyang
Copy link
Contributor

ezyang commented Feb 8, 2021

Just to make sure I understand correctly, this is because there are blas kernels for row major and column major y, but for conjugated y, it has to be column major row major

@ngimel
Copy link
Collaborator Author

ngimel commented Feb 8, 2021

Yep, with a minor correction that for conjugated y it has to be row major (because blas is column major, 'n' in blas means column major, 't'/'h' is row major)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement Not as big of a feature, but technically not a bug. Should be easy to fix module: complex Related to complex number support in PyTorch module: linear algebra Issues related to specialized linear algebra operations in PyTorch; includes matrix multiply matmul triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

2 participants