Dynamic-Routing Transformer with Iterative Refinement
-
🏆 87% Grid Accuracy on Sudoku-Extreme! — Broadcast injection with transformer controller reaches new SOTA
- See training notebook:
notebooks/broadcast_transformer_last.ipynb - Config:
d-model=512,broadcastinjection,H-cycles=4,L-cycles=12,halt-max-steps=8
- See training notebook:
-
🎮 NEW: Interactive Demo — Try the Sudoku solver in your browser!
- 🧩 Interactive 9×9 grid — just click and type
- ⚡ Adjustable "thinking time" (reasoning depth, H/L cycles)
- 🤖 Compare with GPT-4o-mini side-by-side
- 📱 Mobile-friendly — share with anyone!
-
📄 Paper: BERT/GPT with Inner-Thinking Cycles: Iterative Refinement via Dynamic Head Routing
-
🆕 Feature Injection (Next Feature Prediction): Controllers can now inject knowledge into token embeddings
- 6 modes:
none,broadcast,film,depth_token,cross_attn,alpha_gated - Conceptually analogous to next-token prediction, but for feature space
- 6 modes:
-
🆕 New Controllers: Added Mamba (O(N) linear SSM) and Diffusion (iterative denoising) depth controllers
- Mamba: Selective State Space Models for efficient routing (Gu & Dao, 2024)
- Diffusion: Denoising-based routing inspired by DiT (Peebles & Xie, 2023)
-
Sudoku Benchmark (Colab, A100) — train a master-level Sudoku solver:
-
Blog: BERT/GPT with inner thinking cycles (same parameter count) https://medium.com/@eranbt92/bert-gpt-with-inner-thinking-cycles-same-parameter-dc54dbdec61e
-
New: HRM vs PoT comparison report — alignment and differences:
archive/experiments/results_docs/HRM_VS_PoT_REPORT.md
PoH is a modular transformer architecture that adds head-wise routing and iterative refinement to standard transformers. Designed for tasks requiring multi-step reasoning, in simple words "it is a BERT/GPT architecture with inner thinking cycles while keeping number of parameters same” -
This is one PoH Block — a single transformer layer with dynamic head routing. The HRM Controller produces weights α that determine how much each attention head contributes to the output.
flowchart TB
%% ==== Styles ====
classDef head fill:#ffe0c2,stroke:#333,stroke-width:2px,color:#111
classDef ctrlL fill:#d6f5ff,stroke:#1e88e5,stroke-width:2px,color:#111
classDef ctrlH fill:#ffe0e0,stroke:#e53935,stroke-width:2px,color:#111
classDef io fill:#e8f5e9,stroke:#2e7d32,stroke-width:2px,color:#111
classDef mix fill:#fff9c4,stroke:#f9a825,stroke-width:2px,color:#111
classDef state fill:#f5f5f5,stroke:#666,stroke-width:1px,stroke-dasharray:5 5,color:#111
classDef note fill:#fafafa,stroke:#bbb,stroke-width:1px,color:#333
classDef skip fill:#e1f5fe,stroke:#0277bd,stroke-width:2px,stroke-dasharray:3 3,color:#111
%% ==== I/O ====
X[Input tokens or hidden x]:::io
Y[Block output]:::io
%% ==== Heads ====
subgraph SA["Self-Attention Heads"]
direction LR
H1[Head 1]:::head
H2[Head 2]:::head
H3[Head 3]:::head
end
%% ==== HRM Controller ====
subgraph HRM["HRM Pointer Controller"]
direction TB
%% High-level (slow)
subgraph HMOD["High-Level Module f_H (slow)"]
direction TB
zH[(z_H state)]:::state
FH[GRUCell f_H]:::ctrlH
end
%% Low-level (fast)
subgraph LMOD["Low-Level Module f_L (fast)"]
direction TB
zL[(z_L state)]:::state
FL[GRUCell f_L]:::ctrlL
end
%% Router head
RT["Router: Linear(concat(z_L, z_H)) → logits"]:::ctrlL
SM["Softmax / temperature"]:::ctrlL
TK{{Top-k optional}}:::ctrlL
ALPHA["Routing weights α over heads"]:::ctrlL
%% Internal wiring
Xp[x → controller space]:::ctrlH --> FH --> zH
zH --> FL
Xc[x → controller space]:::ctrlL --> FL
FL --> zL
zL --> RT --> SM --> TK --> ALPHA
end
%% ==== Mixer & FFN ====
MIX[Weighted head mix: Σ α_i · head_i]:::mix
FFN[Feed-Forward Network]:::mix
%% ==== Skip Connections ====
SKIP1[Residual: x + attn]:::skip
SKIP2[Residual: x + ffn]:::skip
%% ==== Timing / Notes ====
NOTE1[[f_H updates every T steps; f_L updates each step; optional deep supervision]]:::note
%% ==== Main flow ====
X --> SA
X --> HRM
ALPHA --> MIX
H1 --> MIX
H2 --> MIX
H3 --> MIX
%% Residual path 1: attention
MIX --> SKIP1
X -.-> SKIP1
%% Residual path 2: FFN
SKIP1 --> FFN
FFN --> SKIP2
SKIP1 -.-> SKIP2
SKIP2 --> Y
%% ==== Recurrence across inner iterations ====
Y -. next inner iteration .-> X
zL -. carried each step .-> zL
zH -. updated when t mod T == 0 .-> zH
NOTE1 -.-> HRM
class H1,H2,H3 head
class MIX,FFN mix
class SKIP1,SKIP2 skip
For complex reasoning tasks, we wrap multiple PoH Blocks into a two-timescale architecture inspired by the HRM paper. Each yellow box below contains the PoH Block shown above:
flowchart TB
classDef input fill:#e8f5e9,stroke:#2e7d32,stroke-width:2px,color:#111
classDef state fill:#fff3e0,stroke:#ef6c00,stroke-width:2px,color:#111
classDef fast fill:#e0f7fa,stroke:#00838f,stroke-width:2px,color:#111
classDef slow fill:#ffebee,stroke:#c62828,stroke-width:2px,color:#111
classDef output fill:#f3e5f5,stroke:#7b1fa2,stroke-width:2px,color:#111
classDef poh fill:#fff9c4,stroke:#f9a825,stroke-width:2px,color:#111
INPUT[/"Input Embedding<br/>(scaled by √d_model)"/]:::input
subgraph STATES["Persistent Hidden States"]
ZH["z_H (slow)"]:::state
ZL["z_L (fast)"]:::state
end
INJECT["⊕ z_H + input_emb"]:::input
subgraph LLEVEL["L_level: ReasoningModule (FAST)"]
LPOH["PoH Block × n_layers<br/>(diagram above)"]:::poh
end
subgraph HLEVEL["H_level: ReasoningModule (SLOW)"]
HPOH["PoH Block × n_layers<br/>(diagram above)"]:::poh
end
INNER{{"Inner Loop<br/>L_cycles=8"}}:::fast
OUTER{{"Outer Loop<br/>H_cycles=2"}}:::slow
OUTPUT[/"Output Logits"/]:::output
INPUT --> INJECT
ZH --> INJECT
INJECT --> LLEVEL
ZL --> LLEVEL
LLEVEL --> |"updates"|ZL
ZL --> INNER
INNER --> |"repeat"|LLEVEL
INNER --> |"done"|HLEVEL
ZH --> HLEVEL
ZL --> HLEVEL
HLEVEL --> |"updates"|ZH
ZH --> OUTER
OUTER --> |"repeat"|INNER
OUTER --> |"done"|OUTPUT
Architecture Hierarchy:
┌─────────────────────────────────────────────────────────────────┐
│ HybridHRMBase (Sudoku Solver) │
│ ├── L_level: ReasoningModule (FAST, runs 8× per H_cycle) │
│ │ └── PoH Block × 2 layers ← (Diagram 1️⃣ above) │
│ │ ├── HRM Controller (GRU f_L + f_H → α) │
│ │ ├── Multi-Head Attention (weighted by α) │
│ │ └── SwiGLU FFN + RMSNorm │
│ │ │
│ └── H_level: ReasoningModule (SLOW, runs 2×) │
│ └── PoH Block × 2 layers ← (Diagram 1️⃣ above) │
│ ├── HRM Controller (GRU f_L + f_H → α) │
│ ├── Multi-Head Attention (weighted by α) │
│ └── SwiGLU FFN + RMSNorm │
└─────────────────────────────────────────────────────────────────┘
Total reasoning steps = H_cycles × L_cycles = 2 × 8 = 16
Each step uses PoH Block with dynamic head routing (α weights)
| Component | What it does | Diagram |
|---|---|---|
| PoH Block | Single layer: HRM Controller → α → Weighted MHA → FFN | 1️⃣ above |
| ReasoningModule | Stack of PoH Blocks with shared controller state | Inside 2️⃣ |
| HybridHRMBase | Two-timescale loop: L_level (fast) + H_level (slow) | 2️⃣ above |
Key insight: Diagram 1️⃣ shows what happens at each step (head routing). Diagram 2️⃣ shows how steps are organized into fast/slow timescales for iterative reasoning.
PoT (Pointer-over-Heads Transformer) is built around a simple idea: instead of producing its output in one forward pass, the model thinks through its representations over several refinement steps.
At the start, every token has an initial embedding — a rough guess of what it means in context. PoT doesn’t stop there. It runs the same Transformer stack R times, updating those embeddings after each pass. At every step, the model looks at its current hidden states and asks:
“Given what I know now, how should I use my attention heads to refine this understanding?”
Each iteration slightly reshapes the embedding space. Tokens move, cluster, and separate as their meanings become sharper and more contextually grounded. This process is not about memorizing — it’s about progressive self-correction. By the final iteration, the embeddings encode a richer, more internally consistent view of the sequence.
What makes PoT different is the controller that guides this process. For every token and refinement step, the controller decides how strongly to use each attention head. Some heads specialize in local structure, others in global dependencies or positional cues. By adjusting their mixture across iterations, the model can “compose” reasoning stages — starting with local alignment, then moving toward abstract relations or long-range coherence.
The controller itself operates on two timescales:
A fast component that adapts on every refinement step — reacting immediately to the evolving state of each token.
A slow component that changes less frequently — maintaining a broader contextual plan that influences the fast dynamics.
Together, they form a kind of hierarchical reasoning loop inside the embedding space. Rather than running deeper networks, PoT deepens its thinking process — continuously refining the meaning of each token until the hidden representations stabilize.
In other words:
PoT doesn’t just compute token embeddings — it thinks within them, iteratively reorganizing its own representation space to reach a more coherent internal understanding, without fall to the lower token ids 1-dimension. (which cause a loss of a lot of information found in the embeddings themself)
-
SAEs vs PoT
- SAEs: post-hoc interpretability (discover fixed features/circuits).
- PoT: online computation (learned controller modulates head usage during inference).
-
MoE vs PoT
| Aspect | MoE | PoT |
|---|---|---|
| Routing target | Experts (sub-networks) | Attention heads (within block) |
| Tokens processed | Sparse subset | All tokens |
| Computation | Sparse/efficient | Dense/iterative |
| Routing frequency | Once per forward | Every refinement step (R) |
| Controller | Shallow gate | Two-timescale (f_L fast, f_H slow) |
| Goal | Throughput/scale | Adaptive reasoning |
Reference: Tiny Recursive Models (TRM)
| Aspect | TRM (TinyRecursiveModels) | PoT (Pointer‑over‑Heads) |
|---|---|---|
| Motivation | Compress depth via recursive weight tying | Make attention adaptive via dynamic head routing |
| Iteration type | Reuse the same block output as next input (recurrence) | Iterative refinement with per‑token per‑head routing |
| Routing | None (uniform computation) | α[token, iter, head] changes every refinement step |
| Controller | None (deterministic recurrence) | Hierarchical controller: f_L (fast), f_H (slow, period T) |
| Granularity | Whole‑block | Attention‑head |
| Goal | Parameter efficiency (simulate deep nets) | Adaptive reasoning / dynamic information flow |
Summary: TRM repeats the same computation to act deeper; PoT refines the computation itself to act smarter. While both do multiple passes, TRM’s steps are uniform across tokens with tied weights, whereas PoT learns a two‑timescale controller to modulate each head’s contribution per token and per iteration.
git clone https://github.com/Eran-BA/PoT.git
cd PoT
source venv/bin/activate # Activate virtual environment
pip install pyyaml datasets # For NLI benchmarksKey Components:
- HRM Controller: Two-timescale recurrent modules (from HRM paper)
- f_L (HRM inner loop): Updates every refinement step - fast, reactive processing
- f_H (HRM outer loop): Updates every T steps (T=4) - slow, strategic planning
- Router: Produces per-token, per-head routing weights α from f_L state
- Weighted Mix: Combines attention heads based on α
- Skip Connections: Residual connections around attention and FFN
- Refinement: Model refines representation R times per forward pass (R=12 optimal)
HybridHRMBase # Two-timescale reasoning wrapper
│
├── L_level: ReasoningModule # FAST (8 cycles per H_cycle)
│ └── PoH Block × 2 # See Diagram 1️⃣ above
│ ├─ HRM Controller (GRU f_L + f_H → α)
│ ├─ Multi-Head Attention (weighted by α)
│ └─ SwiGLU FFN + RMSNorm
│
└── H_level: ReasoningModule # SLOW (2 cycles total)
└── PoH Block × 2 # See Diagram 1️⃣ above
├─ HRM Controller (GRU f_L + f_H → α)
├─ Multi-Head Attention (weighted by α)
└─ SwiGLU FFN + RMSNorm
Total reasoning steps: H_cycles × L_cycles = 2 × 8 = 16
-
Head-Wise Routing: Dynamically select or weight attention heads per token
- Soft routing: Differentiable softmax over heads
- Top-k routing: Sparse binary mask (select top-k heads)
- Controlled by HRM inner loop (f_L) - updates every refinement step
-
Iterative Refinement: Apply the stack R times for multi-step reasoning
- R=12 refinement steps optimal (from empirical analysis)
- Optional residual connections across refinement steps (ReZero-style)
- ACT halting for adaptive computation
-
Positional Encoding: Config-switchable (none/absolute/rotary)
"none": Permutation-invariant tasks"absolute": Learned embeddings (GPT-2 style)"rotary": RoPE (LLaMA style, optional)
-
Parameter Parity: 0.27% overhead vs baseline TransformerEncoder
- Lightweight router:
d_model → d_model/4 → n_heads - Optional bias stripping to maintain parity
- Lightweight router:
| 🧩 Input Puzzle | 🤔 | ✅ PoT Solution |
|
→ |
|
Features of the live demo:
- 🧩 Interactive 9×9 Sudoku grid — click cells to edit
- ⚡ Adjustable reasoning depth (see how thinking time affects accuracy)
- 🤖 Compare with GPT-4o-mini side-by-side
- 📋 Copy puzzle prompts to test on Claude, Gemini, ChatGPT
- 📱 Works on mobile!
# Download dataset and train
python experiments/sudoku_poh_benchmark.py --download --model hybrid
# Or run in Colab (A100 recommended)Training details:
- ✅ 1000 extreme Sudoku puzzles with 1000 augmentations each
- ✅ HybridPoHHRM two-timescale reasoning (L_level fast + H_level slow)
- ✅ Constraint loss for Sudoku rule enforcement
- ✅ ~20.8M parameters, trains in ~10 hours on A100
- ✅ 87% grid accuracy on Sudoku-Extreme with broadcast injection (vs. 55% HRM baseline)
Pre-trained model: Available on HuggingFace Hub
See also: experiments/ for archived benchmarks (Maze, NLI, Connect Four)
See: notebooks/ for archived notebooks (Maze, NLI, GPT, Connect Four)
- docs/ - Complete documentation index
- docs/architecture/ - Architecture guides
- docs/guides/ - User & developer guides
- examples/poh_usage.py - 6 usage examples
- examples/synthetic/ - Synthetic task experiments
- Terminology Guide - ESSENTIAL: Official HRM-aligned terminology
- Architecture Summary - Comprehensive architecture guide
- Refinement Iteration Guide - Why R=12 refinement steps is optimal
- HRM vs Refinement - Three nested loops explained
- Quick Start - Copy-paste commands for NLI benchmarks
- Contributing Guide - Development guidelines
- Determinism Guide - Reproducibility best practices
- Running Benchmarks - Full benchmark guide
- Python 3.9+
- PyTorch 2.0+
- NumPy, Matplotlib, Seaborn, SciPy, pandas, pytest, PyYAML
Optional:
rotary-embedding-torch(for RoPE support)datasets(for real NLI benchmarks - Hugging Face)maze-dataset(for maze generation benchmarks)transformers(for BERT baselines in A/B tests)
PoT/
├── src/
│ ├── pot/
│ │ ├── modules/ # PoHBlock, PoHStack, IterRefiner, Positional Encoding
│ │ ├── logging/ # Inner-loop CSV logger
│ │ ├── core/ # HRM controller, losses, metrics
│ │ ├── tasks/ # Task adapters (dependency parsing, NLI)
│ │ ├── utils/ # Training utilities
│ │ └── models/ # High-level models (PoHGPT, BERT baselines)
│ └── models/ # Legacy model definitions
├── scripts/
│ ├── train.py # Unified training entry point
│ ├── plot_results.py # Auto-plotting
│ ├── plot_inner_vs_outer.py # Inner-loop visualization
│ └── make_readme_tables.py # Table generation
├── tests/
│ ├── test_poh_modules.py # 17 tests (all passing)
│ └── test_core.py # Core component tests
├── examples/
│ ├── poh_usage.py # Usage examples
│ ├── poh_gpt_usage.py # GPT-style usage
│ └── synthetic/ # Synthetic tasks (sorting)
├── experiments/
│ ├── configs/ # YAML configs per task (parsing, nli, lm)
│ ├── results/ # Experiment CSVs
│ ├── quick_nli_test.py # 3-min NLI test
│ ├── fair_ab_nli.py # Full synthetic NLI benchmark
│ ├── real_nli_benchmark.py # Real SNLI/MultiNLI benchmark
│ ├── quick_ab_test.py # GPT quick test
│ ├── fair_ab_lm.py # Full GPT benchmark
│ ├── maze_ab_proper_generation.py # Maze solving A/B test (with maze-dataset)
│ ├── maze_scaling_benchmark.py # Maze scaling 8×8→30×30
│ └── connect_four_ab_test.py # Connect Four A/B test
└── docs/
├── architecture/ # Architecture documentation
├── guides/ # User guides
├── tasks/ # Task-specific docs
└── POH_ITERATION_GUIDE.md # Iteration count guide
See CONTRIBUTING.md for development guidelines.
@article{benartzy2025pot,
author = {Eran Ben Artzy},
title = {BERT/GPT with Inner-Thinking Cycles: Iterative Refinement via Dynamic Head Routing},
year = {2025},
publisher = {Zenodo},
doi = {10.5281/zenodo.17959628},
url = {https://doi.org/10.5281/zenodo.17959628}
}@software{benartzy2025poh,
author = {Eran Ben Artzy},
title = {Pointer-over-Heads Transformer: Dynamic Multi-Head Attention with Adaptive Routing},
year = {2025},
publisher = {Zenodo},
doi = {10.5281/zenodo.17959628},
url = {https://doi.org/10.5281/zenodo.17959628}
}Or cite as:
Ben Artzy, E. (2025). BERT/GPT with Inner-Thinking Cycles: Iterative Refinement via Dynamic Head Routing. Zenodo. https://doi.org/10.5281/zenodo.17959628
Apache 2.0 - See LICENSE for details.
This work builds upon several foundational papers:
- Pointer Networks - Vinyals et al. (2015): https://arxiv.org/pdf/1506.03134
- Foundation for attention-based pointer mechanisms
- Hierarchical Reasoning Model (HRM) - Sapient Intelligence, Singapore (2025): https://arxiv.org/abs/2305.19472]
- Two-timescale recurrent controller for routing
- Adaptive Computation Time (ACT) - Graves (2016): https://arxiv.org/abs/1603.08983
- Learned halting for variable computation
- Transformer - Vaswani et al. (2017): https://arxiv.org/abs/1706.03762
- Base architecture
- maze-dataset - Ivanitskiy et al. (2023): https://arxiv.org/abs/2309.10498
- High-quality maze generation library for ML benchmarking
- Built on PyTorch's
MultiheadAttention - Evaluation metrics from Universal Dependencies project
- Maze generation using
maze-datasetlibrary
v2.0.0 - HybridPoHHRM Focus 🎯
- HybridHRM two-timescale reasoning (L_level fast + H_level slow)
- PoH Block with dynamic head routing (HRM Controller → α weights)
- Modular code structure (
src/pot/models/,src/data/,src/training/) - Constraint loss for Sudoku rule enforcement
- 17/17 tests passing
- HybridPoHHRMSolver implementation (~25.8M params)
- Sudoku-Extreme dataset integration (1000 puzzles × 1000 augmentations)
- Colab notebook for A100 training
- Reaching HRM paper target (55% grid accuracy)
- NLI, GPT, Maze, Connect Four benchmarks
The above architecture uses GRU cells for the recurrent controller modules (f_L and f_H). Importantly, these GRUs operate across depth (refinement iterations), not across the input sequence length. Each token maintains its own independent controller state that evolves as the model iterates through reasoning steps.
This is not a fixed design choice — the GRU can be replaced with other recurrent units.
All controllers are accessible via the unified factory:
from src.pot.core import create_controller, CONTROLLER_TYPES
print(CONTROLLER_TYPES)
# ['gru', 'lstm', 'xlstm', 'mingru', 'transformer', 'pot_transformer', 'swin', 'mamba', 'diffusion']
controller = create_controller("mamba", d_model=256, n_heads=8)| Type | Name | Complexity | Description |
|---|---|---|---|
gru |
HRM GRU Controller | O(1) per step | Two-timescale GRU (f_L fast, f_H slow) from HRM paper |
lstm |
LSTM Depth Controller | O(1) per step | Standard LSTM with stronger gating than GRU |
xlstm |
xLSTM Depth Controller | O(1) per step | Extended LSTM with exponential gating (sLSTM variant) |
mingru |
minGRU Depth Controller | O(1) per step | Simplified GRU with single gate, fewer parameters |
transformer |
Causal Depth Transformer | O(t²) over depth | Transformer with causal attention over depth axis |
pot_transformer |
PoT Depth Transformer | O(t²) over depth | Nested PoT with gated MHA internally |
swin |
Swin Depth Controller | O(t²) over depth | Hierarchical controller with local window attention |
mamba |
Mamba Depth Controller | O(N) linear | Selective SSM with input-dependent transitions |
diffusion |
Diffusion Depth Controller | O(1) per step | Iterative denoising inspired by diffusion transformers |
- LSTM — Long Short-Term Memory for stronger gating
- xLSTM — Extended LSTM with exponential gating and matrix memory (Beck et al., 2024)
- Mamba — Selective State Space Models with O(N) linear complexity (Gu & Dao, 2024)
- minGRU — Simplified GRU variant for reduced overhead
The key insight is that any recurrent unit capable of maintaining state across depth (i.e., across iteration steps, not across tokens) can serve as the controller backbone.
The MambaDepthController uses Selective State Space Models (SSMs) for efficient depth-wise routing with O(N) linear time complexity:
from src.pot.core import create_controller
# Mamba controller - O(N) complexity
mamba = create_controller("mamba", d_model=256, n_heads=8, d_state=16)Key features:
- Input-dependent transitions: A, B, C, D matrices depend on input (selective scan)
- Linear complexity: O(N) vs O(N²) for attention-based controllers
- Efficient recurrent processing: Memory-efficient compared to Transformer controllers
Core SSM recurrence:
Δ = softplus(Linear(x)) # Input-dependent discretization
A_bar = exp(Δ * A) # Discretized transition
B_bar = Δ * B(x) # Input-dependent input matrix
h' = A_bar * h + B_bar * x # State update
y = C(x) * h' + D * x # Output
The DiffusionDepthController uses iterative denoising inspired by diffusion transformers (Peebles & Xie, 2023):
from src.pot.core import create_controller
# Diffusion controller - iterative denoising
diffusion = create_controller("diffusion", d_model=256, n_heads=8, noise_schedule="cosine")Key features:
- Iterative denoising: Routing weights evolve through a learned denoising process
- Learned noise schedules: Linear, cosine, or sqrt schedules
- Adaptive LayerNorm (adaLN): Conditioning inspired by DiT (Diffusion Transformers)
- Smooth routing evolution: Temporally coherent across depth steps
Denoising process:
z^(t) = denoise(z^(t-1), sigma(t), x_ctrl)
alpha = softmax(router(z^(t)))
A more expressive alternative is to replace the GRU entirely with a causal Transformer operating over the depth axis. Unlike GRUs which only have implicit access to past states through compressed hidden states, a depth Transformer can explicitly attend to any relevant previous refinement step.
Core idea: At refinement step t, compute routing weights α⁽ᵗ⁾ using only past and current depth states {x⁽⁰⁾, ..., x⁽ᵗ⁾}, then use α⁽ᵗ⁾ to mix attention heads.
Architecture:
┌─────────────────────────────────────────────────────────────┐
│ Causal Depth Transformer Controller │
│ │
│ Input options: │
│ ├── (A) Token-wise: u_i^(t) = W_u·x_i^(t) + pos^(t) │
│ └── (B) Pooled: g^(t) = Pool(X^(t)), u^(t) = W_u·g + pos │
│ │
│ Depth sequence U^(0:t) → DepthTransformer (causal mask) │
│ └── 1-2 layers, d_ctrl = d_model/4, n_heads = 4 │
│ │
│ Output y^(t) → Router → α^(t) routing weights │
│ └── Token-conditioned: logits = W_r·[x_i | y^(t)] │
└─────────────────────────────────────────────────────────────┘
Detailed Architecture Diagram:
CAUSAL DEPTH TRANSFORMER CONTROLLER
═══════════════════════════════════════════════════════════════
INPUT (at refinement step t)
┌─────────────────────────────────────────────────────────────┐
│ X^(t) = [x₁, x₂, ..., xₛ] (S tokens, d_model each) │
└───────────────────────────┬─────────────────────────────────┘
│
▼
┌─────────────────────────────────────────────────────────────┐
│ POOLING + PROJECTION │
│ ┌─────────────┐ ┌─────────────┐ ┌─────────────────┐ │
│ │ LayerNorm │ → │ Mean Pool │ → │ Linear+GELU+Lin │ │
│ │ (X^(t)) │ │ over S │ │ d_model → d_ctrl│ │
│ └─────────────┘ └─────────────┘ └────────┬────────┘ │
│ │ │
│ u^(t) = ctrl_input + pos^(t)
└──────────────────────────────────────────────────┬──────────┘
│
▼
┌─────────────────────────────────────────────────────────────┐
│ DEPTH CACHE (grows with t) │
│ │
│ U = [ u^(0), u^(1), u^(2), ..., u^(t) ] │
│ ↓ ↓ ↓ ↓ │
│ step 0 step 1 step 2 ... current │
│ │
└───────────────────────────┬─────────────────────────────────┘
│
▼
┌─────────────────────────────────────────────────────────────┐
│ CAUSAL DEPTH TRANSFORMER (over depth axis) │
│ │
│ ┌───────────────────────────────────────────────────┐ │
│ │ TransformerEncoder (n_layers=2, n_heads=4) │ │
│ │ │ │
│ │ Attention Mask (CAUSAL over depth): │ │
│ │ ┌─────────────────────────┐ │ │
│ │ │ step: 0 1 2 t │ │ │
│ │ │ 0 [✓] [✗] [✗] [✗] │ ✓ = can attend │ │
│ │ │ 1 [✓] [✓] [✗] [✗] │ ✗ = masked out │ │
│ │ │ 2 [✓] [✓] [✓] [✗] │ │ │
│ │ │ t [✓] [✓] [✓] [✓] │ ← current step │ │
│ │ └─────────────────────────┘ │ │
│ │ │ │
│ │ Each step can only see past steps (causal) │ │
│ └───────────────────────────────────────────────────┘ │
│ │ │
│ ▼ │
│ Y = Transformer(U, causal_mask) │
│ r^(t) = Y[-1] (last position = current step) │
│ │
└───────────────────────────┬─────────────────────────────────┘
│
▼
┌─────────────────────────────────────────────────────────────┐
│ TOKEN-CONDITIONED ROUTER │
│ │
│ For each token i: │
│ ┌────────────────────────────────────────────────────┐ │
│ │ r^(t) ──┐ │ │
│ │ ├──→ concat ──→ [xᵢ | r^(t)] │ │
│ │ xᵢ^(t) ─┘ │ │ │
│ │ ▼ │ │
│ │ ┌────────────────────┐ │ │
│ │ │ Router MLP │ │ │
│ │ │ LN → Linear+GELU │ │ │
│ │ │ → Linear → logits │ │ │
│ │ └─────────┬──────────┘ │ │
│ │ │ │ │
│ │ ▼ │ │
│ │ logitsᵢ = [h₁, h₂, ..., hₕ] │ │
│ │ │ │ │
│ │ ┌─────────▼──────────┐ │ │
│ │ │ Softmax(logits/τ) │ │ │
│ │ │ (temperature τ) │ │ │
│ │ └─────────┬──────────┘ │ │
│ │ │ │ │
│ │ ┌─────────▼──────────┐ │ │
│ │ │ Optional: Top-k │ │ │
│ │ │ sparsification │ │ │
│ │ └─────────┬──────────┘ │ │
│ │ ▼ │ │
│ │ αᵢ^(t) = [α₁, α₂, ..., αₕ] │ │
│ └────────────────────────────────────────────────────┘ │
│ │
└───────────────────────────┬─────────────────────────────────┘
│
▼
┌─────────────────────────────────────────────────────────────┐
│ OUTPUT │
│ │
│ α^(t) = [α₁^(t), α₂^(t), ..., αₛ^(t)] │
│ │
│ Shape: [Batch, Sequence, Heads] │
│ │
│ Each αᵢ^(t) sums to 1.0 (softmax over heads) │
│ Used to weight attention head outputs in PoH Block │
│ │
└─────────────────────────────────────────────────────────────┘
KEY DIFFERENCES FROM GRU CONTROLLER:
═══════════════════════════════════════════════════════════════
┌─────────────────────────┬───────────────────────────────────┐
│ GRU Controller │ Causal Depth Transformer │
├─────────────────────────┼───────────────────────────────────┤
│ h^(t) = GRU(x, h^(t-1))│ y^(t) = Attn(U^(0:t), causal) │
│ │ │
│ Compressed history │ Explicit attention to ALL │
│ in fixed-size hidden │ previous depth steps │
│ state h │ │
│ │ │
│ O(1) memory per step │ O(t) memory (cache grows) │
│ │ │
│ Implicit past access │ Explicit: step 10 can directly │
│ (through h) │ attend to step 3's features │
│ │ │
│ Sequential processing │ Parallel training possible │
│ (can't parallelize) │ (with causal mask) │
└─────────────────────────┴───────────────────────────────────┘
Empirical Results (Sudoku-Extreme benchmark):
| Epoch | Transformer Cell | GRU Cell | Transformer Grid | GRU Grid |
|---|---|---|---|---|
| 30 | 58.0% | ~55% | 0.0% | 0.0% |
| 47 | 61.7% | 62.7% | 0.0% | 0.1% |
| 62 | 63.6% | ~64% | 0.1% | ~0.8% |
| 70 | 64.6% | 65.2% | 0.6% | 1.3% |
Verdict: The Causal Depth Transformer achieves comparable performance to GRU, validating this alternative architecture for depth-wise control.
Advantages:
- Explicit attention over depth history — step 10 can directly reference step 3
- Parallel training — causal mask allows batched forward pass over all K steps
- Better gradient flow — residual connections avoid vanishing gradients
- Interpretability — attention weights show which past reasoning steps matter
Implementation choices:
- Recompute prefix (simple): At step t, run DepthTx on [0..t]. O(t²) across steps, fine for K ≤ 16.
- KV-cache (fast): Cache K/V for each layer over previous depth steps. O(t) per step.
Recommended starting point:
- Pooled controller input (Option B) for efficiency
- Token-conditioned routing for per-token expressivity
- 1-2 layer Transformer with d_ctrl = 128-256
- Recompute prefix first, add KV-cache later if needed
Integration with HRM two-timescale design:
- Keep f_H (slow) as GRU — updates rarely, doesn't need long-range depth attention
- Replace f_L (fast) with Depth Transformer — updates every step, benefits most
Different controller choices offer trade-offs in:
- Memory capacity and gradient flow
- Computational efficiency
- Expressiveness of the routing dynamics
Beyond routing-only controllers, PoT now supports feature injection — injecting controller knowledge back into token embeddings, not just into routing weights α.
Conceptual insight: This is next feature injection prediction. The controller observes the current reasoning state and "predicts" what features should be injected into tokens for the next reasoning step.
| Paradigm | Prediction Target |
|---|---|
| LLMs (GPT) | Next token |
| Diffusion | Next noise level state |
| Feature Injection | Next feature state to inject |
The controller produces a feature vector r that shapes what knowledge flows into the token embeddings before the next iteration — essentially predicting the feature context needed for continued reasoning.
from src.pot.core import INJECTION_MODES
print(INJECTION_MODES)
# ['none', 'broadcast', 'film', 'depth_token', 'cross_attn']
# Example: Solver with FiLM injection
model = HybridPoHHRMSolver(
d_model=512, n_heads=8,
injection_mode="film", # NEW
)| Mode | Formula | Description | Use Case |
|---|---|---|---|
none |
Pass-through | Routing-only (default, backward compatible) | Baseline, interpretability |
broadcast |
x + gate * broadcast(r·W) |
Gated broadcast to all tokens | Quick improvement, global context |
film |
γ * x + β |
FiLM modulation (scale/shift) | Stable conditioning, modulation |
depth_token |
Prepend z token |
Attention-based knowledge sharing | Transformer-native, selective |
cross_attn |
x + CrossAttn(x, memory) |
Cross-attention to depth memory bank | Most expressive, history access |
alpha_gated |
x + α_agg * (r·W) |
Alpha-modulated broadcast | Coupled routing + injection |
When to use each mode:
none— When you want pure head routing without feature injection (baseline, interpretability)broadcast— Quick way to inject global depth context into all tokens (risk: can overpower tokens if not gated)film— Stable modulation without adding new information (acts like "context conditioning")depth_token— Transformer-native approach where tokens can selectively attend to depth knowledgecross_attn— Most expressive; different tokens can retrieve different past depth informationalpha_gated— Coupled approach where injection strength follows routing confidence (coherent α + injection)
Example with alpha-gated injection:
# Alpha-gated: injection strength modulated by routing weights
model = HybridPoHHRMSolver(
d_model=512, n_heads=8,
injection_mode="alpha_gated",
injection_kwargs={
"alpha_aggregation": "entropy", # "mean", "max", or "entropy"
"use_learned_gate": True, # Combine with learned gate
},
)Example with cross-attention memory:
model = HybridPoHHRMSolver(
d_model=512, n_heads=8,
injection_mode="cross_attn",
injection_kwargs={"memory_size": 16, "n_heads": 4},
)flowchart TB
%% ==== Styles ====
classDef input fill:#e8f5e9,stroke:#2e7d32,stroke-width:2px,color:#111
classDef pool fill:#e3f2fd,stroke:#1565c0,stroke-width:2px,color:#111
classDef cache fill:#fff3e0,stroke:#ef6c00,stroke-width:2px,color:#111
classDef transformer fill:#f3e5f5,stroke:#7b1fa2,stroke-width:2px,color:#111
classDef router fill:#fff9c4,stroke:#f9a825,stroke-width:2px,color:#111
classDef output fill:#ffebee,stroke:#c62828,stroke-width:2px,color:#111
classDef note fill:#fafafa,stroke:#bbb,stroke-width:1px,color:#333
%% ==== Input ====
X["X^(t) = Token representations<br/>[B, S, d_model]"]:::input
%% ==== Pooling Stage ====
subgraph POOL["Pooling + Projection"]
direction TB
LN["LayerNorm"]:::pool
MEAN["Mean Pool<br/>over S tokens"]:::pool
MLP["MLP: d_model → d_ctrl"]:::pool
POS["+ depth_pos^(t)"]:::pool
UT["u^(t) = controller input"]:::pool
LN --> MEAN --> MLP --> POS --> UT
end
%% ==== Depth Cache ====
subgraph CACHE["Depth Cache (grows with t)"]
direction LR
U0["u^(0)"]:::cache
U1["u^(1)"]:::cache
U2["u^(2)"]:::cache
DOTS["..."]:::cache
UCUR["u^(t)"]:::cache
end
%% ==== Causal Transformer ====
subgraph DEPTHTX["Causal Depth Transformer"]
direction TB
STACK["TransformerEncoder<br/>n_layers=2, n_heads=4"]:::transformer
MASK["Causal Mask:<br/>step t sees only 0..t"]:::note
LASTOUT["r^(t) = output at position t"]:::transformer
STACK --> LASTOUT
MASK -.-> STACK
end
%% ==== Router ====
subgraph ROUTER["Token-Conditioned Router"]
direction TB
CONCAT["concat(x_i, r^(t))<br/>for each token i"]:::router
RMLP["Router MLP:<br/>LN → Linear → GELU → Linear"]:::router
SOFTMAX["Softmax(logits / τ)"]:::router
TOPK{{"Optional: Top-k"}}:::router
ALPHA["α^(t) = routing weights<br/>[B, S, H]"]:::router
CONCAT --> RMLP --> SOFTMAX --> TOPK --> ALPHA
end
%% ==== Output ====
OUT["α^(t) used in PoH Block<br/>to weight attention heads"]:::output
%% ==== Connections ====
X --> POOL
UT --> CACHE
UCUR --> DEPTHTX
U0 --> DEPTHTX
U1 --> DEPTHTX
U2 --> DEPTHTX
LASTOUT --> ROUTER
X -.->|"token features"| CONCAT
ALPHA --> OUT
%% ==== Recurrence ====
OUT -. "next refinement step t+1" .-> X
How it works:
- Pool tokens → compress X^(t) to single vector u^(t)
- Append to cache → U = [u^(0), u^(1), ..., u^(t)]
- Causal attention → Transformer attends only to past/current steps
- Route per-token → combine r^(t) with each token x_i to produce α_i^(t)
- Apply to PoH Block → α weights mix attention head outputs
high level-
- PoT as a 1 processor unit. can pull calculations from it, based on a learnable shared embedding layer.
- Connect PoT to other units, and look at it as a fixed calculation processing unit that transferred outputs or embeddings to different units (like Encoder-Decoder architecture, where PoT function as the Encoder it self)
- connect PoT to cache memory units.
- Robotics in a close and finite domain & range envirement.
- Robotics in a close and infinite domain & range envirement.
- In general, Robotics in a open/close in/finite domain & range envirment. where Robot can be a tool, whether it is computer process, or phyisical unit in the world.
Questions? Open an issue or see QUICK_START.md for copy-paste commands!