Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/api/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,5 @@ rst/schema
rst/apply
rst/tracing
rst/compile
rst/agents
```
25 changes: 25 additions & 0 deletions docs/api/rst/agents.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# flashinfer_bench.agents

`flashinfer_bench.agents` provides tools for LLM-based kernel development and debugging.
This module enables LLM agents to:

1. **Profiling Tools**: Run NVIDIA Nsight Compute, Compute Sanitizer, etc. on solutions
2. **Schema Generation**: Generate JSON Schema from function signatures for tool calling
3. **FFI Prompts**: Provide context about the FlashInfer Bench API for LLM agents

```{eval-rst}
.. currentmodule:: flashinfer_bench.agents
.. autofunction:: flashinfer_bench_run_ncu
.. autofunction:: flashinfer_bench_list_ncu_options
.. autofunction:: function_to_schema
.. autofunction:: get_all_tool_schemas
.. automodule:: flashinfer_bench.agents.ffi_prompt
:members:
:no-value:
```
15 changes: 13 additions & 2 deletions flashinfer_bench/agents/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,14 @@
from flashinfer_bench.agents.ffi_prompt import FFI_PROMPT, FFI_PROMPT_SIMPLE
"""Agent tools for LLM-based kernel development and debugging."""

__all__ = ["FFI_PROMPT_SIMPLE", "FFI_PROMPT"]
from .ffi_prompt import FFI_PROMPT, FFI_PROMPT_SIMPLE
from .ncu import flashinfer_bench_list_ncu_options, flashinfer_bench_run_ncu
from .schema import function_to_schema, get_all_tool_schemas

__all__ = [
"flashinfer_bench_list_ncu_options",
"flashinfer_bench_run_ncu",
"function_to_schema",
"get_all_tool_schemas",
"FFI_PROMPT_SIMPLE",
"FFI_PROMPT",
]
58 changes: 58 additions & 0 deletions flashinfer_bench/agents/_solution_runner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
"""Solution runner - standalone script for profiling solutions."""

import argparse
from pathlib import Path

import torch

from flashinfer_bench.bench.evaluators.utils import allocate_outputs
from flashinfer_bench.bench.utils import gen_inputs, load_safetensors
from flashinfer_bench.compile import BuilderRegistry
from flashinfer_bench.data import Definition, Solution, Workload


def main():
parser = argparse.ArgumentParser(description="Run a solution for NCU profiling")
parser.add_argument("--data-dir", required=True, help="Path to data directory")
parser.add_argument("--device", required=True, help="CUDA device to run on")
parser.add_argument("--trace-set-path", help="Path to trace set")
args = parser.parse_args()

data_dir = Path(args.data_dir)
device = args.device
trace_set_path = Path(args.trace_set_path) if args.trace_set_path else None

# Load data from JSON files
definition = Definition.model_validate_json((data_dir / "definition.json").read_text())
solution = Solution.model_validate_json((data_dir / "solution.json").read_text())
workload = Workload.model_validate_json((data_dir / "workload.json").read_text())

# Build the solution
registry = BuilderRegistry.get_instance()
runnable = registry.build(definition, solution)

# Load safetensors if needed
safe_tensors = None
if any(inp.type == "safetensors" for inp in workload.inputs.values()):
safe_tensors = load_safetensors(definition, workload, trace_set_path)

# Generate inputs
inputs = gen_inputs(definition, workload, device, safe_tensors)

# Allocate output tensors
outputs = allocate_outputs(definition, inputs, device)

# Warmup run to trigger JIT compilation
with torch.no_grad():
runnable.call_destination_passing(*inputs, *outputs)
torch.cuda.synchronize()

# Actual run for profiling (marked with NVTX for NCU filtering)
with torch.cuda.nvtx.range("flashinfer_bench_ncu_profile"):
with torch.no_grad():
runnable.call_destination_passing(*inputs, *outputs)
torch.cuda.synchronize()


if __name__ == "__main__":
main()
4 changes: 4 additions & 0 deletions flashinfer_bench/agents/ffi_prompt.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
"""Prompt templates for TVM FFI API documentation used by agents."""

FFI_PROMPT_SIMPLE = """
Use TVM FFI format for your generated kernel host function and bindings

Expand Down Expand Up @@ -159,6 +161,7 @@
} // namespace my_kernels
```
"""
"""Simplified TVM FFI API documentation with essential methods and a basic example."""

FFI_PROMPT = """
Use TVM FFI format for your generated kernel host function and bindings
Expand Down Expand Up @@ -585,3 +588,4 @@
} // namespace my_kernels
```
"""
"""Comprehensive TVM FFI API documentation with full method signatures and multiple examples."""
Loading