Skip to content

Conversation

@JacobHelwig
Copy link
Collaborator

@JacobHelwig JacobHelwig commented Jan 12, 2026

What does this PR do?

Adds on-policy distillation to FSDP engine with top-k distillation losses and KL estimator distillation losses.

  • top-k distillation losses: forward KL -> Jensen–Shannon divergence -> reverse KL using top-k logits. Support for full logits to be added in a later PR.
  • KL estimator distillation losses: estimates KL using only the log prob for the sampled token via the same estimators used by the reference model (e.g., k1, k3)

Test

Tested with examples/on_policy_distillation_trainer/run_qwen_gsmk8k.sh

Results are for the k3 loss estimator (only computed for the sampled token) and the top-k JSD loss. The top-k loss converged faster than k3 and generalized better, but training was more unstable.

GSM8K eval acc

image

GSM8K train acc

image

Distillation loss

image

Design & Code Changes

  • Distillation config extends the reference model config
  • Distillation model is initialized in place of the reference model
  • For top-k losses, top-k log probs are gathered at different stages from the student (old log prob and actor update) and teacher (ref log prob) in a way to keep same number model of calls compared to GRPO w/ ref model

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request appears to be a work-in-progress for adding on-policy distillation support. The only change present is a minor whitespace modification in the README.md file. As this is a stylistic change with no functional impact, and I am configured to only report issues of high or critical severity, I have no specific comments on the current state of the pull request. I look forward to reviewing more substantial changes as they are added.

@JacobHelwig JacobHelwig mentioned this pull request Jan 12, 2026
30 tasks
@JacobHelwig JacobHelwig force-pushed the jhelwig/onPolicyDistillation branch 2 times, most recently from 4245789 to 840aca3 Compare January 13, 2026 01:24
@yubin1991
Copy link

why the reward is so small (almost 0) before step 40 ?

@JacobHelwig
Copy link
Collaborator Author

why the reward is so small (almost 0) before step 40 ?

Training only explicitly optimizes the distillation loss, not rewards:

Note that we can do OPD + RL rewards, although to isolate effects of OPD, RL rewards aren't used in these experiments.

Any increase in the logged rewards=GSM8k accuracy are an indirect result of minimizing the distillation loss.

In this case, the reason that the base model has Pass@1~=0 is because the default GSM8k answer formatting (#### 42) is OOD for the model. The base model is answering the questions correctly, but using incorrect formatting, so none of the answers can be parsed. The base model can be evaluated using a reward function that is more lenient on formatting by adding the following to the script:

...
    reward_model.reward_manager=remote \
    custom_reward_function.path=tests/experimental/reward_loop/reward_fn.py \
    custom_reward_function.name=compute_score_math_verify \
    trainer.val_only=True

The results are:

(TaskRunner pid=904198) ("Initial validation metrics: {'val-aux/openai/gsm8k/reward/mean@1': "
(TaskRunner pid=904198)  "np.float64(0.31766489764973466), 'val-core/openai/gsm8k/acc/mean@1': "
(TaskRunner pid=904198)  "np.float64(0.31766489764973466), 'val-aux/num_turns/min': np.int32(2), "
(TaskRunner pid=904198)  "'val-aux/num_turns/max': np.int32(2), 'val-aux/num_turns/mean': "
(TaskRunner pid=904198)  'np.float64(2.0)}')

The formatting is only a few tokens, so it does not contribute much to the distillation loss. The distillation loss initially focuses on minimizing other discrepancies between the teacher and student distributions before targeting formatting, which is why early steps of training show 0% accuracy under the stricter parser.

@JacobHelwig JacobHelwig force-pushed the jhelwig/onPolicyDistillation branch 2 times, most recently from 216f0a7 to eaca4e1 Compare January 16, 2026 03:39
@JacobHelwig JacobHelwig force-pushed the jhelwig/onPolicyDistillation branch 3 times, most recently from 9f3cb9d to d0d0d55 Compare January 17, 2026 18:40
@JacobHelwig
Copy link
Collaborator Author

/gemini review

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces on-policy distillation (OPD) as a new feature, which is a significant addition to the library. The implementation is comprehensive, adding new configurations, loss functions, and utility functions, and integrating them into the existing PPO trainer pipeline. My review has identified several critical and high-severity issues related to correctness, robustness, and efficiency that should be addressed. These include a bug that could cause a crash when not using use_remove_padding, an invalid default configuration value, and incorrect handling of zero-length prompts. Additionally, there are some type hint inaccuracies and unnecessary tensor cloning that impact code quality and performance.

def compute_topk_distillation_inputs(logits: torch.Tensor, batch: TensorDict, cu_seqlens: torch.Tensor, config: DistillationConfig):
"""TODO: Docstring"""
# Gather inputs for top-k distillation losses.
logits = logits.squeeze(0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The use of logits.squeeze(0) assumes that the logits tensor always has a batch dimension of size 1. This holds true when use_remove_padding is enabled, as the input is reshaped to (1, total_tokens, ...). However, when use_remove_padding is disabled, logits will have a shape of (batch_size, seq_len, vocab_size). If batch_size is greater than 1, squeeze(0) will raise an error. This will prevent distillation from working correctly in this configuration.

A more robust approach would be to reshape the tensor to flatten the batch and sequence dimensions if they exist, rather than squeezing a specific dimension.

sequence_lens = prompt_lens + response_lens
sequence_offsets = sequence_lens.cumsum(dim=0)
assert sequence_offsets[-1].item() == values.shape[0]
assert not prompt_lens.eq(0).any(), f"seq_offset - resp_len - 1 assumes prompt_len > 0. Got {prompt_lens}"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The assertion assert not prompt_lens.eq(0).any() will cause a crash if any of the prompts in a batch have a length of 0. While this might be an uncommon case, the code should handle it gracefully instead of crashing. The slicing logic values[seq_offset - resp_len - 1 : seq_offset - 1] relies on prompt_len > 0.

@JacobHelwig JacobHelwig changed the title [WIP] On-Policy Distillation [fsdp,trainer,algo] feat: On-Policy Distillation Jan 18, 2026
@JacobHelwig
Copy link
Collaborator Author

/gemini review

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces on-policy distillation capabilities, including top-k and KL estimator-based distillation losses. The changes are well-structured, with new configurations, loss functions, and utility modules. The integration into the existing PPO trainer and FSDP engine seems correct. I've identified one area for improvement regarding the use of __post_init__ in a class with multiple inheritance, which could be made more robust.

@JacobHelwig
Copy link
Collaborator Author

/gemini review

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a significant new feature: on-policy distillation. The implementation is well-designed, with a clear separation of concerns between configuration, data processing stages, and loss computation. The use of a registry for distillation losses is a good practice for extensibility. I appreciate the addition of unit tests for the new utility functions and the inclusion of an example script. The new validation checks in the configuration are also a great addition to prevent misuse. I have one high-severity finding in the example script where a variable is overwritten, which could lead to incorrect experiment execution.

Comment on lines 25 to 26
DISTILLATION_LOSS_MODE="jsd_topk"
DISTILLATION_LOSS_MODE="k3"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The DISTILLATION_LOSS_MODE variable is defined twice. The second definition on line 26 will always overwrite the first one on line 25. This will cause the script to run with loss-"k3" regardless of the intention to test jsd_topk. To fix this, you should comment out or remove one of the definitions.

Suggested change
DISTILLATION_LOSS_MODE="jsd_topk"
DISTILLATION_LOSS_MODE="k3"
#DISTILLATION_LOSS_MODE="jsd_topk"
DISTILLATION_LOSS_MODE="k3"

@JacobHelwig
Copy link
Collaborator Author

/gemini review

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces on-policy distillation, a significant new feature. The implementation includes new configuration options, distillation loss functions (top-k and KL estimators), and integrates them into the FSDP engine and PPO trainer. The changes are extensive and well-structured, with new modules for distillation logic and tests for utility functions. My review found a couple of issues in the configuration and example script that need to be addressed. Overall, this is a solid contribution that adds valuable new capabilities.

Comment on lines 25 to 27
DISTILLATION_LOSS_MODE="jsd_topk"
DISTILLATION_LOSS_MODE="k3"
DISTILLATION_LOSS_MODE="reverse_kl_topk+"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The DISTILLATION_LOSS_MODE variable is being redefined on consecutive lines. This means only the last value, "reverse_kl_topk+", will be effective when the script is run. This is likely not the intended behavior and can lead to confusion or incorrect experiments. To make it easier to switch between different loss modes, you should comment out the inactive options.

Suggested change
DISTILLATION_LOSS_MODE="jsd_topk"
DISTILLATION_LOSS_MODE="k3"
DISTILLATION_LOSS_MODE="reverse_kl_topk+"
# DISTILLATION_LOSS_MODE="jsd_topk"
# DISTILLATION_LOSS_MODE="k3"
DISTILLATION_LOSS_MODE="reverse_kl_topk+"

@JacobHelwig
Copy link
Collaborator Author

/gemini review

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces on-policy distillation to the FSDP engine, adding support for top-k and KL estimator-based distillation losses. The changes are extensive, touching configuration, training loop logic, worker implementations, and loss calculations. A critical bug was found in the distillation loss calculation where a variable was used without being initialized in all code paths, which would lead to a runtime error. The rest of the implementation, including the complex data flow for distillation and the new validation checks, appears solid.

sequence_lens = prompt_lens + response_lens
sequence_offsets = sequence_lens.cumsum(dim=0)
assert sequence_offsets[-1].item() == values.shape[0]
assert not prompt_lens.eq(0).any(), f"seq_offset - resp_len - 1 assumes prompt_len > 0. Got {prompt_lens}"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This assertion correctly identifies that the slicing logic seq_offset - resp_len - 1 assumes prompt_len > 0. However, if a dataset contains empty prompts, this will crash the training. It would be more robust to handle the case of prompt_len == 0 gracefully within the slicing logic instead of asserting. If empty prompts are not expected and should be filtered, this check is fine, but handling it in the code would prevent unexpected failures with new datasets.

@JacobHelwig
Copy link
Collaborator Author

/gemini review

@JacobHelwig JacobHelwig marked this pull request as ready for review January 19, 2026 03:36
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces on-policy distillation capabilities, a significant and well-implemented feature. The changes are extensive, touching upon configuration, loss functions, worker implementations, and core training logic. The addition of top-k and KL estimator distillation losses is comprehensive. The code is well-structured, with new functionalities organized into a distillation module and supported by new tests and example scripts. My review identified one high-severity issue concerning a latent bug in a utility function that affects handling of empty prompts, which is currently mitigated by an assertion. Addressing this would improve the robustness of the implementation. Overall, this is a solid contribution.

sequence_lens = prompt_lens + response_lens
sequence_offsets = sequence_lens.cumsum(dim=0)
assert sequence_offsets[-1].item() == values.shape[0]
assert not prompt_lens.eq(0).any(), f"seq_offset - resp_len - 1 assumes prompt_len > 0. Got {prompt_lens}"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The assertion assert not prompt_lens.eq(0).any() correctly prevents a bug in the slicing logic for prompts of length 0. However, this also disallows what could be a valid use case (empty prompts).

The slicing logic on line 146, values[seq_offset - resp_len - 1 : seq_offset - 1], is incorrect when prompt_len is 0, as the start index becomes -1, which leads to incorrect behavior.

While the assertion prevents a crash, it would be more robust to:

  1. Change the assertion to a ValueError with a clear error message for users.
  2. Ideally, fix the underlying slicing logic to correctly handle empty prompts if they are expected to be supported.

A more robust check would be:

if torch.any(prompt_lens == 0):
    raise ValueError(
        "Prompts with length 0 are not supported by the current slicing logic. "
        "Please filter out empty prompts from your dataset or update the slicing logic."
    )

@JacobHelwig JacobHelwig force-pushed the jhelwig/onPolicyDistillation branch from f14c77d to 58e1e28 Compare January 27, 2026 16:54
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants