Skip to content

Commit 07824fa

Browse files
lezcanocyyever
authored andcommitted
Add hascuSOLVER flag to Context (#69825)
Summary: Pull Request resolved: pytorch/pytorch#69825 As per title. cc ngimel jianyuh nikitaved pearu mruberry walterddr IvanYashchuk xwang233 Lezcano Test Plan: Imported from OSS Reviewed By: mikaylagawarecki, ngimel Differential Revision: D33751986 Pulled By: mruberry fbshipit-source-id: 8625c7246d627b5c3680d92d4e8afdd7efc7dd69 (cherry picked from commit 7ca16be)
1 parent fc2d48d commit 07824fa

File tree

4 files changed

+16
-0
lines changed

4 files changed

+16
-0
lines changed

aten/src/ATen/Context.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,9 @@ class TORCH_API Context {
7474
static long versionCuDNN() {
7575
return detail::getCUDAHooks().versionCuDNN();
7676
}
77+
static bool hasCuSOLVER() {
78+
return detail::getCUDAHooks().hasCuSOLVER();
79+
}
7780
static bool hasHIP() {
7881
return detail::getHIPHooks().hasHIP();
7982
}

aten/src/ATen/cuda/detail/CUDAHooks.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,14 @@ bool CUDAHooks::hasCuDNN() const {
130130
return AT_CUDNN_ENABLED();
131131
}
132132

133+
bool CUDAHooks::hasCuSOLVER() const {
134+
#if defined(CUDART_VERSION) && defined(CUSOLVER_VERSION)
135+
return true;
136+
#else
137+
return false;
138+
#endif
139+
}
140+
133141
#if defined(USE_DIRECT_NVRTC)
134142
static std::pair<std::unique_ptr<at::DynamicLibrary>, at::cuda::NVRTC*> load_nvrtc() {
135143
return std::make_pair(nullptr, at::cuda::load_nvrtc());

aten/src/ATen/cuda/detail/CUDAHooks.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ struct CUDAHooks : public at::CUDAHooksInterface {
2626
bool hasCUDA() const override;
2727
bool hasMAGMA() const override;
2828
bool hasCuDNN() const override;
29+
bool hasCuSOLVER() const override;
2930
const at::cuda::NVRTC& nvrtc() const override;
3031
int64_t current_device() const override;
3132
bool hasPrimaryContext(int64_t device_index) const override;

aten/src/ATen/detail/CUDAHooksInterface.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,10 @@ struct TORCH_API CUDAHooksInterface {
102102
return false;
103103
}
104104

105+
virtual bool hasCuSOLVER() const {
106+
return false;
107+
}
108+
105109
virtual const at::cuda::NVRTC& nvrtc() const {
106110
TORCH_CHECK(false, "NVRTC requires CUDA. ", CUDA_HELP);
107111
}

0 commit comments

Comments
 (0)