Skip to content

Commit 410e48f

Browse files
authored
Enable fp8 on sm89 (#3624)
fp8's supported has been lifted to sm89 since PTX ISA 8.1 and later per https://docs.nvidia.com/cuda/parallel-thread-execution/
1 parent 516d590 commit 410e48f

File tree

2 files changed

+23
-4
lines changed

2 files changed

+23
-4
lines changed

csrc/device_lower/analysis/device_version.cpp

+16-1
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
* SPDX-License-Identifier: BSD-3-Clause
66
*/
77
// clang-format on
8+
#include <cuda.h>
9+
810
#include <device_lower/analysis/device_version.h>
911
#include <device_lower/lower2device.h>
1012
#include <mma_type.h>
@@ -19,9 +21,22 @@ void MinimumDeviceVersion::dispatch(Val* val) {
1921
}
2022
if (val->dtype() == DataType::Float8_e4m3fn ||
2123
val->dtype() == DataType::Float8_e5m2) {
24+
// See release note
25+
// https://docs.nvidia.com/cuda/archive/12.1.0/parallel-thread-execution/index.html#ptx-isa-version-8-1
26+
#if (CUDA_VERSION >= 12010)
2227
ensureVersion(
23-
{9, 0},
28+
{8, 9},
29+
"Fusion contains Float8_xxx values which was introduced in Ada (8.9)");
30+
// See release note
31+
// https://docs.nvidia.com/cuda/archive/11.8.0/parallel-thread-execution/index.html#ptx-isa-version-7-8
32+
#elif (CUDA_VERSION >= 11080)
33+
ensureVersion(
34+
{8, 9},
2435
"Fusion contains Float8_xxx values which was introduced in Hopper (9.0)");
36+
#else
37+
NVF_ERROR(
38+
"Fusion contains Float8_xxx values which was not supported in given CUDA version");
39+
#endif // (CUDA_VERSION >= 12010)
2540
}
2641
IterVisitor::dispatch(val);
2742
}

tests/cpp/test_gpu1.cpp

+7-3
Original file line numberDiff line numberDiff line change
@@ -2711,13 +2711,17 @@ TEST_F(NVFuserTest, FusionFp8CastOps_CUDA) {
27112711
std::vector<c10::IValue> inputs = {input1};
27122712

27132713
KernelExecutor ke;
2714-
2714+
#if (CUDA_VERSION >= 12010)
2715+
if (!deviceMajorMinorCheck(8, 9)) {
2716+
#elif (CUDA_VERSION >= 11080)
27152717
if (!deviceMajorMinorCheck(9)) {
2718+
#else
2719+
if (true) {
2720+
#endif
27162721
ASSERT_THAT(
27172722
[&]() { ke.compile(&fusion, inputs); },
27182723
testing::ThrowsMessage<nvfuser::nvfError>(testing::HasSubstr(
2719-
"Reason: Fusion contains Float8_xxx values which was introduced in Hopper (9.0)")));
2720-
GTEST_SKIP() << "skipping tests on pre-HOPPER GPUs";
2724+
"Reason: Fusion contains Float8_xxx values")));
27212725
} else {
27222726
ke.compile(&fusion, inputs);
27232727
auto outputs = ke.run(inputs);

0 commit comments

Comments
 (0)