-
Notifications
You must be signed in to change notification settings - Fork 3.8k
Open
Labels
needs-triagePRs or issues that need to be investigated by maintainers to find the right assignees to address itPRs or issues that need to be investigated by maintainers to find the right assignees to address ittype: bug
Description
Summary
When using the TVM Relax Torch frontend to import a torch.export.exported program that involves a torch.sparse_csr_tensor registered as a buffer, from_exported_program crashes with a C++ c10::Error:
layout_impl is only implemented for TensorImpl subclasses.
torch.export.export itself succeeds. The crash happens during from_exported_program(ep).
Environment
OS: Linux x86_64
PyTorch: 2.9.0+cu128
TVM: 0.22.0
Python: 3.10.x
Minimal Reproduction
import torch
import torch.nn as nn
import tvm
from tvm.relax.frontend.torch import from_exported_program
print("torch version:", torch.__version__)
print("tvm version:", getattr(tvm, "__version__", "unknown"))
class MyModel(nn.Module):
def __init__(self):
super().__init__()
crow_indices = torch.tensor([0, 1, 2], dtype=torch.int64)
col_indices = torch.tensor([0, 1], dtype=torch.int64)
values = torch.tensor([1.0, 1.0], dtype=torch.float32, requires_grad=True)
csr_tensor = torch.sparse_csr_tensor(
crow_indices, col_indices, values, dtype=torch.float32
)
# Register sparse CSR tensor as a buffer
self.register_buffer("csr_tensor", csr_tensor)
# Explicitly enable grad as well
self.csr_tensor.requires_grad_(True)
def forward(self, x):
# Convert buffer to sparse CSR layout again
csr2 = self.csr_tensor.to_sparse(layout=torch.sparse_csr)
y = torch.matmul(csr2, x)
return y.sum()
def GetInput():
return torch.ones((2, 1), dtype=torch.float32)
def main():
model = MyModel().to("cpu").eval()
x = GetInput().to("cpu")
print("Start torch.export.export ...")
ep = torch.export.export(model, (x,))
print("torch.export.export done.")
print("Start from_exported_program ...")
ir_mod = from_exported_program(ep)
print("from_exported_program done.")
print(ir_mod)
if __name__ == "__main__":
main()Actual Behavior
Output on my machine:
torch version: 2.9.0+cu128
tvm version: 0.22.0
.../SparseCsrTensorImpl.cpp:53: Sparse CSR tensor support is in beta state...
Start torch.export.export ...
torch.export.export done.
Start from_exported_program ...
terminate called after throwing an instance of 'c10::Error'
what(): layout_impl is only implemented for TensorImpl subclasses.
Exception raised from layout_impl at /pytorch/c10/core/TensorImpl.h:1094 (most recent call first):
frame #0: c10::Error::Error(...)
frame #1: c10::detail::torchCheckFail(...)
frame #2: ...
frame #3: torch::autograd::InputMetadata::InputMetadata(at::Tensor const&) + ...
frame #4: ...
<libtorch / libtorch_python frames omitted>
...
Aborted (core dumped)
torch.export.export completes successfully; the abort occurs only when calling from_exported_program(ep).
Triage
Please refer to the list of label tags here to find the relevant tags and add them below in a bullet format (example below).
- needs-triage
- bug
Metadata
Metadata
Assignees
Labels
needs-triagePRs or issues that need to be investigated by maintainers to find the right assignees to address itPRs or issues that need to be investigated by maintainers to find the right assignees to address ittype: bug