Skip to content

Commit 77b2b77

Browse files
committed
Fixup CUDA architecture defaults to make them into a list
1 parent c4e2be0 commit 77b2b77

File tree

2 files changed

+7
-2
lines changed

2 files changed

+7
-2
lines changed

src/aedifix/packages/cuda.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from __future__ import annotations
55

66
import os
7+
import re
78
import shutil
89
from argparse import Action, ArgumentParser, Namespace
910
from pathlib import Path
@@ -55,7 +56,7 @@ def map_cuda_arch_names(in_arch: str) -> list[str]:
5556
# TODO(jfaibussowit): rubin?
5657
}
5758
arch = []
58-
for sub_arch in in_arch.split(","):
59+
for sub_arch in re.split(r"[;,]", in_arch):
5960
# support Turing, TURING, and, if the user is feeling spicy, tUrInG
6061
sub_arch_lo = sub_arch.strip().casefold()
6162
if not sub_arch_lo:
@@ -131,7 +132,9 @@ class CUDA(Package):
131132
spec=ArgSpec(
132133
dest="cuda_arch",
133134
required=False,
134-
default=_guess_cuda_architecture(),
135+
default=CudaArchAction.map_cuda_arch_names(
136+
_guess_cuda_architecture()
137+
),
135138
action=CudaArchAction,
136139
help=(
137140
"Specify the target GPU architecture. Available choices are: "

tests/packages/test_cuda.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
("turing,hopper", ["75", "90"]),
1919
("volta,60,all-major", ["70", "60", "all-major"]),
2020
("60,,80", ["60", "80"]),
21+
("50-real;120-real;121", ["50-real", "120-real", "121"]),
2122
)
2223

2324

@@ -33,6 +34,7 @@ class TestCUDA:
3334
("env_var", "env_value", "expected"),
3435
[
3536
("CUDAARCHS", "volta", "volta"),
37+
("CUDAARCHS", "volta;ampere", "volta;ampere"),
3638
("CMAKE_CUDA_ARCHITECTURES", "75", "75"),
3739
("", "", "all-major"),
3840
],

0 commit comments

Comments
 (0)