Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

INTERNAL ASSERT FAILED at "nvfuser/csrc/device_lower/analysis/sync_information.cpp":812 Inconsistent parallelization found between TV48 and TV2 #3609

Open
crcrpar opened this issue Dec 18, 2024 · 0 comments
Labels

Comments

@crcrpar
Copy link
Collaborator

crcrpar commented Dec 18, 2024

This might be equivalent to #3498

# CUDA devices:
#  0: NVIDIA H100 80GB HBM3
#  1: NVIDIA H100 80GB HBM3
#  2: NVIDIA H100 80GB HBM3
#  3: NVIDIA H100 80GB HBM3
#  4: NVIDIA H100 80GB HBM3
#  5: NVIDIA H100 80GB HBM3
#  6: NVIDIA H100 80GB HBM3
#  7: NVIDIA H100 80GB HBM3
# torch version: 2.6.0a0+git45ed7c1
# nvfuser version: 0.2.23+git21e3617
import torch
from nvfuser import FusionDefinition, DataType

def nvfuser_fusion_id2(fd : FusionDefinition) -> None :
    T0 = fd.define_tensor(shape=[64], contiguity=[True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[0])
    T1 = fd.define_tensor(shape=[16, 64], contiguity=[True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[1, 0])
    T5 = fd.ops.broadcast_in_dim(T0, shape=[1, 64], broadcast_dims=[1])
    T9 = fd.ops.broadcast_in_dim(T5, shape=[16, 64], broadcast_dims=[0, 1])
    T10 = fd.ops.cast(T1, dtype=DataType.Float)
    T11 = fd.ops.cast(T9, dtype=DataType.Float)
    T12 = fd.ops.add(T10, T11)
    T13 = fd.ops.cast(T12, dtype=DataType.BFloat16)
    S14 = fd.define_scalar(0.00000, dtype=DataType.Double)
    T15 = fd.ops.gt(T13, S14)
    S16 = fd.define_scalar(0.00000, dtype=DataType.Double)
    T17 = fd.ops.where(T15, T13, S16)
    T18 = fd.ops.cast(T17, dtype=DataType.Float)
    T19 = fd.ops.abs(T18)
    T20 = fd.ops.max(T19, dims=[0, 1], keepdim=False, dtype=DataType.Null)
    T21 = fd.ops.cast(T20, dtype=DataType.Double)
    T22 = fd.ops.ne(T21, T21)
    S23 = fd.define_scalar(1.00000e-12, dtype=DataType.Double)
    T24 = fd.ops.gt(T21, S23)
    S25 = fd.define_scalar(1.00000e-12, dtype=DataType.Double)
    T26 = fd.ops.where(T24, T21, S25)
    T27 = fd.ops.where(T22, T21, T26)
    S28 = fd.define_scalar(448.000, dtype=DataType.Double)
    T29 = fd.ops.reciprocal(T27)
    T30 = fd.ops.mul(S28, T29)
    T31 = fd.ops.cast(T30, dtype=DataType.Float)
    T35 = fd.ops.broadcast_in_dim(T31, shape=[16, 64], broadcast_dims=[])
    T36 = fd.ops.mul(T18, T35)
    T37 = fd.ops.ne(T36, T36)
    S38 = fd.define_scalar(-448.000, dtype=DataType.Double)
    T39 = fd.ops.gt(T36, S38)
    S40 = fd.define_scalar(-448.000, dtype=DataType.Double)
    T41 = fd.ops.where(T39, T36, S40)
    T42 = fd.ops.where(T37, T36, T41)
    T43 = fd.ops.ne(T42, T42)
    S44 = fd.define_scalar(448.000, dtype=DataType.Double)
    T45 = fd.ops.lt(T42, S44)
    S46 = fd.define_scalar(448.000, dtype=DataType.Double)
    T47 = fd.ops.where(T45, T42, S46)
    T48 = fd.ops.where(T43, T42, T47)
    T49 = fd.ops.cast(T48, dtype=DataType.Float8_e4m3fn)
    T53 = fd.ops.reshape(T49, new_shape=[16, 64])
    T54 = fd.ops.reciprocal(T31)
    fd.add_output(T15)
    fd.add_output(T31)
    fd.add_output(T49)
    fd.add_output(T53)
    fd.add_output(T54)

with FusionDefinition() as fd:
    nvfuser_fusion_id2(fd)

inputs = [
    torch.testing.make_tensor((64,), dtype=torch.bfloat16, device='cuda:0'),
    torch.testing.make_tensor((16, 64), dtype=torch.bfloat16, device='cuda:0'),
]
fd.execute(inputs)
Traceback (most recent call last):
  File "/opt/pytorch/nvfuser/nvfuser/__init__.py", line 319, in execute
    results = self._execute(
              ^^^^^^^^^^^^^^
RuntimeError:  INTERNAL ASSERT FAILED at "/opt/pytorch/nvfuser/csrc/device_lower/analysis/sync_information.cpp":812, please report a bug with repro script to NVFuser at https://github.com/NVIDIA/Fuser/issues. Inconsistent parallelization found between TV48 (T48_l___bfloat[iS390{1}, iUS392{1}, ithreadIdx.x393{8}_p, iV389{8}]) and TV2(T2_l___bfloat[iS160{1}, iUS162{1}, ithreadIdx.x163{8}_p, iS159{8}] ca_pos( 4 )). Producer is required to be in Global or Shared Memory based on parallelization strategy. RAW flags: (threadIdx.x)
Exception raised from SyncMap at /opt/pytorch/nvfuser/csrc/device_lower/analysis/sync_information.cpp:812 (most recent call first):
frame #0: nvfuser::nvfCheckFail(char const*, char const*, unsigned int, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&) + 0x103 (0x7f30da79e519 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-312-x86_64-linux-gnu.so)
frame #1: nvfuser::nvfErrorFail(char const*, char const*, unsigned int, char const*, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&) + 0x62 (0x7f30dabb8162 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-312-x86_64-linux-gnu.so)
frame #2: nvfuser::SyncMap::SyncMap(nvfuser::Fusion*) + 0x1909 (0x7f30daa4bf19 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-312-x86_64-linux-gnu.so)
frame #3: <unknown function> + 0x40a636 (0x7f30daa79636 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-312-x86_64-linux-gnu.so)
frame #4: nvfuser::GpuLower::GpuLower(nvfuser::Fusion*, nvfuser::CompileParams const&) + 0x12bf (0x7f30daa7b6ff in /opt/pytorch/nvfuser/nvfuser/_C.cpython-312-x86_64-linux-gnu.so)
frame #5: nvfuser::KernelExecutor::compile(nvfuser::Fusion*, nvfuser::KernelArgumentHolder const&, nvfuser::LaunchParams const&, nvfuser::CompileParams, nvfuser::SchedulerType) + 0x83d (0x7f30daef7f3d in /opt/pytorch/nvfuser/nvfuser/_C.cpython-312-x86_64-linux-gnu.so)
frame #6: <unknown function> + 0x89118f (0x7f30daf0018f in /opt/pytorch/nvfuser/nvfuser/_C.cpython-312-x86_64-linux-gnu.so)
frame #7: <unknown function> + 0x8c8d74 (0x7f30daf37d74 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-312-x86_64-linux-gnu.so)
frame #8: nvfuser::FusionKernelRuntime::compileFusionParallel(nvfuser::KernelArgumentHolder) + 0x423 (0x7f30daf3b423 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-312-x86_64-linux-gnu.so)
frame #9: nvfuser::FusionExecutorCache::runFusionWithInputs(c10::ArrayRef<c10::IValue> const&, std::optional<nvfuser::PrimDataType>, std::optional<signed char>) + 0x1cb (0x7f30daf30c2b in /opt/pytorch/nvfuser/nvfuser/_C.cpython-312-x86_64-linux-gnu.so)
frame #10: nvfuser::python_frontend::FusionDefinition::execute(c10::ArrayRef<c10::IValue> const&, std::optional<signed char>, bool, bool, bool, std::vector<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, std::allocator<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > > >, std::vector<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, std::allocator<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > > >) const + 0xbc4 (0x7f30db0eb7a4 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-312-x86_64-linux-gnu.so)
frame #11: <unknown function> + 0x1accd4 (0x7f30da81bcd4 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-312-x86_64-linux-gnu.so)
frame #12: <unknown function> + 0x281173 (0x7f30da8f0173 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-312-x86_64-linux-gnu.so)
frame #13: <unknown function> + 0x2e0835 (0x7f30da94f835 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-312-x86_64-linux-gnu.so)
frame #14: python() [0x5820ff]
<omitting python frames>
frame #17: python() [0x54cae4]
frame #21: python() [0x5a3698]
frame #24: python() [0x54ca1d]
frame #26: python() [0x54ca1d]
frame #29: python() [0x5a3698]
frame #33: python() [0x608b52]
frame #34: python() [0x6b4d83]
frame #39: <unknown function> + 0x2a1ca (0x7f44b2b011ca in /usr/lib/x86_64-linux-gnu/libc.so.6)
frame #40: __libc_start_main + 0x8b (0x7f44b2b0128b in /usr/lib/x86_64-linux-gnu/libc.so.6)

Traceback (most recent call last):
  File "/opt/pytorch/lightning-thunder/ao_fp8.py", line 42, in <module>
    main()
  File "/opt/pytorch/lightning-thunder/ao_fp8.py", line 25, in main
    actual = jitted(x)
             ^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/pytorch/lightning-thunder/thunder/core/module.py", line 80, in forward
    res = self._forward_fn(*args, **kwargs)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/pytorch/lightning-thunder/thunder/__init__.py", line 748, in wrapped
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/opt/pytorch/lightning-thunder/thunder/__init__.py", line 785, in fn_
    result = cache_entry.computation_fn(*inps)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/pytorch/lightning-thunder/thunder/__init__.py", line 712, in wrapped
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/opt/pytorch/lightning-thunder/thunder/executors/torchex.py", line 178, in no_autocast_fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "thunder.computation_71", line 38, in computation
  File "/opt/pytorch/lightning-thunder/thunder/executors/nvfuserex_impl.py", line 485, in __call__
    return fd.execute(args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/pytorch/nvfuser/nvfuser/__init__.py", line 319, in execute
    results = self._execute(
              ^^^^^^^^^^^^^^
RuntimeError:  INTERNAL ASSERT FAILED at "/opt/pytorch/nvfuser/csrc/device_lower/analysis/sync_information.cpp":812, please report a bug with repro script to NVFuser at https://github.com/NVIDIA/Fuser/issues. Inconsistent parallelization found between TV48 (T48_l___bfloat[iS390{1}, iUS392{1}, ithreadIdx.x393{8}_p, iV389{8}]) and TV2(T2_l___bfloat[iS160{1}, iUS162{1}, ithreadIdx.x163{8}_p, iS159{8}] ca_pos( 4 )). Producer is required to be in Global or Shared Memory based on parallelization strategy. RAW flags: (threadIdx.x)
Exception raised from SyncMap at /opt/pytorch/nvfuser/csrc/device_lower/analysis/sync_information.cpp:812 (most recent call first):
frame #0: nvfuser::nvfCheckFail(char const*, char const*, unsigned int, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&) + 0x103 (0x7f30da79e519 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-312-x86_64-linux-gnu.so)
frame #1: nvfuser::nvfErrorFail(char const*, char const*, unsigned int, char const*, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&) + 0x62 (0x7f30dabb8162 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-312-x86_64-linux-gnu.so)
frame #2: nvfuser::SyncMap::SyncMap(nvfuser::Fusion*) + 0x1909 (0x7f30daa4bf19 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-312-x86_64-linux-gnu.so)
frame #3: <unknown function> + 0x40a636 (0x7f30daa79636 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-312-x86_64-linux-gnu.so)
frame #4: nvfuser::GpuLower::GpuLower(nvfuser::Fusion*, nvfuser::CompileParams const&) + 0x12bf (0x7f30daa7b6ff in /opt/pytorch/nvfuser/nvfuser/_C.cpython-312-x86_64-linux-gnu.so)
frame #5: nvfuser::KernelExecutor::compile(nvfuser::Fusion*, nvfuser::KernelArgumentHolder const&, nvfuser::LaunchParams const&, nvfuser::CompileParams, nvfuser::SchedulerType) + 0x83d (0x7f30daef7f3d in /opt/pytorch/nvfuser/nvfuser/_C.cpython-312-x86_64-linux-gnu.so)
frame #6: <unknown function> + 0x89118f (0x7f30daf0018f in /opt/pytorch/nvfuser/nvfuser/_C.cpython-312-x86_64-linux-gnu.so)
frame #7: <unknown function> + 0x8c8d74 (0x7f30daf37d74 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-312-x86_64-linux-gnu.so)
frame #8: nvfuser::FusionKernelRuntime::compileFusionParallel(nvfuser::KernelArgumentHolder) + 0x423 (0x7f30daf3b423 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-312-x86_64-linux-gnu.so)
frame #9: nvfuser::FusionExecutorCache::runFusionWithInputs(c10::ArrayRef<c10::IValue> const&, std::optional<nvfuser::PrimDataType>, std::optional<signed char>) + 0x1cb (0x7f30daf30c2b in /opt/pytorch/nvfuser/nvfuser/_C.cpython-312-x86_64-linux-gnu.so)
frame #10: nvfuser::python_frontend::FusionDefinition::execute(c10::ArrayRef<c10::IValue> const&, std::optional<signed char>, bool, bool, bool, std::vector<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, std::allocator<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > > >, std::vector<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, std::allocator<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > > >) const + 0xbc4 (0x7f30db0eb7a4 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-312-x86_64-linux-gnu.so)
frame #11: <unknown function> + 0x1accd4 (0x7f30da81bcd4 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-312-x86_64-linux-gnu.so)
frame #12: <unknown function> + 0x281173 (0x7f30da8f0173 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-312-x86_64-linux-gnu.so)
frame #13: <unknown function> + 0x2e0835 (0x7f30da94f835 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-312-x86_64-linux-gnu.so)
frame #14: python() [0x5820ff]
<omitting python frames>
frame #17: python() [0x54cae4]
frame #21: python() [0x5a3698]
frame #24: python() [0x54ca1d]
frame #26: python() [0x54ca1d]
frame #29: python() [0x5a3698]
frame #33: python() [0x608b52]
frame #34: python() [0x6b4d83]
frame #39: <unknown function> + 0x2a1ca (0x7f44b2b011ca in /usr/lib/x86_64-linux-gnu/libc.so.6)
frame #40: __libc_start_main + 0x8b (0x7f44b2b0128b in /usr/lib/x86_64-linux-gnu/libc.so.6)

how to reproduce:

import torch
import torch.nn as nn
from torchao.float8 import convert_to_float8_training
import thunder
from thunder.tests.make_tensor import make_tensor


def main():
    batch_size, in_features, out_features = 16, 32, 64

    device = torch.device("cuda")
    dtype = torch.bfloat16
    bias = True

    model = nn.Sequential(
        nn.Linear(in_features, out_features, bias=bias),
        nn.ReLU(),
        nn.Linear(out_features, out_features, bias=bias),
    ).to(device=device, dtype=dtype)
    fp8_model = convert_to_float8_training(model)
    x = make_tensor((batch_size, in_features), device=device, dtype=dtype)

    jitted = thunder.jit(fp8_model, executors=[thunder.get_executor("torch"), thunder.get_executor("nvfuser")])
    actual = jitted(x)



if __name__ == "__main__":
    main()
@crcrpar crcrpar changed the title NTERNAL ASSERT FAILED at "nvfuser/csrc/device_lower/analysis/sync_information.cpp":812 Inconsistent parallelization found between TV48 and TV2 NTERNAL ASSERT FAILED at "nvfuser/csrc/device_lower/analysis/sync_information.cpp":812 Inconsistent parallelization found between TV48 and TV2 Dec 18, 2024
@crcrpar crcrpar changed the title NTERNAL ASSERT FAILED at "nvfuser/csrc/device_lower/analysis/sync_information.cpp":812 Inconsistent parallelization found between TV48 and TV2 INTERNAL ASSERT FAILED at "nvfuser/csrc/device_lower/analysis/sync_information.cpp":812 Inconsistent parallelization found between TV48 and TV2 Dec 19, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

1 participant