Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

The Segmenter Logic Latency increases 10x between 20 and 30 Fusions #3510

Open
kevinstephano opened this issue Dec 2, 2024 · 0 comments
Open
Labels
Host Latency Segmentation Issues related to nvFuser Segmentation Thunder

Comments

@kevinstephano
Copy link
Collaborator

Thunder based repro. Add NVFUSER_DUMP=python_definition to see nvFuser's FusionDefinition as it is large!

import torch
import thunder

class MySimpleModel(torch.nn.Module):
    def __init__(self, n_layers=10):
        super().__init__()
        self.fcs = torch.nn.ModuleList([torch.nn.Linear(16, 16) for _ in range(n_layers)])

    def forward(self, x):
        for fc in self.fcs:
            x = torch.nn.functional.relu(fc(x))
        
        return x

def get_model_and_args():
    device = 'cuda'
    model = MySimpleModel(n_layers=30).to(device)
    args = (torch.randn(128, 16, device=device),)
    kwargs = {}
    return model, args, kwargs

model, args, kwargs = get_model_and_args()

# Check against the vanilla `thunder.jit` model
jfun = thunder.jit(model, nv_enable_linear=True)
import time
st=time.time()
expected = jfun(*args, **kwargs)
print("time:", time.time()-st)

There are a couple of issue happening here:

The segmenter increases in time to 300s at 30 layers for some reason up from 300 ms at 20 layers. [The Largest Issue]
The NVRTC compilation is happening more than once for the same activation kernel.
Image

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Host Latency Segmentation Issues related to nvFuser Segmentation Thunder
Projects
None yet
Development

No branches or pull requests

1 participant