File tree 2 files changed +23
-4
lines changed
csrc/device_lower/analysis
2 files changed +23
-4
lines changed Original file line number Diff line number Diff line change 5
5
* SPDX-License-Identifier: BSD-3-Clause
6
6
*/
7
7
// clang-format on
8
+ #include < cuda.h>
9
+
8
10
#include < device_lower/analysis/device_version.h>
9
11
#include < device_lower/lower2device.h>
10
12
#include < mma_type.h>
@@ -19,9 +21,22 @@ void MinimumDeviceVersion::dispatch(Val* val) {
19
21
}
20
22
if (val->dtype () == DataType::Float8_e4m3fn ||
21
23
val->dtype () == DataType::Float8_e5m2) {
24
+ // See release note
25
+ // https://docs.nvidia.com/cuda/archive/12.1.0/parallel-thread-execution/index.html#ptx-isa-version-8-1
26
+ #if (CUDA_VERSION >= 12010)
22
27
ensureVersion (
23
- {9 , 0 },
28
+ {8 , 9 },
29
+ " Fusion contains Float8_xxx values which was introduced in Ada (8.9)" );
30
+ // See release note
31
+ // https://docs.nvidia.com/cuda/archive/11.8.0/parallel-thread-execution/index.html#ptx-isa-version-7-8
32
+ #elif (CUDA_VERSION >= 11080)
33
+ ensureVersion (
34
+ {8 , 9 },
24
35
" Fusion contains Float8_xxx values which was introduced in Hopper (9.0)" );
36
+ #else
37
+ NVF_ERROR (
38
+ " Fusion contains Float8_xxx values which was not supported in given CUDA version" );
39
+ #endif // (CUDA_VERSION >= 12010)
25
40
}
26
41
IterVisitor::dispatch (val);
27
42
}
Original file line number Diff line number Diff line change @@ -2711,13 +2711,17 @@ TEST_F(NVFuserTest, FusionFp8CastOps_CUDA) {
2711
2711
std::vector<c10::IValue> inputs = {input1};
2712
2712
2713
2713
KernelExecutor ke;
2714
-
2714
+ #if (CUDA_VERSION >= 12010)
2715
+ if (!deviceMajorMinorCheck (8 , 9 )) {
2716
+ #elif (CUDA_VERSION >= 11080)
2715
2717
if (!deviceMajorMinorCheck (9 )) {
2718
+ #else
2719
+ if (true ) {
2720
+ #endif
2716
2721
ASSERT_THAT (
2717
2722
[&]() { ke.compile (&fusion, inputs); },
2718
2723
testing::ThrowsMessage<nvfuser::nvfError>(testing::HasSubstr (
2719
- " Reason: Fusion contains Float8_xxx values which was introduced in Hopper (9.0)" )));
2720
- GTEST_SKIP () << " skipping tests on pre-HOPPER GPUs" ;
2724
+ " Reason: Fusion contains Float8_xxx values" )));
2721
2725
} else {
2722
2726
ke.compile (&fusion, inputs);
2723
2727
auto outputs = ke.run (inputs);
You can’t perform that action at this time.
0 commit comments