Skip to content

Commit c91138b

Browse files
authored
Merge pull request #466 from amiller27/master
Look for nvcc in CUDA_HOME
2 parents daf9628 + 3294b15 commit c91138b

File tree

1 file changed

+58
-20
lines changed

1 file changed

+58
-20
lines changed

bindings/torch/setup.py

+58-20
Original file line numberDiff line numberDiff line change
@@ -80,26 +80,64 @@ def find_cl_path():
8080
cpp_standard = 14
8181

8282
# Get CUDA version and make sure the targeted compute capability is compatible
83-
if os.system("nvcc --version") == 0:
84-
nvcc_out = subprocess.check_output(["nvcc", "--version"]).decode()
85-
cuda_version = re.search(r"release (\S+),", nvcc_out)
86-
87-
if cuda_version:
88-
cuda_version = parse_version(cuda_version.group(1))
89-
print(f"Detected CUDA version {cuda_version}")
90-
if cuda_version >= parse_version("11.0"):
91-
cpp_standard = 17
92-
93-
supported_compute_capabilities = [
94-
cc for cc in compute_capabilities if cc >= min_supported_compute_capability(cuda_version) and cc <= max_supported_compute_capability(cuda_version)
95-
]
96-
97-
if not supported_compute_capabilities:
98-
supported_compute_capabilities = [max_supported_compute_capability(cuda_version)]
99-
100-
if supported_compute_capabilities != compute_capabilities:
101-
print(f"WARNING: Compute capabilities {compute_capabilities} are not all supported by the installed CUDA version {cuda_version}. Targeting {supported_compute_capabilities} instead.")
102-
compute_capabilities = supported_compute_capabilities
83+
def _maybe_find_nvcc():
84+
# Try PATH first
85+
maybe_nvcc = shutil.which("nvcc")
86+
87+
if maybe_nvcc is not None:
88+
return maybe_nvcc
89+
90+
# Then try CUDA_HOME from torch (cpp_extension.CUDA_HOME is undocumented, which is why we only use
91+
# it as a fallback)
92+
try:
93+
from torch.utils.cpp_extension import CUDA_HOME
94+
except ImportError:
95+
return None
96+
97+
if not CUDA_HOME:
98+
return None
99+
100+
return os.path.join(CUDA_HOME, "bin", "nvcc")
101+
102+
def _maybe_nvcc_version():
103+
maybe_nvcc = _maybe_find_nvcc()
104+
105+
if maybe_nvcc is None:
106+
return None
107+
108+
nvcc_version_result = subprocess.run(
109+
[maybe_nvcc, "--version"],
110+
text=True,
111+
check=False,
112+
stdout=subprocess.PIPE,
113+
)
114+
115+
if nvcc_version_result.returncode != 0:
116+
return None
117+
118+
cuda_version = re.search(r"release (\S+),", nvcc_version_result.stdout)
119+
120+
if not cuda_version:
121+
return None
122+
123+
return parse_version(cuda_version.group(1))
124+
125+
cuda_version = _maybe_nvcc_version()
126+
if cuda_version is not None:
127+
print(f"Detected CUDA version {cuda_version}")
128+
if cuda_version >= parse_version("11.0"):
129+
cpp_standard = 17
130+
131+
supported_compute_capabilities = [
132+
cc for cc in compute_capabilities if cc >= min_supported_compute_capability(cuda_version) and cc <= max_supported_compute_capability(cuda_version)
133+
]
134+
135+
if not supported_compute_capabilities:
136+
supported_compute_capabilities = [max_supported_compute_capability(cuda_version)]
137+
138+
if supported_compute_capabilities != compute_capabilities:
139+
print(f"WARNING: Compute capabilities {compute_capabilities} are not all supported by the installed CUDA version {cuda_version}. Targeting {supported_compute_capabilities} instead.")
140+
compute_capabilities = supported_compute_capabilities
103141

104142
min_compute_capability = min(compute_capabilities)
105143

0 commit comments

Comments
 (0)