-
Notifications
You must be signed in to change notification settings - Fork 20
fix: GDN decode reference and test #163
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Summary of ChangesHello @yyihuang, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request addresses an issue with the Gated Delta Net (GDN) kernel test by updating its reference implementation to align perfectly with the FlashInfer kernel's behavior. This ensures accurate validation of the kernel. Concurrently, the Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
|
Note Other AI code review bot(s) detectedCodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review. 📝 WalkthroughWalkthroughUpdated GDN prefill/decode JSON specs to add a public Changes
Sequence Diagram(s)(omitted — changes are spec and test updates without a multi-component runtime control-flow requiring visualization) Estimated code review effort🎯 3 (Moderate) | ⏱️ ~22 minutes Possibly related PRs
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 3✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request aims to fix the Gated Delta Net (GDN) kernel test by updating the reference implementation. The changes in the markdown files improve the documentation and shell scripts, which is great. However, I've found a critical issue in the updated reference implementation within gdn_decode_qk16_v32_d128_k_last.json. The decay logic appears to be flawed, causing state amplification instead of decay. This likely explains why the test tolerances in test_gdn_decode_qk16_v32_d128_k_last.py had to be increased so dramatically to 7e-2. My review includes a critical comment with a suggested fix for the reference implementation's decay logic. I've also added high-severity comments regarding the loose test tolerances, which should be tightened after the reference implementation is corrected. Addressing these points should lead to a more robust and numerically correct test.
| } | ||
| }, | ||
| "reference": "import math\nimport torch\nimport torch.nn.functional as F\n\n\ndef matmul(a: torch.Tensor, b: torch.Tensor):\n \"\"\"Float32 matmul for numerical stability.\"\"\"\n return a.float() @ b.float()\n\n\n@torch.no_grad()\ndef run(q, k, v, state, A_log, a, dt_bias, b, scale):\n \"\"\"\n Gated Delta Net decode reference implementation (k-last layout).\n \n State layout: [B, H, V, K] (k-last, K dimension at the end)\n \n Gate computation:\n g = exp(-exp(A_log) * softplus(a + dt_bias))\n beta = sigmoid(b)\n \n Delta rule update:\n state_new = g * state_old + k^T @ (beta * v + (1-beta) * k @ state_old) - k^T @ (k @ state_old)\n output = scale * q @ state_new\n \"\"\"\n B, T, num_q_heads, K = q.shape\n _, _, num_k_heads, _ = k.shape\n _, _, num_v_heads, V = v.shape\n num_heads = num_v_heads\n device = q.device\n \n assert num_q_heads == 16\n assert num_k_heads == 16\n assert num_v_heads == 32\n assert K == 128 and V == 128\n assert T == 1\n \n if scale is None or scale == 0.0:\n scale = 1.0 / math.sqrt(K)\n \n # Compute g and beta from raw parameters\n x = a.float() + dt_bias.float() # [B, 1, HV]\n g = torch.exp(-torch.exp(A_log.float()) * F.softplus(x)) # [B, 1, HV]\n beta = torch.sigmoid(b.float()) # [B, 1, HV]\n \n q_f32 = q.squeeze(1).float()\n k_f32 = k.squeeze(1).float()\n v_f32 = v.squeeze(1).float()\n g_f32 = g.squeeze(1).float()\n beta_f32 = beta.squeeze(1).float()\n \n if state is not None:\n state_f32 = state.float()\n else:\n state_f32 = torch.zeros(B, num_heads, V, K, dtype=torch.float32, device=device)\n \n q_exp = q_f32.repeat_interleave(num_v_heads // num_q_heads, dim=1)\n k_exp = k_f32.repeat_interleave(num_v_heads // num_k_heads, dim=1)\n \n new_state = torch.zeros_like(state_f32)\n output = torch.zeros(B, num_heads, V, dtype=torch.float32, device=device)\n \n for b_idx in range(B):\n for h_idx in range(num_heads):\n q_h = q_exp[b_idx, h_idx]\n k_h = k_exp[b_idx, h_idx]\n v_h = v_f32[b_idx, h_idx]\n h_state = state_f32[b_idx, h_idx].clone().transpose(-1, -2) # [V,K] -> [K,V]\n g_val = g_f32[b_idx, h_idx]\n beta_val = beta_f32[b_idx, h_idx]\n \n old_state = g_val * h_state\n old_v = k_h @ old_state\n new_v = beta_val * v_h + (1 - beta_val) * old_v\n state_remove = k_h.unsqueeze(1) @ old_v.unsqueeze(0)\n state_update = k_h.unsqueeze(1) @ new_v.unsqueeze(0)\n h_state = old_state - state_remove + state_update\n \n output[b_idx, h_idx] = scale * (q_h @ h_state)\n new_state[b_idx, h_idx] = h_state.transpose(-1, -2) # [K,V] -> [V,K]\n \n output = output.unsqueeze(1).to(torch.bfloat16)\n return {\"output\": output, \"new_state\": new_state}" | ||
| "reference": "import math\nimport torch\nimport torch.nn.functional as F\n\n\n@torch.no_grad()\ndef run(q, k, v, state, A_log, a, dt_bias, b, scale):\n \"\"\"\n Gated Delta Net decode reference implementation matching FlashInfer kernel.\n \n State layout: [B, H, V, K] (k-last, V dimension before K)\n \n Gate computation:\n g = exp(-exp(A_log) * softplus(a + dt_bias))\n beta = sigmoid(b)\n \n Delta rule update (following FlashInfer kernel exactly):\n 1. Apply L2 normalization to q and k\n 2. h *= exp(g) # Apply decay to state\n 3. v_new = v - k^T @ h # Delta rule\n 4. v_new *= beta # Apply update gate\n 5. h += k @ v_new^T # Update state (outer product)\n 6. o = scale * q^T @ h # Compute output\n \"\"\"\n B, T, num_q_heads, K = q.shape\n _, _, num_k_heads, _ = k.shape\n _, _, num_v_heads, V = v.shape\n num_heads = num_v_heads\n device = q.device\n dtype = torch.float32\n \n assert num_q_heads == 16\n assert num_k_heads == 16\n assert num_v_heads == 32\n assert K == 128 and V == 128\n assert T == 1\n \n # Default scale\n if scale is None or scale == 0.0:\n scale = 1.0 / math.sqrt(K)\n \n # Convert to float32 for computation\n A_log = A_log.to(dtype).to(device)\n a = a.to(dtype).to(device)\n dt_bias = dt_bias.to(dtype).to(device)\n b = b.to(dtype).to(device)\n \n # Compute gating values\n # g = exp(-exp(A_log) * softplus(a + dt_bias))\n x = a + dt_bias # [B, 1, HV]\n softplus_x = F.softplus(x)\n g = torch.exp(-torch.exp(A_log) * softplus_x) # [B, 1, HV]\n \n # beta = sigmoid(b)\n beta = torch.sigmoid(b) # [B, 1, HV]\n \n # Process tensors\n q_f32 = q.squeeze(1).float() # [B, num_q_heads, K]\n k_f32 = k.squeeze(1).float() # [B, num_k_heads, K]\n v_f32 = v.squeeze(1).float() # [B, num_v_heads, V]\n g_f32 = g.squeeze(1) # [B, num_v_heads]\n beta_f32 = beta.squeeze(1) # [B, num_v_heads]\n \n if state is not None:\n state_f32 = state.float() # [B, num_v_heads, V, K]\n else:\n state_f32 = torch.zeros(B, num_v_heads, V, K, dtype=dtype, device=device)\n \n # Expand heads for GVA (num_v_heads > num_q_heads)\n q_exp = q_f32.repeat_interleave(num_v_heads // num_q_heads, dim=1) # [B, num_v_heads, K]\n k_exp = k_f32.repeat_interleave(num_v_heads // num_k_heads, dim=1) # [B, num_v_heads, K]\n \n # Apply L2 normalization (matching FlashInfer kernel with use_qk_l2norm=True)\n q_exp = F.normalize(q_exp, p=2.0, dim=-1)\n k_exp = F.normalize(k_exp, p=2.0, dim=-1)\n \n # Apply scale to q\n q_exp = q_exp * scale\n \n # Initialize outputs\n new_state = torch.zeros_like(state_f32) # [B, num_v_heads, V, K]\n output = torch.zeros(B, num_v_heads, V, dtype=dtype, device=device)\n \n # Process each batch and head\n for b_idx in range(B):\n for h_idx in range(num_v_heads):\n q_h = q_exp[b_idx, h_idx] # [K]\n k_h = k_exp[b_idx, h_idx] # [K]\n v_h = v_f32[b_idx, h_idx] # [V]\n # State is [V, K] but we need [K, V] for computation\n h_state = state_f32[b_idx, h_idx].transpose(-1, -2).clone() # [V, K] -> [K, V]\n \n g_val = g_f32[b_idx, h_idx] # scalar\n beta_val = beta_f32[b_idx, h_idx] # scalar\n \n # Delta rule following FlashInfer kernel:\n # 1. Apply decay: h *= exp(g)\n h_state = h_state * torch.exp(g_val)\n \n # 2. Delta rule: v_new = v - k^T @ h\n # k is [K], h is [K, V]\n # k @ h = [K] @ [K, V] = [V]\n v_new = v_h - (k_h @ h_state)\n \n # 3. Apply update gate: v_new *= beta\n v_new = v_new * beta_val\n \n # 4. Update state: h += k @ v_new^T (outer product)\n # k.unsqueeze(1) is [K, 1], v_new.unsqueeze(0) is [1, V]\n # [K, 1] @ [1, V] = [K, V]\n h_state = h_state + k_h.unsqueeze(1) @ v_new.unsqueeze(0)\n \n # 5. Compute output: o = q^T @ h\n # q is [K], h is [K, V]\n # q @ h = [K] @ [K, V] = [V]\n output[b_idx, h_idx] = q_h @ h_state\n \n # Store updated state (transpose back to [V, K])\n new_state[b_idx, h_idx] = h_state.transpose(-1, -2)\n \n # Convert output to bfloat16 and add seq_len dimension\n output = output.unsqueeze(1).to(torch.bfloat16) # [B, 1, num_v_heads, V]\n \n return {\"output\": output, \"new_state\": new_state}" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There appears to be a bug in the decay logic of this reference implementation. The state is updated with h_state = h_state * torch.exp(g_val), which implies g_val should be the log of the decay factor.
However, g is calculated as torch.exp(-torch.exp(A_log) * softplus_x), which is a value between 0 and 1. Applying torch.exp() to this results in a value between 1 and e, causing an amplification of the state instead of decay. This is likely incorrect and can lead to numerical instability.
The previous implementation used g directly as a multiplier, which was correct as a decay factor. To fix this while matching the new update rule h *= exp(g), the calculation of g should yield the log of the decay factor.
The fix is to remove the outer torch.exp in the calculation of g. I've also updated the docstring for clarity.
| "reference": "import math\nimport torch\nimport torch.nn.functional as F\n\n\n@torch.no_grad()\ndef run(q, k, v, state, A_log, a, dt_bias, b, scale):\n \"\"\"\n Gated Delta Net decode reference implementation matching FlashInfer kernel.\n \n State layout: [B, H, V, K] (k-last, V dimension before K)\n \n Gate computation:\n g = exp(-exp(A_log) * softplus(a + dt_bias))\n beta = sigmoid(b)\n \n Delta rule update (following FlashInfer kernel exactly):\n 1. Apply L2 normalization to q and k\n 2. h *= exp(g) # Apply decay to state\n 3. v_new = v - k^T @ h # Delta rule\n 4. v_new *= beta # Apply update gate\n 5. h += k @ v_new^T # Update state (outer product)\n 6. o = scale * q^T @ h # Compute output\n \"\"\"\n B, T, num_q_heads, K = q.shape\n _, _, num_k_heads, _ = k.shape\n _, _, num_v_heads, V = v.shape\n num_heads = num_v_heads\n device = q.device\n dtype = torch.float32\n \n assert num_q_heads == 16\n assert num_k_heads == 16\n assert num_v_heads == 32\n assert K == 128 and V == 128\n assert T == 1\n \n # Default scale\n if scale is None or scale == 0.0:\n scale = 1.0 / math.sqrt(K)\n \n # Convert to float32 for computation\n A_log = A_log.to(dtype).to(device)\n a = a.to(dtype).to(device)\n dt_bias = dt_bias.to(dtype).to(device)\n b = b.to(dtype).to(device)\n \n # Compute gating values\n # g = exp(-exp(A_log) * softplus(a + dt_bias))\n x = a + dt_bias # [B, 1, HV]\n softplus_x = F.softplus(x)\n g = torch.exp(-torch.exp(A_log) * softplus_x) # [B, 1, HV]\n \n # beta = sigmoid(b)\n beta = torch.sigmoid(b) # [B, 1, HV]\n \n # Process tensors\n q_f32 = q.squeeze(1).float() # [B, num_q_heads, K]\n k_f32 = k.squeeze(1).float() # [B, num_k_heads, K]\n v_f32 = v.squeeze(1).float() # [B, num_v_heads, V]\n g_f32 = g.squeeze(1) # [B, num_v_heads]\n beta_f32 = beta.squeeze(1) # [B, num_v_heads]\n \n if state is not None:\n state_f32 = state.float() # [B, num_v_heads, V, K]\n else:\n state_f32 = torch.zeros(B, num_v_heads, V, K, dtype=dtype, device=device)\n \n # Expand heads for GVA (num_v_heads > num_q_heads)\n q_exp = q_f32.repeat_interleave(num_v_heads // num_q_heads, dim=1) # [B, num_v_heads, K]\n k_exp = k_f32.repeat_interleave(num_v_heads // num_k_heads, dim=1) # [B, num_v_heads, K]\n \n # Apply L2 normalization (matching FlashInfer kernel with use_qk_l2norm=True)\n q_exp = F.normalize(q_exp, p=2.0, dim=-1)\n k_exp = F.normalize(k_exp, p=2.0, dim=-1)\n \n # Apply scale to q\n q_exp = q_exp * scale\n \n # Initialize outputs\n new_state = torch.zeros_like(state_f32) # [B, num_v_heads, V, K]\n output = torch.zeros(B, num_v_heads, V, dtype=dtype, device=device)\n \n # Process each batch and head\n for b_idx in range(B):\n for h_idx in range(num_v_heads):\n q_h = q_exp[b_idx, h_idx] # [K]\n k_h = k_exp[b_idx, h_idx] # [K]\n v_h = v_f32[b_idx, h_idx] # [V]\n # State is [V, K] but we need [K, V] for computation\n h_state = state_f32[b_idx, h_idx].transpose(-1, -2).clone() # [V, K] -> [K, V]\n \n g_val = g_f32[b_idx, h_idx] # scalar\n beta_val = beta_f32[b_idx, h_idx] # scalar\n \n # Delta rule following FlashInfer kernel:\n # 1. Apply decay: h *= exp(g)\n h_state = h_state * torch.exp(g_val)\n \n # 2. Delta rule: v_new = v - k^T @ h\n # k is [K], h is [K, V]\n # k @ h = [K] @ [K, V] = [V]\n v_new = v_h - (k_h @ h_state)\n \n # 3. Apply update gate: v_new *= beta\n v_new = v_new * beta_val\n \n # 4. Update state: h += k @ v_new^T (outer product)\n # k.unsqueeze(1) is [K, 1], v_new.unsqueeze(0) is [1, V]\n # [K, 1] @ [1, V] = [K, V]\n h_state = h_state + k_h.unsqueeze(1) @ v_new.unsqueeze(0)\n \n # 5. Compute output: o = q^T @ h\n # q is [K], h is [K, V]\n # q @ h = [K] @ [K, V] = [V]\n output[b_idx, h_idx] = q_h @ h_state\n \n # Store updated state (transpose back to [V, K])\n new_state[b_idx, h_idx] = h_state.transpose(-1, -2)\n \n # Convert output to bfloat16 and add seq_len dimension\n output = output.unsqueeze(1).to(torch.bfloat16) # [B, 1, num_v_heads, V]\n \n return {\"output\": output, \"new_state\": new_state}" | |
| "reference": "import math\nimport torch\nimport torch.nn.functional as F\n\n\n@torch.no_grad()\ndef run(q, k, v, state, A_log, a, dt_bias, b, scale):\n \"\"\"\n Gated Delta Net decode reference implementation matching FlashInfer kernel.\n \n State layout: [B, H, V, K] (k-last, V dimension before K)\n \n Gate computation:\n g_log = -exp(A_log) * softplus(a + dt_bias)\n beta = sigmoid(b)\n \n Delta rule update (following FlashInfer kernel exactly):\n 1. Apply L2 normalization to q and k\n 2. h *= exp(g_log) # Apply decay to state\n 3. v_new = v - k^T @ h # Delta rule\n 4. v_new *= beta # Apply update gate\n 5. h += k @ v_new^T # Update state (outer product)\n 6. o = scale * q^T @ h # Compute output\n \"\"\"\n B, T, num_q_heads, K = q.shape\n _, _, num_k_heads, _ = k.shape\n _, _, num_v_heads, V = v.shape\n num_heads = num_v_heads\n device = q.device\n dtype = torch.float32\n \n assert num_q_heads == 16\n assert num_k_heads == 16\n assert num_v_heads == 32\n assert K == 128 and V == 128\n assert T == 1\n \n # Default scale\n if scale is None or scale == 0.0:\n scale = 1.0 / math.sqrt(K)\n \n # Convert to float32 for computation\n A_log = A_log.to(dtype).to(device)\n a = a.to(dtype).to(device)\n dt_bias = dt_bias.to(dtype).to(device)\n b = b.to(dtype).to(device)\n \n # Compute gating values\n # g_log = -exp(A_log) * softplus(a + dt_bias)\n x = a + dt_bias # [B, 1, HV]\n softplus_x = F.softplus(x)\n g = -torch.exp(A_log) * softplus_x # [B, 1, HV]\n \n # beta = sigmoid(b)\n beta = torch.sigmoid(b) # [B, 1, HV]\n \n # Process tensors\n q_f32 = q.squeeze(1).float() # [B, num_q_heads, K]\n k_f32 = k.squeeze(1).float() # [B, num_k_heads, K]\n v_f32 = v.squeeze(1).float() # [B, num_v_heads, V]\n g_f32 = g.squeeze(1) # [B, num_v_heads]\n beta_f32 = beta.squeeze(1) # [B, num_v_heads]\n \n if state is not None:\n state_f32 = state.float() # [B, num_v_heads, V, K]\n else:\n state_f32 = torch.zeros(B, num_v_heads, V, K, dtype=dtype, device=device)\n \n # Expand heads for GVA (num_v_heads > num_q_heads)\n q_exp = q_f32.repeat_interleave(num_v_heads // num_q_heads, dim=1) # [B, num_v_heads, K]\n k_exp = k_f32.repeat_interleave(num_v_heads // num_k_heads, dim=1) # [B, num_v_heads, K]\n \n # Apply L2 normalization (matching FlashInfer kernel with use_qk_l2norm=True)\n q_exp = F.normalize(q_exp, p=2.0, dim=-1)\n k_exp = F.normalize(k_exp, p=2.0, dim=-1)\n \n # Apply scale to q\n q_exp = q_exp * scale\n \n # Initialize outputs\n new_state = torch.zeros_like(state_f32) # [B, num_v_heads, V, K]\n output = torch.zeros(B, num_v_heads, V, dtype=dtype, device=device)\n \n # Process each batch and head\n for b_idx in range(B):\n for h_idx in range(num_v_heads):\n q_h = q_exp[b_idx, h_idx] # [K]\n k_h = k_exp[b_idx, h_idx] # [K]\n v_h = v_f32[b_idx, h_idx] # [V]\n # State is [V, K] but we need [K, V] for computation\n h_state = state_f32[b_idx, h_idx].transpose(-1, -2).clone() # [V, K] -> [K, V]\n \n g_val = g_f32[b_idx, h_idx] # scalar\n beta_val = beta_f32[b_idx, h_idx] # scalar\n \n # Delta rule following FlashInfer kernel:\n # 1. Apply decay: h *= exp(g)\n h_state = h_state * torch.exp(g_val)\n \n # 2. Delta rule: v_new = v - k^T @ h\n # k is [K], h is [K, V]\n # k @ h = [K] @ [K, V] = [V]\n v_new = v_h - (k_h @ h_state)\n \n # 3. Apply update gate: v_new *= beta\n v_new = v_new * beta_val\n \n # 4. Update state: h += k @ v_new^T (outer product)\n # k.unsqueeze(1) is [K, 1], v_new.unsqueeze(0) is [1, V]\n # [K, 1] @ [1, V] = [K, V]\n h_state = h_state + k_h.unsqueeze(1) @ v_new.unsqueeze(0)\n \n # 5. Compute output: o = q^T @ h\n # q is [K], h is [K, V]\n # q @ h = [K] @ [K, V] = [V]\n output[b_idx, h_idx] = q_h @ h_state\n \n # Store updated state (transpose back to [V, K])\n new_state[b_idx, h_idx] = h_state.transpose(-1, -2)\n \n # Convert output to bfloat16 and add seq_len dimension\n output = output.unsqueeze(1).to(torch.bfloat16) # [B, 1, num_v_heads, V]\n \n return {\"output\": output, \"new_state\": new_state}" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated. Please review.
flashinfer_trace/tests/references/test_gdn_decode_qk16_v32_d128_k_last.py
Show resolved
Hide resolved
flashinfer_trace/tests/references/test_gdn_decode_qk16_v32_d128_k_last.py
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
flashinfer_trace/tests/references/test_gdn_decode_qk16_v32_d128_k_last.py (1)
120-223: Maketest_correctnessfail when comparisons fail.Right now the test returns
Falsebut doesn’t assert, so pytest will still pass even on mismatches. This undermines the test signal. Please add assertions (or usetorch.testing.assert_close) so failures are surfaced.🔧 Proposed fix
- if output_close and state_close: - print(f"\n✓ PASSED (atol={atol}, rtol={rtol})") - return True - else: - print(f"\n✗ FAILED (atol={atol}, rtol={rtol})") - return False + if output_close and state_close: + print(f"\n✓ PASSED (atol={atol}, rtol={rtol})") + else: + print(f"\n✗ FAILED (atol={atol}, rtol={rtol})") + assert output_close, "Output mismatch exceeds tolerance" + assert state_close, "State mismatch exceeds tolerance"
🤖 Fix all issues with AI agents
In `@flashinfer_trace/tests/references/test_gdn_prefill_qk16_v32_d128_k_last.py`:
- Around line 251-257: The three print statements that use f-strings but pass
only a single preformatted string (e.g., the lines printing "With initial
state:" and the two prints that format max_abs_diff_o, mean_abs_diff_o,
cos_sim_o and max_abs_diff_s, mean_abs_diff_s, cos_sim_s) should drop the
unnecessary leading "f" to satisfy Ruff F541; find the print calls that
reference variables max_abs_diff_o, mean_abs_diff_o, cos_sim_o, max_abs_diff_s,
mean_abs_diff_s, cos_sim_s and remove the f-prefix so they are plain string
literals or regular formatted strings using .format if needed.
🧹 Nitpick comments (1)
flashinfer_trace/tests/references/test_gdn_prefill_qk16_v32_d128_k_last.py (1)
171-173: Consider tightening or parameterizing the max-abs tolerance.A fixed
atol = 1.0on max error is quite permissive and could hide regressions when output magnitudes are modest. It may be safer to tie tolerances to expected scales (e.g., use relative thresholds or per-metric bounds) or centralize them for easier calibration.Also applies to: 259-262, 339-341
| print(f"\nWith initial state:") | ||
| print(f" Output max_err={output_max_err:.6f}") | ||
| print(f" State max_err={state_max_err:.6f}") | ||
| print( | ||
| f" Output: max_abs={max_abs_diff_o:.6e}, mean_abs={mean_abs_diff_o:.6e}, cos_sim={cos_sim_o:.6f}" | ||
| ) | ||
| print( | ||
| f" State: max_abs={max_abs_diff_s:.6e}, mean_abs={mean_abs_diff_s:.6e}, cos_sim={cos_sim_s:.6f}" | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove unnecessary f-string prefix.
Ruff flags this as F541; it’s a no-op but can fail linting.
🔧 Proposed fix
- print(f"\nWith initial state:")
+ print("\nWith initial state:")Based on static analysis hints, consider applying this change.
🧰 Tools
🪛 Ruff (0.14.13)
251-251: f-string without any placeholders
Remove extraneous f prefix
(F541)
🤖 Prompt for AI Agents
In `@flashinfer_trace/tests/references/test_gdn_prefill_qk16_v32_d128_k_last.py`
around lines 251 - 257, The three print statements that use f-strings but pass
only a single preformatted string (e.g., the lines printing "With initial
state:" and the two prints that format max_abs_diff_o, mean_abs_diff_o,
cos_sim_o and max_abs_diff_s, mean_abs_diff_s, cos_sim_s) should drop the
unnecessary leading "f" to satisfy Ruff F541; find the print calls that
reference variables max_abs_diff_o, mean_abs_diff_o, cos_sim_o, max_abs_diff_s,
mean_abs_diff_s, cos_sim_s and remove the f-prefix so they are plain string
literals or regular formatted strings using .format if needed.
|
closed with #164 |
Fix GDB decode reference run() and test
Testing GDN Decode K-Last Reference Implementation
Loading definition from: flashinfer_trace/definitions/gdn/gdn_decode_qk16_v32_d128_k_last.json
============================================================
Testing GDN decode k-last, batch_size=1
Running reference implementation from definition...
Running FlashInfer kernel...
Comparing outputs...
Output comparison:
Max absolute difference: 5.035400e-02
Max relative difference: 8.360472e+02
Mean absolute difference: 9.030354e-03
Mean relative difference: 2.054162e+00
Cosine similarity: 0.965112
MSE: 1.296871e-04
State comparison:
Max absolute difference: 5.673413e-02
Max relative difference: 9.196285e+04
Mean absolute difference: 9.160103e-03
Mean relative difference: 2.089704e+00
Cosine similarity: 0.971113
MSE: 1.318440e-04
✓ PASSED (atol=0.07, rtol=0.07)
============================================================
Testing GDN decode k-last, batch_size=4
Running reference implementation from definition...
Running FlashInfer kernel...
Comparing outputs...
Output comparison:
Max absolute difference: 5.035400e-02
Max relative difference: 1.578677e+04
Mean absolute difference: 9.103831e-03
Mean relative difference: 3.202836e+00
Cosine similarity: 0.964996
MSE: 1.303982e-04
State comparison:
Max absolute difference: 5.673413e-02
Max relative difference: 1.463894e+05
Mean absolute difference: 9.160452e-03
Mean relative difference: 2.109787e+00
Cosine similarity: 0.970537
MSE: 1.318912e-04
✓ PASSED (atol=0.07, rtol=0.07)
============================================================
Testing GDN decode k-last, batch_size=16
Running reference implementation from definition...
Running FlashInfer kernel...
Comparing outputs...
Output comparison:
Max absolute difference: 5.297852e-02
Max relative difference: 2.751208e+04
Mean absolute difference: 9.158911e-03
Mean relative difference: 2.420628e+00
Cosine similarity: 0.973697
MSE: 1.315938e-04
State comparison:
Max absolute difference: 6.240548e-02
Max relative difference: 1.226528e+06
Mean absolute difference: 9.162382e-03
Mean relative difference: 2.398528e+00
Cosine similarity: 0.972964
MSE: 1.319060e-04
✓ PASSED (atol=0.07, rtol=0.07)
============================================================
Testing GDN decode k-last, batch_size=64
Running reference implementation from definition...
Running FlashInfer kernel...
Comparing outputs...
Output comparison:
Max absolute difference: 5.334473e-02
Max relative difference: 6.385589e+04
Mean absolute difference: 9.173501e-03
Mean relative difference: 2.101699e+00
Cosine similarity: 0.974166
MSE: 1.323684e-04
State comparison:
Max absolute difference: 6.595607e-02
Max relative difference: 1.226528e+06
Mean absolute difference: 9.167635e-03
Mean relative difference: 2.299596e+00
Cosine similarity: 0.973976
MSE: 1.320960e-04
✓ PASSED (atol=0.07, rtol=0.07)
============================================================
Testing GDN decode k-last, batch_size=256
Running reference implementation from definition...
Running FlashInfer kernel...
Comparing outputs...
Output comparison:
Max absolute difference: 5.480957e-02
Max relative difference: 7.578678e+04
Mean absolute difference: 9.174975e-03
Mean relative difference: 2.037896e+00
Cosine similarity: 0.974177
MSE: 1.323827e-04
State comparison:
Max absolute difference: 6.654610e-02
Max relative difference: 1.226528e+06
Mean absolute difference: 9.171450e-03
Mean relative difference: 2.308035e+00
Cosine similarity: 0.974194
MSE: 1.322146e-04
✓ PASSED (atol=0.07, rtol=0.07)
============================================================
Summary: 5/5 tests passed
✓ All tests passed!
Summary by CodeRabbit
New Features
Bug Fixes
Tests
✏️ Tip: You can customize this high-level summary in your review settings.