Skip to content

Commit 02c6229

Browse files
committed
Merge branch 'jbarker/apex_softmax' into 'main'
Switch custom fused softmax kernels to apex See merge request ADLR/megatron-lm!570
2 parents 040eac9 + cd12636 commit 02c6229

20 files changed

+158
-1838
lines changed

megatron/fused_kernels/__init__.py

+22-42
Original file line numberDiff line numberDiff line change
@@ -19,17 +19,18 @@ def load(args):
1919
# Check if cuda 11 is installed for compute capability 8.0
2020
cc_flag = []
2121
_, bare_metal_major, bare_metal_minor = _get_cuda_bare_metal_version(
22-
cpp_extension.CUDA_HOME)
22+
cpp_extension.CUDA_HOME
23+
)
2324
if int(bare_metal_major) >= 11:
24-
cc_flag.append('-gencode')
25-
cc_flag.append('arch=compute_80,code=sm_80')
25+
cc_flag.append("-gencode")
26+
cc_flag.append("arch=compute_80,code=sm_80")
2627
if int(bare_metal_minor) >= 7:
27-
cc_flag.append('-gencode')
28-
cc_flag.append('arch=compute_90,code=sm_90')
28+
cc_flag.append("-gencode")
29+
cc_flag.append("arch=compute_90,code=sm_90")
2930

3031
# Build path
3132
srcpath = pathlib.Path(__file__).parent.absolute()
32-
buildpath = srcpath / 'build'
33+
buildpath = srcpath / "build"
3334
_create_build_dir(buildpath)
3435

3536
# Helper function to build the kernels.
@@ -38,46 +39,25 @@ def _cpp_extention_load_helper(name, sources, extra_cuda_flags):
3839
name=name,
3940
sources=sources,
4041
build_directory=buildpath,
41-
extra_cflags=['-O3',],
42-
extra_cuda_cflags=['-O3',
43-
'-gencode', 'arch=compute_70,code=sm_70',
44-
'--use_fast_math'] + extra_cuda_flags + cc_flag,
45-
verbose=(args.rank == 0)
42+
extra_cflags=[
43+
"-O3",
44+
],
45+
extra_cuda_cflags=[
46+
"-O3",
47+
"-gencode",
48+
"arch=compute_70,code=sm_70",
49+
"--use_fast_math",
50+
]
51+
+ extra_cuda_flags
52+
+ cc_flag,
53+
verbose=(args.rank == 0),
4654
)
4755

48-
# ==============
49-
# Fused softmax.
50-
# ==============
51-
52-
if args.masked_softmax_fusion:
53-
extra_cuda_flags = ['-U__CUDA_NO_HALF_OPERATORS__',
54-
'-U__CUDA_NO_HALF_CONVERSIONS__',
55-
'--expt-relaxed-constexpr',
56-
'--expt-extended-lambda']
57-
58-
# Upper triangular softmax.
59-
sources=[srcpath / 'scaled_upper_triang_masked_softmax.cpp',
60-
srcpath / 'scaled_upper_triang_masked_softmax_cuda.cu']
61-
scaled_upper_triang_masked_softmax_cuda = _cpp_extention_load_helper(
62-
"scaled_upper_triang_masked_softmax_cuda",
63-
sources, extra_cuda_flags)
64-
65-
# Masked softmax.
66-
sources=[srcpath / 'scaled_masked_softmax.cpp',
67-
srcpath / 'scaled_masked_softmax_cuda.cu']
68-
scaled_masked_softmax_cuda = _cpp_extention_load_helper(
69-
"scaled_masked_softmax_cuda", sources, extra_cuda_flags)
70-
71-
# Softmax
72-
sources=[srcpath / 'scaled_softmax.cpp',
73-
srcpath / 'scaled_softmax_cuda.cu']
74-
scaled_softmax_cuda = _cpp_extention_load_helper(
75-
"scaled_softmax_cuda", sources, extra_cuda_flags)
76-
7756

7857
def _get_cuda_bare_metal_version(cuda_dir):
79-
raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"],
80-
universal_newlines=True)
58+
raw_output = subprocess.check_output(
59+
[cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True
60+
)
8161
output = raw_output.split()
8262
release_idx = output.index("release") + 1
8363
release = output[release_idx].split(".")

megatron/fused_kernels/scaled_masked_softmax.cpp

-83
This file was deleted.

0 commit comments

Comments
 (0)