Skip to content

[Bug] relax.frontend.torch.from_exported_program aborts on sparse CSR buffer (layout_impl is only implemented for TensorImpl subclasses) #18648

@tinywisdom

Description

@tinywisdom

Summary

When using the TVM Relax Torch frontend to import a torch.export.exported program that involves a torch.sparse_csr_tensor registered as a buffer, from_exported_program crashes with a C++ c10::Error:

layout_impl is only implemented for TensorImpl subclasses.

torch.export.export itself succeeds. The crash happens during from_exported_program(ep).


Environment

OS: Linux x86_64
PyTorch: 2.9.0+cu128
TVM: 0.22.0
Python: 3.10.x

Minimal Reproduction

import torch
import torch.nn as nn

import tvm
from tvm.relax.frontend.torch import from_exported_program

print("torch version:", torch.__version__)
print("tvm version:", getattr(tvm, "__version__", "unknown"))


class MyModel(nn.Module):
    def __init__(self):
        super().__init__()
        crow_indices = torch.tensor([0, 1, 2], dtype=torch.int64)
        col_indices = torch.tensor([0, 1], dtype=torch.int64)
        values = torch.tensor([1.0, 1.0], dtype=torch.float32, requires_grad=True)

        csr_tensor = torch.sparse_csr_tensor(
            crow_indices, col_indices, values, dtype=torch.float32
        )
        # Register sparse CSR tensor as a buffer
        self.register_buffer("csr_tensor", csr_tensor)
        # Explicitly enable grad as well
        self.csr_tensor.requires_grad_(True)

    def forward(self, x):
        # Convert buffer to sparse CSR layout again
        csr2 = self.csr_tensor.to_sparse(layout=torch.sparse_csr)
        y = torch.matmul(csr2, x)
        return y.sum()


def GetInput():
    return torch.ones((2, 1), dtype=torch.float32)


def main():
    model = MyModel().to("cpu").eval()
    x = GetInput().to("cpu")

    print("Start torch.export.export ...")
    ep = torch.export.export(model, (x,))
    print("torch.export.export done.")

    print("Start from_exported_program ...")
    ir_mod = from_exported_program(ep)
    print("from_exported_program done.")
    print(ir_mod)


if __name__ == "__main__":
    main()

Actual Behavior

Output on my machine:

torch version: 2.9.0+cu128
tvm version: 0.22.0
.../SparseCsrTensorImpl.cpp:53: Sparse CSR tensor support is in beta state...
Start torch.export.export ...
torch.export.export done.
Start from_exported_program ...
terminate called after throwing an instance of 'c10::Error'
  what():  layout_impl is only implemented for TensorImpl subclasses.
Exception raised from layout_impl at /pytorch/c10/core/TensorImpl.h:1094 (most recent call first):
frame #0: c10::Error::Error(...)
frame #1: c10::detail::torchCheckFail(...)
frame #2: ...
frame #3: torch::autograd::InputMetadata::InputMetadata(at::Tensor const&) + ...
frame #4: ...
<libtorch / libtorch_python frames omitted>
...
Aborted (core dumped)

torch.export.export completes successfully; the abort occurs only when calling from_exported_program(ep).

Triage

Please refer to the list of label tags here to find the relevant tags and add them below in a bullet format (example below).

  • needs-triage
  • bug

Metadata

Metadata

Assignees

No one assigned

    Labels

    needs-triagePRs or issues that need to be investigated by maintainers to find the right assignees to address ittype: bug

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions