Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sharded Llama missing cache update in exported MLIR #271

Open
sogartar opened this issue Oct 11, 2024 · 3 comments
Open

Sharded Llama missing cache update in exported MLIR #271

sogartar opened this issue Oct 11, 2024 · 3 comments
Assignees

Comments

@sogartar
Copy link
Contributor

The sharded variant of the exported Lllama from this test has missing IR for the update of the paged cache.
Note that the equivalent unsharded model does not exhibit this problem.

My hypothesis is it got erroneously optimized out due to dead code elimination as we are not properly generating cache update code to be really in-place. Probably some problem with interaction with flow.tensor.trasfer as it actually introduces a new tensor value. But when not exporting we just return the same tensor so in-place semantics work fine.

Here are both the sharded and unsharded exported MLIR programs-mlir.zip.

@sogartar
Copy link
Contributor Author

sogartar commented Oct 16, 2024

Here is a PR that helps address this. It uses the approach of inserting an in-place device placement torch FX op, that does not get materialized as an op, but sets the affinity for the corresponding function argument. This seems a bit brittle.

There is one other approach where we insert attributes into the Torch FX graph after generation.
E.g. (thanks to @iamakanshab)

import torch
from torch.fx import Graph, Node

def create_fx_function_with_custom_attributes():
    graph = Graph()
    
    # Create input Node
    input_node = graph.placeholder('x', type_expr=torch.Tensor)
    
    # Set custom attributes on the input node
    input_node.custom_attr = "input_custom_value"
    
    # Create a parameter Node
    param_node = graph.placeholder('weight', type_expr=torch.Tensor)
    param_node.custom_attr = "weight_custom_value"
    
    # Create an operation Node
    output_node = graph.call_function(torch.matmul, args=(input_node, param_node))
    
    # Set the output
    graph.output(output_node)
    
    # Create a GraphModule from this Graph
    module = torch.fx.GraphModule(torch.nn.Module(), graph)
    
    return module

# Create the FX function
fx_module = create_fx_function_with_custom_attributes()

# Print the generated code
print(fx_module.code)

# Access custom attributes
for node in fx_module.graph.nodes:
    if hasattr(node, 'custom_attr'):
        print(f"Node: {node.name}, Custom Attribute: {node.custom_attr}")

Then we convert these attributes to function argument attributes in MLIR after the main conversion has occurred.
This approach would require adding device placement information in our custom tensor type and then in pytree node registration to add this data to the corresponding argument context.
This context will then be available in the torch.export.exported_program.ExportedProgram.call_spec. From there we copy the attribute to the 'torch.fx.graph.Graph' as show in the example above.
This approach would not work if passing torch.Tensor as an argument. We can't influence its context in the call_spec. Probably that is OK as we control the input.

@stellaraccident, what other approach is there?

@stellaraccident
Copy link
Contributor

I think that there are some disconnected pieces to make this work. Can we confirm that we really do need these function annotations? One thing is that those DefaultPrimitiveTensor types are not real torch Tensors. I'm not sure how that connects like stated.

This might need to be done in our call to import, passing device placements for some arguments vs through export goo.

@sogartar
Copy link
Contributor Author

I think that there are some disconnected pieces to make this work. Can we confirm that we really do need these function annotations?

If we don't set the iree.abi.affinity attribute for each argument they should get default device placement. I think that is the intention of @benvanik.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants