Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -45,43 +45,75 @@
],
"inputs": {
"q": {
"shape": ["batch_size", "seq_len", "num_q_heads", "head_size"],
"shape": [
"batch_size",
"seq_len",
"num_q_heads",
"head_size"
],
"dtype": "bfloat16",
"description": "Query tensor for single token decode."
},
"k": {
"shape": ["batch_size", "seq_len", "num_k_heads", "head_size"],
"shape": [
"batch_size",
"seq_len",
"num_k_heads",
"head_size"
],
"dtype": "bfloat16",
"description": "Key tensor for single token decode."
},
"v": {
"shape": ["batch_size", "seq_len", "num_v_heads", "head_size"],
"shape": [
"batch_size",
"seq_len",
"num_v_heads",
"head_size"
],
"dtype": "bfloat16",
"description": "Value tensor for single token decode."
},
"state": {
"shape": ["batch_size", "num_v_heads", "head_size", "head_size"],
"shape": [
"batch_size",
"num_v_heads",
"head_size",
"head_size"
],
"dtype": "float32",
"description": "Recurrent state in k-last layout [B, H, V, K].",
"optional": true
},
"A_log": {
"shape": ["num_v_heads"],
"shape": [
"num_v_heads"
],
"dtype": "float32",
"description": "Log decay parameter (learnable). Used to compute g = exp(-exp(A_log) * softplus(a + dt_bias))."
},
"a": {
"shape": ["batch_size", "seq_len", "num_v_heads"],
"shape": [
"batch_size",
"seq_len",
"num_v_heads"
],
"dtype": "bfloat16",
"description": "Input-dependent decay from projection."
},
"dt_bias": {
"shape": ["num_v_heads"],
"shape": [
"num_v_heads"
],
"dtype": "float32",
"description": "Decay bias (learnable). Added to 'a' before softplus."
},
"b": {
"shape": ["batch_size", "seq_len", "num_v_heads"],
"shape": [
"batch_size",
"seq_len",
"num_v_heads"
],
"dtype": "bfloat16",
"description": "Update gate input from projection. beta = sigmoid(b)."
},
Expand All @@ -93,15 +125,25 @@
},
"outputs": {
"output": {
"shape": ["batch_size", "seq_len", "num_v_heads", "head_size"],
"shape": [
"batch_size",
"seq_len",
"num_v_heads",
"head_size"
],
"dtype": "bfloat16",
"description": "Attention output. Shape follows num_v_heads in GVA mode."
},
"new_state": {
"shape": ["batch_size", "num_v_heads", "head_size", "head_size"],
"shape": [
"batch_size",
"num_v_heads",
"head_size",
"head_size"
],
"dtype": "float32",
"description": "Updated recurrent state in k-last layout [B, H, V, K]."
}
},
"reference": "import math\nimport torch\nimport torch.nn.functional as F\n\n\ndef matmul(a: torch.Tensor, b: torch.Tensor):\n \"\"\"Float32 matmul for numerical stability.\"\"\"\n return a.float() @ b.float()\n\n\[email protected]_grad()\ndef run(q, k, v, state, A_log, a, dt_bias, b, scale):\n \"\"\"\n Gated Delta Net decode reference implementation (k-last layout).\n \n State layout: [B, H, V, K] (k-last, K dimension at the end)\n \n Gate computation:\n g = exp(-exp(A_log) * softplus(a + dt_bias))\n beta = sigmoid(b)\n \n Delta rule update:\n state_new = g * state_old + k^T @ (beta * v + (1-beta) * k @ state_old) - k^T @ (k @ state_old)\n output = scale * q @ state_new\n \"\"\"\n B, T, num_q_heads, K = q.shape\n _, _, num_k_heads, _ = k.shape\n _, _, num_v_heads, V = v.shape\n num_heads = num_v_heads\n device = q.device\n \n assert num_q_heads == 16\n assert num_k_heads == 16\n assert num_v_heads == 32\n assert K == 128 and V == 128\n assert T == 1\n \n if scale is None or scale == 0.0:\n scale = 1.0 / math.sqrt(K)\n \n # Compute g and beta from raw parameters\n x = a.float() + dt_bias.float() # [B, 1, HV]\n g = torch.exp(-torch.exp(A_log.float()) * F.softplus(x)) # [B, 1, HV]\n beta = torch.sigmoid(b.float()) # [B, 1, HV]\n \n q_f32 = q.squeeze(1).float()\n k_f32 = k.squeeze(1).float()\n v_f32 = v.squeeze(1).float()\n g_f32 = g.squeeze(1).float()\n beta_f32 = beta.squeeze(1).float()\n \n if state is not None:\n state_f32 = state.float()\n else:\n state_f32 = torch.zeros(B, num_heads, V, K, dtype=torch.float32, device=device)\n \n q_exp = q_f32.repeat_interleave(num_v_heads // num_q_heads, dim=1)\n k_exp = k_f32.repeat_interleave(num_v_heads // num_k_heads, dim=1)\n \n new_state = torch.zeros_like(state_f32)\n output = torch.zeros(B, num_heads, V, dtype=torch.float32, device=device)\n \n for b_idx in range(B):\n for h_idx in range(num_heads):\n q_h = q_exp[b_idx, h_idx]\n k_h = k_exp[b_idx, h_idx]\n v_h = v_f32[b_idx, h_idx]\n h_state = state_f32[b_idx, h_idx].clone().transpose(-1, -2) # [V,K] -> [K,V]\n g_val = g_f32[b_idx, h_idx]\n beta_val = beta_f32[b_idx, h_idx]\n \n old_state = g_val * h_state\n old_v = k_h @ old_state\n new_v = beta_val * v_h + (1 - beta_val) * old_v\n state_remove = k_h.unsqueeze(1) @ old_v.unsqueeze(0)\n state_update = k_h.unsqueeze(1) @ new_v.unsqueeze(0)\n h_state = old_state - state_remove + state_update\n \n output[b_idx, h_idx] = scale * (q_h @ h_state)\n new_state[b_idx, h_idx] = h_state.transpose(-1, -2) # [K,V] -> [V,K]\n \n output = output.unsqueeze(1).to(torch.bfloat16)\n return {\"output\": output, \"new_state\": new_state}"
"reference": "import math\nimport torch\nimport torch.nn.functional as F\n\n\[email protected]_grad()\ndef run(q, k, v, state, A_log, a, dt_bias, b, scale):\n \"\"\"\n Gated Delta Net decode reference implementation matching FlashInfer kernel.\n \n State layout: [B, H, V, K] (k-last, V dimension before K)\n \n Gate computation:\n g_log = -exp(A_log) * softplus(a + dt_bias)\n beta = sigmoid(b)\n \n Delta rule update (following FlashInfer kernel exactly):\n 1. Apply L2 normalization to q and k\n 2. h *= exp(g_log) # Apply decay to state\n 3. v_new = v - k^T @ h # Delta rule\n 4. v_new *= beta # Apply update gate\n 5. h += k @ v_new^T # Update state (outer product)\n 6. o = scale * q^T @ h # Compute output\n \"\"\"\n B, T, num_q_heads, K = q.shape\n _, _, num_k_heads, _ = k.shape\n _, _, num_v_heads, V = v.shape\n num_heads = num_v_heads\n device = q.device\n dtype = torch.float32\n \n assert num_q_heads == 16\n assert num_k_heads == 16\n assert num_v_heads == 32\n assert K == 128 and V == 128\n assert T == 1\n \n # Default scale\n if scale is None or scale == 0.0:\n scale = 1.0 / math.sqrt(K)\n \n # Convert to float32 for computation\n A_log = A_log.to(dtype).to(device)\n a = a.to(dtype).to(device)\n dt_bias = dt_bias.to(dtype).to(device)\n b = b.to(dtype).to(device)\n \n # Compute gating values\n # g_log = -exp(A_log) * softplus(a + dt_bias)\n x = a + dt_bias # [B, 1, HV]\n softplus_x = F.softplus(x)\n g = -torch.exp(A_log) * softplus_x # [B, 1, HV]\n \n # beta = sigmoid(b)\n beta = torch.sigmoid(b) # [B, 1, HV]\n \n # Process tensors\n q_f32 = q.squeeze(1).float() # [B, num_q_heads, K]\n k_f32 = k.squeeze(1).float() # [B, num_k_heads, K]\n v_f32 = v.squeeze(1).float() # [B, num_v_heads, V]\n g_f32 = g.squeeze(1) # [B, num_v_heads]\n beta_f32 = beta.squeeze(1) # [B, num_v_heads]\n \n if state is not None:\n state_f32 = state.float() # [B, num_v_heads, V, K]\n else:\n state_f32 = torch.zeros(B, num_v_heads, V, K, dtype=dtype, device=device)\n \n # Expand heads for GVA (num_v_heads > num_q_heads)\n q_exp = q_f32.repeat_interleave(num_v_heads // num_q_heads, dim=1) # [B, num_v_heads, K]\n k_exp = k_f32.repeat_interleave(num_v_heads // num_k_heads, dim=1) # [B, num_v_heads, K]\n \n # Apply L2 normalization (matching FlashInfer kernel with use_qk_l2norm=True)\n q_exp = F.normalize(q_exp, p=2.0, dim=-1)\n k_exp = F.normalize(k_exp, p=2.0, dim=-1)\n \n # Apply scale to q\n q_exp = q_exp * scale\n \n # Initialize outputs\n new_state = torch.zeros_like(state_f32) # [B, num_v_heads, V, K]\n output = torch.zeros(B, num_v_heads, V, dtype=dtype, device=device)\n \n # Process each batch and head\n for b_idx in range(B):\n for h_idx in range(num_v_heads):\n q_h = q_exp[b_idx, h_idx] # [K]\n k_h = k_exp[b_idx, h_idx] # [K]\n v_h = v_f32[b_idx, h_idx] # [V]\n # State is [V, K] but we need [K, V] for computation\n h_state = state_f32[b_idx, h_idx].transpose(-1, -2).clone() # [V, K] -> [K, V]\n \n g_val = g_f32[b_idx, h_idx] # scalar\n beta_val = beta_f32[b_idx, h_idx] # scalar\n \n # Delta rule following FlashInfer kernel:\n # 1. Apply decay: h *= exp(g)\n h_state = h_state * torch.exp(g_val)\n \n # 2. Delta rule: v_new = v - k^T @ h\n # k is [K], h is [K, V]\n # k @ h = [K] @ [K, V] = [V]\n v_new = v_h - (k_h @ h_state)\n \n # 3. Apply update gate: v_new *= beta\n v_new = v_new * beta_val\n \n # 4. Update state: h += k @ v_new^T (outer product)\n # k.unsqueeze(1) is [K, 1], v_new.unsqueeze(0) is [1, V]\n # [K, 1] @ [1, V] = [K, V]\n h_state = h_state + k_h.unsqueeze(1) @ v_new.unsqueeze(0)\n \n # 5. Compute output: o = q^T @ h\n # q is [K], h is [K, V]\n # q @ h = [K] @ [K, V] = [V]\n output[b_idx, h_idx] = q_h @ h_state\n \n # Store updated state (transpose back to [V, K])\n new_state[b_idx, h_idx] = h_state.transpose(-1, -2)\n \n # Convert output to bfloat16 and add seq_len dimension\n output = output.unsqueeze(1).to(torch.bfloat16) # [B, 1, num_v_heads, V]\n \n return {\"output\": output, \"new_state\": new_state}"
}
Loading