Skip to content

QAT SFT Training Fails with Qwen3 MOE #2305

@chaserogo

Description

@chaserogo

Bug report

Using quantized-aware training does not work when running with a MOE model.

I have run QAT SFT training sucessfully for dense qwen3 models, and I can run SFT on a Qwen MOEs. But when I try to do both at once I get this AssertionError

  File "/home/chase-blagden/2025-09-07-19-03-42/maxtext_venv/lib/python3.12/site-packages/qwix/_src/core/pallas.py", line 105, in _update_block_spec
    assert v % bv == 0 and s % (v // bv) == 0, f"{v=} {bv=} {s=}"
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: v=1536 bv=1024 s=1536

which is raised from the quix library.

This is the training command:

python3 -m MaxText.train MaxText/configs/sft/qwen235b_sft.yml run_name=qwen235b-reasoning-sft-qat

And the relevant parts of the config:

use_sft: True
sft_train_on_completion_only: True
packing: True

ici_data_parallelism: 1
ici_fsdp_parallelism: -1

model_name: "qwen3-235b-a22b"
use_qwix_quantization: True
quantization: 'fp8'

Logs/Output

E0907 23:12:48.842686 139811109874816 packing.py:200] PackAndBatchOperation is deprecated. Please use lazy_dataset.FirstFitPackIterDataset instead.
I0907 23:12:48.843091 139811109874816 data_loader.py:214] Adding CopyNumPyArrayToSharedMemory MapTransform.
E0907 23:12:49.419215 139811109874816 packing.py:200] PackAndBatchOperation is deprecated. Please use lazy_dataset.FirstFitPackIterDataset instead.
I0907 23:12:49.419427 139811109874816 data_loader.py:214] Adding CopyNumPyArrayToSharedMemory MapTransform.
I0907 23:12:49.489779 139811109874816 qconfig.py:195] [QWIX] module='decoder/layers/self_attention/query' op=dot_general0 rule=0
I0907 23:12:49.513342 139811109874816 qconfig.py:195] [QWIX] module='decoder/layers/self_attention/key' op=dot_general1 rule=0
I0907 23:12:49.525885 139811109874816 qconfig.py:195] [QWIX] module='decoder/layers/self_attention/value' op=dot_general2 rule=0
[t=425.19, MAIN COMMAND] Completed 0/32, slice 0 worker 0 still working...
[t=426.19, MAIN COMMAND] Completed 0/32, slice 0 worker 0 still working...
I0907 23:12:51.115714 139811109874816 qconfig.py:195] [QWIX] module='decoder/layers/self_attention/out' op=dot_general3 rule=0
I0907 23:12:51.164183 139811109874816 qconfig.py:195] [QWIX] module='decoder/layers/moe_block/gate' op=dot_general0 rule=0
Traceback (most recent call last):
  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "/home/chase-blagden/2025-09-07-19-03-42/src/MaxText/train.py", line 512, in <module>
    app.run(main)
  File "/home/chase-blagden/2025-09-07-19-03-42/maxtext_venv/lib/python3.12/site-packages/absl/app.py", line 316, in run
    _run_main(main, args)
  File "/home/chase-blagden/2025-09-07-19-03-42/maxtext_venv/lib/python3.12/site-packages/absl/app.py", line 261, in _run_main
    sys.exit(main(argv))
             ^^^^^^^^^^
  File "/home/chase-blagden/2025-09-07-19-03-42/src/MaxText/train.py", line 508, in main
    run(config, recorder, diagnostic_config)
  File "/home/chase-blagden/2025-09-07-19-03-42/src/MaxText/train.py", line 503, in run
    train_loop(config, recorder)
  File "/home/chase-blagden/2025-09-07-19-03-42/src/MaxText/train.py", line 366, in train_loop
    ) = train_utils.setup_train_loop(config, recorder)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/chase-blagden/2025-09-07-19-03-42/src/MaxText/train_utils.py", line 223, in setup_train_loop
    maxtext_utils.setup_training_state(
  File "/home/chase-blagden/2025-09-07-19-03-42/src/MaxText/maxtext_utils.py", line 952, in setup_training_state
    return setup_initial_state(
           ^^^^^^^^^^^^^^^^^^^^
  File "/home/chase-blagden/2025-09-07-19-03-42/src/MaxText/maxtext_utils.py", line 991, in setup_initial_state
    unboxed_abstract_state, state_mesh_annotations, state_mesh_shardings = get_abstract_state(
                                                                           ^^^^^^^^^^^^^^^^^^^
  File "/home/chase-blagden/2025-09-07-19-03-42/src/MaxText/maxtext_utils.py", line 1048, in get_abstract_state
    abstract_state = jax.eval_shape(init_state_partial)
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/chase-blagden/2025-09-07-19-03-42/src/MaxText/maxtext_utils.py", line 901, in init_initial_state
    model_vars = model.init(
                 ^^^^^^^^^^^
  File "/home/chase-blagden/2025-09-07-19-03-42/src/MaxText/layers/models.py", line 61, in init
    return nn.Module.init(module, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/chase-blagden/2025-09-07-19-03-42/maxtext_venv/lib/python3.12/site-packages/qwix/_src/interception.py", line 155, in wrapper
    output = func(*args, **kwargs)
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/chase-blagden/2025-09-07-19-03-42/src/MaxText/layers/models.py", line 136, in __call__
    logits, hidden_state = self.decoder(
                           ^^^^^^^^^^^^^
  File "/home/chase-blagden/2025-09-07-19-03-42/src/MaxText/layers/decoders.py", line 759, in __call__
    y, _ = self.scan_decoder_layers(
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/chase-blagden/2025-09-07-19-03-42/maxtext_venv/lib/python3.12/site-packages/flax/core/axes_scan.py", line 185, in scan_fn
    _, out_pvals, _ = pe.trace_to_jaxpr_nounits(f_flat, in_pvals)
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/chase-blagden/2025-09-07-19-03-42/maxtext_venv/lib/python3.12/site-packages/flax/core/axes_scan.py", line 156, in body_fn
    broadcast_out, c, ys = fn(broadcast_in, c, *xs)
                           ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/chase-blagden/2025-09-07-19-03-42/src/MaxText/layers/qwen3.py", line 216, in __call__
    mlp_output, load_balance_loss = moe.get_routed_moe(
                                    ^^^^^^^^^^^^^^^^^^^
  File "/home/chase-blagden/2025-09-07-19-03-42/src/MaxText/layers/nnx_wrappers.py", line 435, in __call__
    out = method_fn(module, *args, **kwargs)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/chase-blagden/2025-09-07-19-03-42/src/MaxText/layers/moe.py", line 1615, in __call__
    return self.sparse_matmul(
           ^^^^^^^^^^^^^^^^^^^
  File "/home/chase-blagden/2025-09-07-19-03-42/src/MaxText/layers/moe.py", line 1039, in sparse_matmul
    return wrapper(inputs, gate_logits, pre_bias_logits, w0_kernel, w1_kernel, wo_kernel, w0_bias, w1_bias, wo_bias)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/chase-blagden/2025-09-07-19-03-42/src/MaxText/layers/moe.py", line 947, in wrapper
    layer_w0 = gmm(x, w0, group_sizes, selected_experts)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/chase-blagden/2025-09-07-19-03-42/src/MaxText/layers/moe.py", line 774, in gmm
    output = mblx.gmm(
             ^^^^^^^^^
  File "/home/chase-blagden/2025-09-07-19-03-42/src/MaxText/kernels/megablox/gmm.py", line 622, in gmm
    out = call_gmm(
          ^^^^^^^^^
  File "/home/chase-blagden/2025-09-07-19-03-42/maxtext_venv/lib/python3.12/site-packages/qwix/_src/core/pallas.py", line 64, in wrapper
    in_specs = _update_block_specs_for_qarray(
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/chase-blagden/2025-09-07-19-03-42/maxtext_venv/lib/python3.12/site-packages/qwix/_src/core/pallas.py", line 119, in _update_block_specs_for_qarray
    return jax.tree.map(_update_block_spec, block_specs, args)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/chase-blagden/2025-09-07-19-03-42/maxtext_venv/lib/python3.12/site-packages/jax/_src/tree.py", line 155, in map
    return tree_util.tree_map(f, tree, *rest, is_leaf=is_leaf)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/chase-blagden/2025-09-07-19-03-42/maxtext_venv/lib/python3.12/site-packages/qwix/_src/core/pallas.py", line 105, in _update_block_spec
    assert v % bv == 0 and s % (v // bv) == 0, f"{v=} {bv=} {s=}"
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: v=1536 bv=1024 s=1536
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

Environment Information

No response

Additional Context

No response

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions