@@ -19,17 +19,18 @@ def load(args):
19
19
# Check if cuda 11 is installed for compute capability 8.0
20
20
cc_flag = []
21
21
_ , bare_metal_major , bare_metal_minor = _get_cuda_bare_metal_version (
22
- cpp_extension .CUDA_HOME )
22
+ cpp_extension .CUDA_HOME
23
+ )
23
24
if int (bare_metal_major ) >= 11 :
24
- cc_flag .append (' -gencode' )
25
- cc_flag .append (' arch=compute_80,code=sm_80' )
25
+ cc_flag .append (" -gencode" )
26
+ cc_flag .append (" arch=compute_80,code=sm_80" )
26
27
if int (bare_metal_minor ) >= 7 :
27
- cc_flag .append (' -gencode' )
28
- cc_flag .append (' arch=compute_90,code=sm_90' )
28
+ cc_flag .append (" -gencode" )
29
+ cc_flag .append (" arch=compute_90,code=sm_90" )
29
30
30
31
# Build path
31
32
srcpath = pathlib .Path (__file__ ).parent .absolute ()
32
- buildpath = srcpath / ' build'
33
+ buildpath = srcpath / " build"
33
34
_create_build_dir (buildpath )
34
35
35
36
# Helper function to build the kernels.
@@ -38,46 +39,25 @@ def _cpp_extention_load_helper(name, sources, extra_cuda_flags):
38
39
name = name ,
39
40
sources = sources ,
40
41
build_directory = buildpath ,
41
- extra_cflags = ['-O3' ,],
42
- extra_cuda_cflags = ['-O3' ,
43
- '-gencode' , 'arch=compute_70,code=sm_70' ,
44
- '--use_fast_math' ] + extra_cuda_flags + cc_flag ,
45
- verbose = (args .rank == 0 )
42
+ extra_cflags = [
43
+ "-O3" ,
44
+ ],
45
+ extra_cuda_cflags = [
46
+ "-O3" ,
47
+ "-gencode" ,
48
+ "arch=compute_70,code=sm_70" ,
49
+ "--use_fast_math" ,
50
+ ]
51
+ + extra_cuda_flags
52
+ + cc_flag ,
53
+ verbose = (args .rank == 0 ),
46
54
)
47
55
48
- # ==============
49
- # Fused softmax.
50
- # ==============
51
-
52
- if args .masked_softmax_fusion :
53
- extra_cuda_flags = ['-U__CUDA_NO_HALF_OPERATORS__' ,
54
- '-U__CUDA_NO_HALF_CONVERSIONS__' ,
55
- '--expt-relaxed-constexpr' ,
56
- '--expt-extended-lambda' ]
57
-
58
- # Upper triangular softmax.
59
- sources = [srcpath / 'scaled_upper_triang_masked_softmax.cpp' ,
60
- srcpath / 'scaled_upper_triang_masked_softmax_cuda.cu' ]
61
- scaled_upper_triang_masked_softmax_cuda = _cpp_extention_load_helper (
62
- "scaled_upper_triang_masked_softmax_cuda" ,
63
- sources , extra_cuda_flags )
64
-
65
- # Masked softmax.
66
- sources = [srcpath / 'scaled_masked_softmax.cpp' ,
67
- srcpath / 'scaled_masked_softmax_cuda.cu' ]
68
- scaled_masked_softmax_cuda = _cpp_extention_load_helper (
69
- "scaled_masked_softmax_cuda" , sources , extra_cuda_flags )
70
-
71
- # Softmax
72
- sources = [srcpath / 'scaled_softmax.cpp' ,
73
- srcpath / 'scaled_softmax_cuda.cu' ]
74
- scaled_softmax_cuda = _cpp_extention_load_helper (
75
- "scaled_softmax_cuda" , sources , extra_cuda_flags )
76
-
77
56
78
57
def _get_cuda_bare_metal_version (cuda_dir ):
79
- raw_output = subprocess .check_output ([cuda_dir + "/bin/nvcc" , "-V" ],
80
- universal_newlines = True )
58
+ raw_output = subprocess .check_output (
59
+ [cuda_dir + "/bin/nvcc" , "-V" ], universal_newlines = True
60
+ )
81
61
output = raw_output .split ()
82
62
release_idx = output .index ("release" ) + 1
83
63
release = output [release_idx ].split ("." )
0 commit comments