@@ -80,26 +80,64 @@ def find_cl_path():
80
80
cpp_standard = 14
81
81
82
82
# Get CUDA version and make sure the targeted compute capability is compatible
83
- if os .system ("nvcc --version" ) == 0 :
84
- nvcc_out = subprocess .check_output (["nvcc" , "--version" ]).decode ()
85
- cuda_version = re .search (r"release (\S+)," , nvcc_out )
86
-
87
- if cuda_version :
88
- cuda_version = parse_version (cuda_version .group (1 ))
89
- print (f"Detected CUDA version { cuda_version } " )
90
- if cuda_version >= parse_version ("11.0" ):
91
- cpp_standard = 17
92
-
93
- supported_compute_capabilities = [
94
- cc for cc in compute_capabilities if cc >= min_supported_compute_capability (cuda_version ) and cc <= max_supported_compute_capability (cuda_version )
95
- ]
96
-
97
- if not supported_compute_capabilities :
98
- supported_compute_capabilities = [max_supported_compute_capability (cuda_version )]
99
-
100
- if supported_compute_capabilities != compute_capabilities :
101
- print (f"WARNING: Compute capabilities { compute_capabilities } are not all supported by the installed CUDA version { cuda_version } . Targeting { supported_compute_capabilities } instead." )
102
- compute_capabilities = supported_compute_capabilities
83
+ def _maybe_find_nvcc ():
84
+ # Try PATH first
85
+ maybe_nvcc = shutil .which ("nvcc" )
86
+
87
+ if maybe_nvcc is not None :
88
+ return maybe_nvcc
89
+
90
+ # Then try CUDA_HOME from torch (cpp_extension.CUDA_HOME is undocumented, which is why we only use
91
+ # it as a fallback)
92
+ try :
93
+ from torch .utils .cpp_extension import CUDA_HOME
94
+ except ImportError :
95
+ return None
96
+
97
+ if not CUDA_HOME :
98
+ return None
99
+
100
+ return os .path .join (CUDA_HOME , "bin" , "nvcc" )
101
+
102
+ def _maybe_nvcc_version ():
103
+ maybe_nvcc = _maybe_find_nvcc ()
104
+
105
+ if maybe_nvcc is None :
106
+ return None
107
+
108
+ nvcc_version_result = subprocess .run (
109
+ [maybe_nvcc , "--version" ],
110
+ text = True ,
111
+ check = False ,
112
+ stdout = subprocess .PIPE ,
113
+ )
114
+
115
+ if nvcc_version_result .returncode != 0 :
116
+ return None
117
+
118
+ cuda_version = re .search (r"release (\S+)," , nvcc_version_result .stdout )
119
+
120
+ if not cuda_version :
121
+ return None
122
+
123
+ return parse_version (cuda_version .group (1 ))
124
+
125
+ cuda_version = _maybe_nvcc_version ()
126
+ if cuda_version is not None :
127
+ print (f"Detected CUDA version { cuda_version } " )
128
+ if cuda_version >= parse_version ("11.0" ):
129
+ cpp_standard = 17
130
+
131
+ supported_compute_capabilities = [
132
+ cc for cc in compute_capabilities if cc >= min_supported_compute_capability (cuda_version ) and cc <= max_supported_compute_capability (cuda_version )
133
+ ]
134
+
135
+ if not supported_compute_capabilities :
136
+ supported_compute_capabilities = [max_supported_compute_capability (cuda_version )]
137
+
138
+ if supported_compute_capabilities != compute_capabilities :
139
+ print (f"WARNING: Compute capabilities { compute_capabilities } are not all supported by the installed CUDA version { cuda_version } . Targeting { supported_compute_capabilities } instead." )
140
+ compute_capabilities = supported_compute_capabilities
103
141
104
142
min_compute_capability = min (compute_capabilities )
105
143
0 commit comments