-
Notifications
You must be signed in to change notification settings - Fork 20
update dsa and gdn definitions, op schema and unittests #164
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -75,7 +75,8 @@ FlashInfer-Bench supports the following op_types (corresponding to different Def | |||||
| | `gemm` | General Matrix Multiplication | `gemm_n_6144_k_4096` | | ||||||
| | `gqa_ragged` | Group Query Attention (ragged) | `gqa_ragged_prefill_causal_h32_kv8_d128` | | ||||||
| | `gqa_paged` | Group Query Attention (paged) | `gqa_paged_decode_h32_kv8_d128_ps1` | | ||||||
| | `mla_paged` | Multi-Head Linear Attention | `mla_paged_decode_h16_ckv512_kpe64_ps1` | | ||||||
| | `mla_paged` | Multi-Head Latent Attention (paged) | `mla_paged_decode_h16_ckv512_kpe64_ps1` | | ||||||
| | `dsa_paged` | DeepSeek Sparse Attention (paged) | `dsa_sparse_decode_h16_ckv512_kpe64_topk256_ps1` | | ||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Documentation example doesn't match actual definition name. The example shows Proposed fix-| `dsa_paged` | DeepSeek Sparse Attention (paged) | `dsa_sparse_decode_h16_ckv512_kpe64_topk256_ps1` |
+| `dsa_paged` | DeepSeek Sparse Attention (paged) | `dsa_sparse_attention_h16_ckv512_kpe64_topk256_ps1` |📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents |
||||||
| | `gdn` | Gated Delta Net (linear attention) | `gdn_decode_qk16_v32_d128_k_last` | | ||||||
| | `moe` | Mixture of Experts | `moe_fp8_block_scale_ds_routing_topk8_ng8_kg4_e32_h7168_i2048` | | ||||||
| | `sampling` | Sampling operations | - | | ||||||
|
|
@@ -168,6 +169,7 @@ Associate each module with corresponding Definitions: | |||||
| - **Attention layers**: | ||||||
| - GQA: `gqa_paged_decode_h{num_heads}_kv{kv_heads}_d{head_dim}_ps1` | ||||||
| - MLA: `mla_paged_decode_h{num_heads}_ckv{ckv_dim}_kpe{kpe_dim}_ps1` | ||||||
| - DSA: `dsa_sparse_decode_h{num_heads}_ckv{ckv_dim}_kpe{kpe_dim}_topk{topk}_ps1` (sparse MLA) | ||||||
| - GDN: `gdn_decode_qk{q_heads}_v{v_heads}_d{head_dim}` (linear attention) | ||||||
| - **GEMM layers**: `gemm_n_{out_dim}_k_{in_dim}` | ||||||
| - **MoE layers**: `moe_fp8_block_scale_ds_routing_topk{topk}_ng{num_groups}_kg{group_size}_e{num_experts}_h{hidden}_i{intermediate}` | ||||||
|
|
||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,55 @@ | ||
| # dsa_paged | ||
|
|
||
| DeepSeek Sparse Attention (DSA) with paged memory layout. DSA is a two-stage sparse attention mechanism: first an indexer selects top-K relevant KV cache entries using ReLU scoring, then MLA-style attention is performed only on selected entries. This reduces attention computation from O(n) to O(k) where k << n. | ||
|
|
||
| Variants: | ||
| - indexer | ||
| - sparse_attention | ||
|
|
||
| ## indexer | ||
|
|
||
| Computes sparse attention scores using ReLU activation and learned weights, then selects top-K KV cache indices. Uses FP8 quantization with deep_gemm format. | ||
|
|
||
| Axes (9 dimensions): | ||
| - `batch_size`, `max_num_pages`, `num_pages`: variable | ||
| - `num_index_heads`, `index_head_dim`, `page_size`, `topk`, `kv_cache_num_heads`, `head_dim_with_scale`: constant | ||
|
|
||
| Inputs (5 tensors): | ||
| - `q_index_fp8`: FP8 query for indexing [batch_size, num_index_heads, index_head_dim] | ||
| - `k_index_cache_fp8`: FP8 key index cache with scales [num_pages, page_size, kv_cache_num_heads, head_dim_with_scale] | ||
| - `weights`: learned head weights [batch_size, num_index_heads] | ||
| - `seq_lens`: sequence lengths [batch_size] | ||
| - `block_table`: page mapping [batch_size, max_num_pages] | ||
|
|
||
| Outputs (1 tensor): | ||
| - `topk_indices`: selected token indices [batch_size, topk], -1 indicates padding | ||
|
|
||
| Constraints: | ||
| - `topk <= max_num_pages * page_size` | ||
| - `num_index_heads == 64`, `index_head_dim == 128` (deep_gemm requirement) | ||
| - `head_dim_with_scale == 132` (128 + 4 scale bytes) | ||
|
|
||
| ## sparse_attention | ||
|
|
||
| Performs MLA-style attention on top-K selected KV entries. Works for both prefill (multiple tokens) and decode (one token per sequence) - the computation is identical, only the first dimension differs. | ||
|
|
||
| Axes (7 dimensions): | ||
| - `num_tokens`, `num_pages`: variable | ||
| - `num_qo_heads`, `head_dim_ckv`, `head_dim_kpe`, `page_size`, `topk`: constant | ||
|
|
||
| Inputs (5 tensors + 1 scalar): | ||
| - `q_nope`: query without positional encoding [num_tokens, num_qo_heads, head_dim_ckv] | ||
| - `q_pe`: query positional encoding [num_tokens, num_qo_heads, head_dim_kpe] | ||
| - `ckv_cache`: compressed KV cache [num_pages, page_size, head_dim_ckv] | ||
| - `kpe_cache`: key positional encoding cache [num_pages, page_size, head_dim_kpe] | ||
| - `sparse_indices`: top-K indices per token [num_tokens, topk], -1 indicates padding | ||
| - `sm_scale`: softmax scale (scalar) | ||
|
|
||
| Outputs (2 tensors): | ||
| - `output`: attention output [num_tokens, num_qo_heads, head_dim_ckv] | ||
| - `lse`: 2-based log-sum-exp [num_tokens, num_qo_heads] | ||
|
|
||
| Constraints: | ||
| - `sparse_indices.shape[0] == num_tokens` | ||
| - `sparse_indices.shape[-1] == topk` | ||
| - `ckv_cache.shape[1] == page_size` |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,18 +1,16 @@ | ||
| { | ||
| "name": "dsa_sparse_decode_h16_ckv512_kpe64_topk256_ps1", | ||
| "description": "Batched Native Sparse Attention (DSA) decode with sparse TopK KV cache selection. Captured from DeepSeek-V3 with tensor parallel size 8. Uses sparse indexing to select only top-K KV cache entries for attention computation.", | ||
| "name": "dsa_sparse_attention_h16_ckv512_kpe64_topk256_ps1", | ||
| "description": "Batched Native Sparse Attention (DSA) with sparse TopK KV cache selection. Captured from DeepSeek-V3 with tensor parallel size 8. Uses sparse indexing to select only top-K KV cache entries for attention computation. Works for both prefill and decode stages.", | ||
| "op_type": "dsa_paged", | ||
| "tags": [ | ||
| "stage:decode", | ||
| "status:verified", | ||
| "model:deepseek-v3", | ||
| "model:deepseek-r1", | ||
| "model:deepseek-v3.2", | ||
| "sparse:topk" | ||
| ], | ||
| "axes": { | ||
| "batch_size": { | ||
| "num_tokens": { | ||
| "type": "var", | ||
| "description": "Batch size (number of sequences)." | ||
| "description": "Number of tokens (batch_size for decode, total_num_tokens for prefill)." | ||
| }, | ||
| "num_qo_heads": { | ||
| "type": "const", | ||
|
|
@@ -45,13 +43,14 @@ | |
| } | ||
| }, | ||
| "constraints": [ | ||
| "sparse_indices.shape[0] == num_tokens", | ||
| "sparse_indices.shape[-1] == topk", | ||
| "ckv_cache.shape[1] == page_size" | ||
| ], | ||
| "inputs": { | ||
| "q_nope": { | ||
| "shape": [ | ||
| "batch_size", | ||
| "num_tokens", | ||
| "num_qo_heads", | ||
| "head_dim_ckv" | ||
| ], | ||
|
|
@@ -60,7 +59,7 @@ | |
| }, | ||
| "q_pe": { | ||
| "shape": [ | ||
| "batch_size", | ||
| "num_tokens", | ||
| "num_qo_heads", | ||
| "head_dim_kpe" | ||
| ], | ||
|
|
@@ -87,11 +86,11 @@ | |
| }, | ||
| "sparse_indices": { | ||
| "shape": [ | ||
| "batch_size", | ||
| "num_tokens", | ||
| "topk" | ||
| ], | ||
| "dtype": "int32", | ||
| "description": "Sparse indices selecting top-K KV cache entries for each batch element. Values of -1 indicate padding (invalid indices)." | ||
| "description": "Sparse indices selecting top-K KV cache entries for each token. Values of -1 indicate padding (invalid indices)." | ||
| }, | ||
| "sm_scale": { | ||
| "shape": null, | ||
|
|
@@ -102,20 +101,20 @@ | |
| "outputs": { | ||
| "output": { | ||
| "shape": [ | ||
| "batch_size", | ||
| "num_tokens", | ||
| "num_qo_heads", | ||
| "head_dim_ckv" | ||
| ], | ||
| "dtype": "bfloat16" | ||
| }, | ||
| "lse": { | ||
| "shape": [ | ||
| "batch_size", | ||
| "num_tokens", | ||
| "num_qo_heads" | ||
| ], | ||
| "dtype": "float32", | ||
| "description": "The 2-based log-sum-exp of attention logits." | ||
| } | ||
| }, | ||
| "reference": "import math\nimport torch\n\n\n@torch.no_grad()\ndef run(q_nope, q_pe, ckv_cache, kpe_cache, sparse_indices, sm_scale):\n batch_size, num_qo_heads, head_dim_ckv = q_nope.shape\n head_dim_kpe = q_pe.shape[-1]\n page_size = ckv_cache.shape[1]\n topk = sparse_indices.shape[-1]\n\n # Check constants\n assert num_qo_heads == 16\n assert head_dim_ckv == 512\n assert head_dim_kpe == 64\n assert page_size == 1\n assert topk == 256\n\n # Check constraints\n assert sparse_indices.shape[-1] == topk\n assert ckv_cache.shape[1] == page_size\n\n device = q_nope.device\n\n # Squeeze page dimension (page_size=1)\n Kc_all = ckv_cache.squeeze(1).to(torch.float32) # [num_pages, head_dim_ckv]\n Kp_all = kpe_cache.squeeze(1).to(torch.float32) # [num_pages, head_dim_kpe]\n\n output = torch.zeros(\n (batch_size, num_qo_heads, head_dim_ckv), dtype=torch.bfloat16, device=device\n )\n lse = torch.full((batch_size, num_qo_heads), -float(\"inf\"), dtype=torch.float32, device=device)\n\n for b in range(batch_size):\n indices = sparse_indices[b] # [topk]\n\n # Handle padding: -1 indicates invalid indices\n valid_mask = indices != -1\n valid_indices = indices[valid_mask]\n\n if valid_indices.numel() == 0:\n output[b].zero_()\n continue\n\n tok_idx = valid_indices.to(torch.long)\n\n Kc = Kc_all[tok_idx] # [num_valid, head_dim_ckv]\n Kp = Kp_all[tok_idx] # [num_valid, head_dim_kpe]\n qn = q_nope[b].to(torch.float32) # [num_qo_heads, head_dim_ckv]\n qp = q_pe[b].to(torch.float32) # [num_qo_heads, head_dim_kpe]\n\n # Compute attention logits\n logits = (qn @ Kc.T) + (qp @ Kp.T) # [num_qo_heads, num_valid]\n logits_scaled = logits * sm_scale\n\n # Compute 2-base LSE\n lse[b] = torch.logsumexp(logits_scaled, dim=-1) / math.log(2.0)\n\n # Compute attention output\n attn = torch.softmax(logits_scaled, dim=-1) # [num_qo_heads, num_valid]\n out = attn @ Kc # [num_qo_heads, head_dim_ckv]\n output[b] = out.to(torch.bfloat16)\n\n return output, lse" | ||
| "reference": "import math\nimport torch\n\n\n@torch.no_grad()\ndef run(q_nope, q_pe, ckv_cache, kpe_cache, sparse_indices, sm_scale):\n num_tokens, num_qo_heads, head_dim_ckv = q_nope.shape\n head_dim_kpe = q_pe.shape[-1]\n page_size = ckv_cache.shape[1]\n topk = sparse_indices.shape[-1]\n\n # Check constants\n assert num_qo_heads == 16\n assert head_dim_ckv == 512\n assert head_dim_kpe == 64\n assert page_size == 1\n assert topk == 256\n\n # Check constraints\n assert sparse_indices.shape[0] == num_tokens\n assert sparse_indices.shape[-1] == topk\n assert ckv_cache.shape[1] == page_size\n\n device = q_nope.device\n\n # Squeeze page dimension (page_size=1)\n Kc_all = ckv_cache.squeeze(1).to(torch.float32) # [num_pages, head_dim_ckv]\n Kp_all = kpe_cache.squeeze(1).to(torch.float32) # [num_pages, head_dim_kpe]\n\n output = torch.zeros(\n (num_tokens, num_qo_heads, head_dim_ckv), dtype=torch.bfloat16, device=device\n )\n lse = torch.full((num_tokens, num_qo_heads), -float(\"inf\"), dtype=torch.float32, device=device)\n\n for t in range(num_tokens):\n indices = sparse_indices[t] # [topk]\n\n # Handle padding: -1 indicates invalid indices\n valid_mask = indices != -1\n valid_indices = indices[valid_mask]\n\n if valid_indices.numel() == 0:\n output[t].zero_()\n continue\n\n tok_idx = valid_indices.to(torch.long)\n\n Kc = Kc_all[tok_idx] # [num_valid, head_dim_ckv]\n Kp = Kp_all[tok_idx] # [num_valid, head_dim_kpe]\n qn = q_nope[t].to(torch.float32) # [num_qo_heads, head_dim_ckv]\n qp = q_pe[t].to(torch.float32) # [num_qo_heads, head_dim_kpe]\n\n # Compute attention logits\n logits = (qn @ Kc.T) + (qp @ Kp.T) # [num_qo_heads, num_valid]\n logits_scaled = logits * sm_scale\n\n # Compute 2-base LSE\n lse[t] = torch.logsumexp(logits_scaled, dim=-1) / math.log(2.0)\n\n # Compute attention output\n attn = torch.softmax(logits_scaled, dim=-1) # [num_qo_heads, num_valid]\n out = attn @ Kc # [num_qo_heads, head_dim_ckv]\n output[t] = out.to(torch.bfloat16)\n\n return output, lse" | ||
| } | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,18 +1,16 @@ | ||
| { | ||
| "name": "dsa_sparse_decode_h16_ckv512_kpe64_topk256_ps64", | ||
| "description": "Batched Native Sparse Attention (DSA) decode with sparse TopK KV cache selection. Captured from DeepSeek-V3 with tensor parallel size 8. Uses sparse indexing to select only top-K KV cache entries for attention computation. Page size 64 variant.", | ||
| "name": "dsa_sparse_attention_h16_ckv512_kpe64_topk256_ps64", | ||
| "description": "Batched Native Sparse Attention (DSA) with sparse TopK KV cache selection. Captured from DeepSeek-V3 with tensor parallel size 8. Uses sparse indexing to select only top-K KV cache entries for attention computation. Page size 64 variant. Works for both prefill and decode stages.", | ||
| "op_type": "dsa_paged", | ||
| "tags": [ | ||
| "stage:decode", | ||
| "status:verified", | ||
| "model:deepseek-v3", | ||
| "model:deepseek-r1", | ||
| "model:deepseek-v3.2", | ||
| "sparse:topk" | ||
| ], | ||
| "axes": { | ||
| "batch_size": { | ||
| "num_tokens": { | ||
| "type": "var", | ||
| "description": "Batch size (number of sequences)." | ||
| "description": "Number of tokens (batch_size for decode, total_num_tokens for prefill)." | ||
| }, | ||
| "num_qo_heads": { | ||
| "type": "const", | ||
|
|
@@ -45,13 +43,14 @@ | |
| } | ||
| }, | ||
| "constraints": [ | ||
| "sparse_indices.shape[0] == num_tokens", | ||
| "sparse_indices.shape[-1] == topk", | ||
| "ckv_cache.shape[1] == page_size" | ||
| ], | ||
| "inputs": { | ||
| "q_nope": { | ||
| "shape": [ | ||
| "batch_size", | ||
| "num_tokens", | ||
| "num_qo_heads", | ||
| "head_dim_ckv" | ||
| ], | ||
|
|
@@ -60,7 +59,7 @@ | |
| }, | ||
| "q_pe": { | ||
| "shape": [ | ||
| "batch_size", | ||
| "num_tokens", | ||
| "num_qo_heads", | ||
| "head_dim_kpe" | ||
| ], | ||
|
|
@@ -87,11 +86,11 @@ | |
| }, | ||
| "sparse_indices": { | ||
| "shape": [ | ||
| "batch_size", | ||
| "num_tokens", | ||
| "topk" | ||
| ], | ||
| "dtype": "int32", | ||
| "description": "Sparse indices selecting top-K KV cache entries for each batch element. Values of -1 indicate padding (invalid indices). For page_size=64, indices encode (page_idx * 64 + offset)." | ||
| "description": "Sparse indices selecting top-K KV cache entries for each token. Values of -1 indicate padding (invalid indices). For page_size=64, indices encode (page_idx * 64 + offset)." | ||
| }, | ||
| "sm_scale": { | ||
| "shape": null, | ||
|
|
@@ -102,20 +101,20 @@ | |
| "outputs": { | ||
| "output": { | ||
| "shape": [ | ||
| "batch_size", | ||
| "num_tokens", | ||
| "num_qo_heads", | ||
| "head_dim_ckv" | ||
| ], | ||
| "dtype": "bfloat16" | ||
| }, | ||
| "lse": { | ||
| "shape": [ | ||
| "batch_size", | ||
| "num_tokens", | ||
| "num_qo_heads" | ||
| ], | ||
| "dtype": "float32", | ||
| "description": "The 2-based log-sum-exp of attention logits." | ||
| } | ||
| }, | ||
| "reference": "import math\nimport torch\n\n\n@torch.no_grad()\ndef run(q_nope, q_pe, ckv_cache, kpe_cache, sparse_indices, sm_scale):\n batch_size, num_qo_heads, head_dim_ckv = q_nope.shape\n head_dim_kpe = q_pe.shape[-1]\n num_pages, page_size, _ = ckv_cache.shape\n topk = sparse_indices.shape[-1]\n\n # Check constants\n assert num_qo_heads == 16\n assert head_dim_ckv == 512\n assert head_dim_kpe == 64\n assert page_size == 64\n assert topk == 256\n\n # Check constraints\n assert sparse_indices.shape[-1] == topk\n assert ckv_cache.shape[1] == page_size\n\n device = q_nope.device\n\n # Flatten paged KV cache to token-level: [num_pages, page_size, dim] -> [num_pages * page_size, dim]\n Kc_all = ckv_cache.reshape(-1, head_dim_ckv).to(torch.float32) # [total_kv_tokens, head_dim_ckv]\n Kp_all = kpe_cache.reshape(-1, head_dim_kpe).to(torch.float32) # [total_kv_tokens, head_dim_kpe]\n\n output = torch.zeros(\n (batch_size, num_qo_heads, head_dim_ckv), dtype=torch.bfloat16, device=device\n )\n lse = torch.full((batch_size, num_qo_heads), -float(\"inf\"), dtype=torch.float32, device=device)\n\n for b in range(batch_size):\n indices = sparse_indices[b] # [topk]\n\n # Handle padding: -1 indicates invalid indices\n valid_mask = indices != -1\n valid_indices = indices[valid_mask]\n\n if valid_indices.numel() == 0:\n output[b].zero_()\n continue\n\n # For page_size=64, indices encode (page_idx * 64 + offset)\n tok_idx = valid_indices.to(torch.long)\n\n Kc = Kc_all[tok_idx] # [num_valid, head_dim_ckv]\n Kp = Kp_all[tok_idx] # [num_valid, head_dim_kpe]\n qn = q_nope[b].to(torch.float32) # [num_qo_heads, head_dim_ckv]\n qp = q_pe[b].to(torch.float32) # [num_qo_heads, head_dim_kpe]\n\n # Compute attention logits\n logits = (qn @ Kc.T) + (qp @ Kp.T) # [num_qo_heads, num_valid]\n logits_scaled = logits * sm_scale\n\n # Compute 2-base LSE\n lse[b] = torch.logsumexp(logits_scaled, dim=-1) / math.log(2.0)\n\n # Compute attention output\n attn = torch.softmax(logits_scaled, dim=-1) # [num_qo_heads, num_valid]\n out = attn @ Kc # [num_qo_heads, head_dim_ckv]\n output[b] = out.to(torch.bfloat16)\n\n return {\"output\": output, \"lse\": lse}" | ||
| "reference": "import math\nimport torch\n\n\n@torch.no_grad()\ndef run(q_nope, q_pe, ckv_cache, kpe_cache, sparse_indices, sm_scale):\n num_tokens, num_qo_heads, head_dim_ckv = q_nope.shape\n head_dim_kpe = q_pe.shape[-1]\n num_pages, page_size, _ = ckv_cache.shape\n topk = sparse_indices.shape[-1]\n\n # Check constants\n assert num_qo_heads == 16\n assert head_dim_ckv == 512\n assert head_dim_kpe == 64\n assert page_size == 64\n assert topk == 256\n\n # Check constraints\n assert sparse_indices.shape[0] == num_tokens\n assert sparse_indices.shape[-1] == topk\n assert ckv_cache.shape[1] == page_size\n\n device = q_nope.device\n\n # Flatten paged KV cache to token-level: [num_pages, page_size, dim] -> [num_pages * page_size, dim]\n Kc_all = ckv_cache.reshape(-1, head_dim_ckv).to(torch.float32) # [total_kv_tokens, head_dim_ckv]\n Kp_all = kpe_cache.reshape(-1, head_dim_kpe).to(torch.float32) # [total_kv_tokens, head_dim_kpe]\n\n output = torch.zeros(\n (num_tokens, num_qo_heads, head_dim_ckv), dtype=torch.bfloat16, device=device\n )\n lse = torch.full((num_tokens, num_qo_heads), -float(\"inf\"), dtype=torch.float32, device=device)\n\n for t in range(num_tokens):\n indices = sparse_indices[t] # [topk]\n\n # Handle padding: -1 indicates invalid indices\n valid_mask = indices != -1\n valid_indices = indices[valid_mask]\n\n if valid_indices.numel() == 0:\n output[t].zero_()\n continue\n\n # For page_size=64, indices encode (page_idx * 64 + offset)\n tok_idx = valid_indices.to(torch.long)\n\n Kc = Kc_all[tok_idx] # [num_valid, head_dim_ckv]\n Kp = Kp_all[tok_idx] # [num_valid, head_dim_kpe]\n qn = q_nope[t].to(torch.float32) # [num_qo_heads, head_dim_ckv]\n qp = q_pe[t].to(torch.float32) # [num_qo_heads, head_dim_kpe]\n\n # Compute attention logits\n logits = (qn @ Kc.T) + (qp @ Kp.T) # [num_qo_heads, num_valid]\n logits_scaled = logits * sm_scale\n\n # Compute 2-base LSE\n lse[t] = torch.logsumexp(logits_scaled, dim=-1) / math.log(2.0)\n\n # Compute attention output\n attn = torch.softmax(logits_scaled, dim=-1) # [num_qo_heads, num_valid]\n out = attn @ Kc # [num_qo_heads, head_dim_ckv]\n output[t] = out.to(torch.bfloat16)\n\n return output, lse" | ||
| } | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we actually need to embed the exact op types in CLAUDE.md? I think we can just provide a general introduction here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point, they are automatic edit by claude code itself, we should make the CLAUDE.md more simple (not necessarily in this PR).