Skip to content

Add EXAONE 4.0 model support for Inference V2#7853

Merged
tohtana merged 5 commits intodeepspeedai:masterfrom
Bias92:add-exaone4-inference-v2
Feb 17, 2026
Merged

Add EXAONE 4.0 model support for Inference V2#7853
tohtana merged 5 commits intodeepspeedai:masterfrom
Bias92:add-exaone4-inference-v2

Conversation

@Bias92
Copy link
Contributor

@Bias92 Bias92 commented Feb 14, 2026

Summary

Add support for LG AI Research's EXAONE 4.0 model family in DeepSpeed Inference V2.

Closes #7453

Changes

  • New model implementation: deepspeed/inference/v2/model_implementations/exaone4/
    • container.py: Transformer and non-transformer parameter containers
    • model.py: Inference model with post-norm architecture and QK-Norm support
    • policy.py: Inference V2 policy
  • Register EXAONE 4.0 in engine_factory.py and __init__.py

Key architectural differences from Mistral/Llama

  • Post-norm: RMSNorm is applied after attention/MLP outputs (not before), followed by residual addition
  • QK-Norm: Per-head RMSNorm applied to Q and K projections after the QKV linear layer
  • Hybrid attention: 32B model uses 3:1 sliding window/full attention ratio (via layer_types config)

Supported models

Requires transformers >= 4.54.0.

Related

@Bias92 Bias92 force-pushed the add-exaone4-inference-v2 branch from bd52e9d to 400d05a Compare February 14, 2026 16:35
Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 400d05a36a

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Copy link
Collaborator

@tohtana tohtana left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Bias92 Thank you for your contribution!
I left a comment about QKV partitioning. Can you check it?
I would be great if you could validate that the model implementation produces coherent text.

"""
tokens = hidden_states.shape[0]
local_n_heads = self.n_heads // max(self.tp_size, 1)
local_n_heads_kv = self.n_heads_kv // max(self.tp_size, 1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As EXAONE4 has uneven Q/KV heads (GQA), I think this can produce incorrect results. Shouldn't we use these?

  • self.n_heads_q_local instead of self.n_heads // self.tp_size
  • self.n_heads_kv_local instead of self.n_heads_kv // self.tp_size

@Bias92
Copy link
Contributor Author

Bias92 commented Feb 16, 2026

@tohtana Thanks for the review! I've updated the code to use self.n_heads_q_local and self.n_heads_kv_local as suggested.

I'll validate the model with coherent text generation and share the results.

@PKUWZP PKUWZP self-requested a review February 16, 2026 18:08
Bias92 and others added 4 commits February 17, 2026 21:11
Signed-off-by: Bias92 <pewpewplay315@gmail.com>
Signed-off-by: Bias92 <pewpewplay315@gmail.com>
Signed-off-by: Bias92 <pewpewplay315@gmail.com>
Use n_heads_q_local and n_heads_kv_local for GQA compatibility

Signed-off-by: Bias92 <pewpewplay315@gmail.com>
@Bias92 Bias92 force-pushed the add-exaone4-inference-v2 branch from fced31a to 8ece3c1 Compare February 17, 2026 12:12
@Bias92
Copy link
Contributor Author

Bias92 commented Feb 17, 2026

Validation: coherent text generation with EXAONE-4.0-1.2B

Environment: RTX 4060 Ti, PyTorch 2.10, transformers 5.1.0, CUDA 12.8

from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

model = AutoModelForCausalLM.from_pretrained("LGAI-EXAONE/EXAONE-4.0-1.2B", dtype=torch.float16, device_map="auto", trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained("LGAI-EXAONE/EXAONE-4.0-1.2B", trust_remote_code=True)

prompt = '[|user|]\nExplain what DeepSpeed is in 2 sentences.\n[|assistant|]\n'
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
outputs = model.generate(**inputs, max_new_tokens=100, do_sample=False)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))

The model produces coherent, contextually relevant text. Full end-to-end validation with DeepSpeed Inference V2 would require multi-GPU environment — happy to assist with further testing if needed.
image

Copy link
Collaborator

@tohtana tohtana left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Bias92 Thank you for the update! This looks good to me.

@tohtana tohtana enabled auto-merge (squash) February 17, 2026 17:19
@tohtana tohtana merged commit f3a9819 into deepspeedai:master Feb 17, 2026
1 check passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[REQUEST] Add support for EXAONE 4.0 models

2 participants

Comments