-
Notifications
You must be signed in to change notification settings - Fork 414
Open
Labels
bugSomething isn't workingSomething isn't working
Description
Bug report
Using quantized-aware training does not work when running with a MOE model.
I have run QAT SFT training sucessfully for dense qwen3 models, and I can run SFT on a Qwen MOEs. But when I try to do both at once I get this AssertionError
File "/home/chase-blagden/2025-09-07-19-03-42/maxtext_venv/lib/python3.12/site-packages/qwix/_src/core/pallas.py", line 105, in _update_block_spec
assert v % bv == 0 and s % (v // bv) == 0, f"{v=} {bv=} {s=}"
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: v=1536 bv=1024 s=1536
which is raised from the quix library.
This is the training command:
python3 -m MaxText.train MaxText/configs/sft/qwen235b_sft.yml run_name=qwen235b-reasoning-sft-qat
And the relevant parts of the config:
use_sft: True
sft_train_on_completion_only: True
packing: True
ici_data_parallelism: 1
ici_fsdp_parallelism: -1
model_name: "qwen3-235b-a22b"
use_qwix_quantization: True
quantization: 'fp8'
Logs/Output
E0907 23:12:48.842686 139811109874816 packing.py:200] PackAndBatchOperation is deprecated. Please use lazy_dataset.FirstFitPackIterDataset instead.
I0907 23:12:48.843091 139811109874816 data_loader.py:214] Adding CopyNumPyArrayToSharedMemory MapTransform.
E0907 23:12:49.419215 139811109874816 packing.py:200] PackAndBatchOperation is deprecated. Please use lazy_dataset.FirstFitPackIterDataset instead.
I0907 23:12:49.419427 139811109874816 data_loader.py:214] Adding CopyNumPyArrayToSharedMemory MapTransform.
I0907 23:12:49.489779 139811109874816 qconfig.py:195] [QWIX] module='decoder/layers/self_attention/query' op=dot_general0 rule=0
I0907 23:12:49.513342 139811109874816 qconfig.py:195] [QWIX] module='decoder/layers/self_attention/key' op=dot_general1 rule=0
I0907 23:12:49.525885 139811109874816 qconfig.py:195] [QWIX] module='decoder/layers/self_attention/value' op=dot_general2 rule=0
[t=425.19, MAIN COMMAND] Completed 0/32, slice 0 worker 0 still working...
[t=426.19, MAIN COMMAND] Completed 0/32, slice 0 worker 0 still working...
I0907 23:12:51.115714 139811109874816 qconfig.py:195] [QWIX] module='decoder/layers/self_attention/out' op=dot_general3 rule=0
I0907 23:12:51.164183 139811109874816 qconfig.py:195] [QWIX] module='decoder/layers/moe_block/gate' op=dot_general0 rule=0
Traceback (most recent call last):
File "<frozen runpy>", line 198, in _run_module_as_main
File "<frozen runpy>", line 88, in _run_code
File "/home/chase-blagden/2025-09-07-19-03-42/src/MaxText/train.py", line 512, in <module>
app.run(main)
File "/home/chase-blagden/2025-09-07-19-03-42/maxtext_venv/lib/python3.12/site-packages/absl/app.py", line 316, in run
_run_main(main, args)
File "/home/chase-blagden/2025-09-07-19-03-42/maxtext_venv/lib/python3.12/site-packages/absl/app.py", line 261, in _run_main
sys.exit(main(argv))
^^^^^^^^^^
File "/home/chase-blagden/2025-09-07-19-03-42/src/MaxText/train.py", line 508, in main
run(config, recorder, diagnostic_config)
File "/home/chase-blagden/2025-09-07-19-03-42/src/MaxText/train.py", line 503, in run
train_loop(config, recorder)
File "/home/chase-blagden/2025-09-07-19-03-42/src/MaxText/train.py", line 366, in train_loop
) = train_utils.setup_train_loop(config, recorder)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/chase-blagden/2025-09-07-19-03-42/src/MaxText/train_utils.py", line 223, in setup_train_loop
maxtext_utils.setup_training_state(
File "/home/chase-blagden/2025-09-07-19-03-42/src/MaxText/maxtext_utils.py", line 952, in setup_training_state
return setup_initial_state(
^^^^^^^^^^^^^^^^^^^^
File "/home/chase-blagden/2025-09-07-19-03-42/src/MaxText/maxtext_utils.py", line 991, in setup_initial_state
unboxed_abstract_state, state_mesh_annotations, state_mesh_shardings = get_abstract_state(
^^^^^^^^^^^^^^^^^^^
File "/home/chase-blagden/2025-09-07-19-03-42/src/MaxText/maxtext_utils.py", line 1048, in get_abstract_state
abstract_state = jax.eval_shape(init_state_partial)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/chase-blagden/2025-09-07-19-03-42/src/MaxText/maxtext_utils.py", line 901, in init_initial_state
model_vars = model.init(
^^^^^^^^^^^
File "/home/chase-blagden/2025-09-07-19-03-42/src/MaxText/layers/models.py", line 61, in init
return nn.Module.init(module, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/chase-blagden/2025-09-07-19-03-42/maxtext_venv/lib/python3.12/site-packages/qwix/_src/interception.py", line 155, in wrapper
output = func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/chase-blagden/2025-09-07-19-03-42/src/MaxText/layers/models.py", line 136, in __call__
logits, hidden_state = self.decoder(
^^^^^^^^^^^^^
File "/home/chase-blagden/2025-09-07-19-03-42/src/MaxText/layers/decoders.py", line 759, in __call__
y, _ = self.scan_decoder_layers(
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/chase-blagden/2025-09-07-19-03-42/maxtext_venv/lib/python3.12/site-packages/flax/core/axes_scan.py", line 185, in scan_fn
_, out_pvals, _ = pe.trace_to_jaxpr_nounits(f_flat, in_pvals)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/chase-blagden/2025-09-07-19-03-42/maxtext_venv/lib/python3.12/site-packages/flax/core/axes_scan.py", line 156, in body_fn
broadcast_out, c, ys = fn(broadcast_in, c, *xs)
^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/chase-blagden/2025-09-07-19-03-42/src/MaxText/layers/qwen3.py", line 216, in __call__
mlp_output, load_balance_loss = moe.get_routed_moe(
^^^^^^^^^^^^^^^^^^^
File "/home/chase-blagden/2025-09-07-19-03-42/src/MaxText/layers/nnx_wrappers.py", line 435, in __call__
out = method_fn(module, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/chase-blagden/2025-09-07-19-03-42/src/MaxText/layers/moe.py", line 1615, in __call__
return self.sparse_matmul(
^^^^^^^^^^^^^^^^^^^
File "/home/chase-blagden/2025-09-07-19-03-42/src/MaxText/layers/moe.py", line 1039, in sparse_matmul
return wrapper(inputs, gate_logits, pre_bias_logits, w0_kernel, w1_kernel, wo_kernel, w0_bias, w1_bias, wo_bias)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/chase-blagden/2025-09-07-19-03-42/src/MaxText/layers/moe.py", line 947, in wrapper
layer_w0 = gmm(x, w0, group_sizes, selected_experts)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/chase-blagden/2025-09-07-19-03-42/src/MaxText/layers/moe.py", line 774, in gmm
output = mblx.gmm(
^^^^^^^^^
File "/home/chase-blagden/2025-09-07-19-03-42/src/MaxText/kernels/megablox/gmm.py", line 622, in gmm
out = call_gmm(
^^^^^^^^^
File "/home/chase-blagden/2025-09-07-19-03-42/maxtext_venv/lib/python3.12/site-packages/qwix/_src/core/pallas.py", line 64, in wrapper
in_specs = _update_block_specs_for_qarray(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/chase-blagden/2025-09-07-19-03-42/maxtext_venv/lib/python3.12/site-packages/qwix/_src/core/pallas.py", line 119, in _update_block_specs_for_qarray
return jax.tree.map(_update_block_spec, block_specs, args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/chase-blagden/2025-09-07-19-03-42/maxtext_venv/lib/python3.12/site-packages/jax/_src/tree.py", line 155, in map
return tree_util.tree_map(f, tree, *rest, is_leaf=is_leaf)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/chase-blagden/2025-09-07-19-03-42/maxtext_venv/lib/python3.12/site-packages/qwix/_src/core/pallas.py", line 105, in _update_block_spec
assert v % bv == 0 and s % (v // bv) == 0, f"{v=} {bv=} {s=}"
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: v=1536 bv=1024 s=1536
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
Environment Information
No response
Additional Context
No response
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working