-
Notifications
You must be signed in to change notification settings - Fork 3.1k
[fsdp,trainer,algo] feat: On-Policy Distillation #4897
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[fsdp,trainer,algo] feat: On-Policy Distillation #4897
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request appears to be a work-in-progress for adding on-policy distillation support. The only change present is a minor whitespace modification in the README.md file. As this is a stylistic change with no functional impact, and I am configured to only report issues of high or critical severity, I have no specific comments on the current state of the pull request. I look forward to reviewing more substantial changes as they are added.
4245789 to
840aca3
Compare
|
why the reward is so small (almost 0) before step 40 ? |
Training only explicitly optimizes the distillation loss, not rewards:
Any increase in the logged rewards=GSM8k accuracy are an indirect result of minimizing the distillation loss. In this case, the reason that the base model has Pass@1~=0 is because the default GSM8k answer formatting ( ...
reward_model.reward_manager=remote \
custom_reward_function.path=tests/experimental/reward_loop/reward_fn.py \
custom_reward_function.name=compute_score_math_verify \
trainer.val_only=TrueThe results are: (TaskRunner pid=904198) ("Initial validation metrics: {'val-aux/openai/gsm8k/reward/mean@1': "
(TaskRunner pid=904198) "np.float64(0.31766489764973466), 'val-core/openai/gsm8k/acc/mean@1': "
(TaskRunner pid=904198) "np.float64(0.31766489764973466), 'val-aux/num_turns/min': np.int32(2), "
(TaskRunner pid=904198) "'val-aux/num_turns/max': np.int32(2), 'val-aux/num_turns/mean': "
(TaskRunner pid=904198) 'np.float64(2.0)}')The formatting is only a few tokens, so it does not contribute much to the distillation loss. The distillation loss initially focuses on minimizing other discrepancies between the teacher and student distributions before targeting formatting, which is why early steps of training show 0% accuracy under the stricter parser. |
216f0a7 to
eaca4e1
Compare
9f3cb9d to
d0d0d55
Compare
|
/gemini review |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces on-policy distillation (OPD) as a new feature, which is a significant addition to the library. The implementation is comprehensive, adding new configurations, loss functions, and utility functions, and integrating them into the existing PPO trainer pipeline. My review has identified several critical and high-severity issues related to correctness, robustness, and efficiency that should be addressed. These include a bug that could cause a crash when not using use_remove_padding, an invalid default configuration value, and incorrect handling of zero-length prompts. Additionally, there are some type hint inaccuracies and unnecessary tensor cloning that impact code quality and performance.
verl/trainer/distillation/utils.py
Outdated
| def compute_topk_distillation_inputs(logits: torch.Tensor, batch: TensorDict, cu_seqlens: torch.Tensor, config: DistillationConfig): | ||
| """TODO: Docstring""" | ||
| # Gather inputs for top-k distillation losses. | ||
| logits = logits.squeeze(0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The use of logits.squeeze(0) assumes that the logits tensor always has a batch dimension of size 1. This holds true when use_remove_padding is enabled, as the input is reshaped to (1, total_tokens, ...). However, when use_remove_padding is disabled, logits will have a shape of (batch_size, seq_len, vocab_size). If batch_size is greater than 1, squeeze(0) will raise an error. This will prevent distillation from working correctly in this configuration.
A more robust approach would be to reshape the tensor to flatten the batch and sequence dimensions if they exist, rather than squeezing a specific dimension.
| sequence_lens = prompt_lens + response_lens | ||
| sequence_offsets = sequence_lens.cumsum(dim=0) | ||
| assert sequence_offsets[-1].item() == values.shape[0] | ||
| assert not prompt_lens.eq(0).any(), f"seq_offset - resp_len - 1 assumes prompt_len > 0. Got {prompt_lens}" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The assertion assert not prompt_lens.eq(0).any() will cause a crash if any of the prompts in a batch have a length of 0. While this might be an uncommon case, the code should handle it gracefully instead of crashing. The slicing logic values[seq_offset - resp_len - 1 : seq_offset - 1] relies on prompt_len > 0.
|
/gemini review |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces on-policy distillation capabilities, including top-k and KL estimator-based distillation losses. The changes are well-structured, with new configurations, loss functions, and utility modules. The integration into the existing PPO trainer and FSDP engine seems correct. I've identified one area for improvement regarding the use of __post_init__ in a class with multiple inheritance, which could be made more robust.
|
/gemini review |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a significant new feature: on-policy distillation. The implementation is well-designed, with a clear separation of concerns between configuration, data processing stages, and loss computation. The use of a registry for distillation losses is a good practice for extensibility. I appreciate the addition of unit tests for the new utility functions and the inclusion of an example script. The new validation checks in the configuration are also a great addition to prevent misuse. I have one high-severity finding in the example script where a variable is overwritten, which could lead to incorrect experiment execution.
| DISTILLATION_LOSS_MODE="jsd_topk" | ||
| DISTILLATION_LOSS_MODE="k3" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The DISTILLATION_LOSS_MODE variable is defined twice. The second definition on line 26 will always overwrite the first one on line 25. This will cause the script to run with loss-"k3" regardless of the intention to test jsd_topk. To fix this, you should comment out or remove one of the definitions.
| DISTILLATION_LOSS_MODE="jsd_topk" | |
| DISTILLATION_LOSS_MODE="k3" | |
| #DISTILLATION_LOSS_MODE="jsd_topk" | |
| DISTILLATION_LOSS_MODE="k3" |
|
/gemini review |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces on-policy distillation, a significant new feature. The implementation includes new configuration options, distillation loss functions (top-k and KL estimators), and integrates them into the FSDP engine and PPO trainer. The changes are extensive and well-structured, with new modules for distillation logic and tests for utility functions. My review found a couple of issues in the configuration and example script that need to be addressed. Overall, this is a solid contribution that adds valuable new capabilities.
| DISTILLATION_LOSS_MODE="jsd_topk" | ||
| DISTILLATION_LOSS_MODE="k3" | ||
| DISTILLATION_LOSS_MODE="reverse_kl_topk+" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The DISTILLATION_LOSS_MODE variable is being redefined on consecutive lines. This means only the last value, "reverse_kl_topk+", will be effective when the script is run. This is likely not the intended behavior and can lead to confusion or incorrect experiments. To make it easier to switch between different loss modes, you should comment out the inactive options.
| DISTILLATION_LOSS_MODE="jsd_topk" | |
| DISTILLATION_LOSS_MODE="k3" | |
| DISTILLATION_LOSS_MODE="reverse_kl_topk+" | |
| # DISTILLATION_LOSS_MODE="jsd_topk" | |
| # DISTILLATION_LOSS_MODE="k3" | |
| DISTILLATION_LOSS_MODE="reverse_kl_topk+" |
|
/gemini review |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces on-policy distillation to the FSDP engine, adding support for top-k and KL estimator-based distillation losses. The changes are extensive, touching configuration, training loop logic, worker implementations, and loss calculations. A critical bug was found in the distillation loss calculation where a variable was used without being initialized in all code paths, which would lead to a runtime error. The rest of the implementation, including the complex data flow for distillation and the new validation checks, appears solid.
| sequence_lens = prompt_lens + response_lens | ||
| sequence_offsets = sequence_lens.cumsum(dim=0) | ||
| assert sequence_offsets[-1].item() == values.shape[0] | ||
| assert not prompt_lens.eq(0).any(), f"seq_offset - resp_len - 1 assumes prompt_len > 0. Got {prompt_lens}" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This assertion correctly identifies that the slicing logic seq_offset - resp_len - 1 assumes prompt_len > 0. However, if a dataset contains empty prompts, this will crash the training. It would be more robust to handle the case of prompt_len == 0 gracefully within the slicing logic instead of asserting. If empty prompts are not expected and should be filtered, this check is fine, but handling it in the code would prevent unexpected failures with new datasets.
|
/gemini review |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces on-policy distillation capabilities, a significant and well-implemented feature. The changes are extensive, touching upon configuration, loss functions, worker implementations, and core training logic. The addition of top-k and KL estimator distillation losses is comprehensive. The code is well-structured, with new functionalities organized into a distillation module and supported by new tests and example scripts. My review identified one high-severity issue concerning a latent bug in a utility function that affects handling of empty prompts, which is currently mitigated by an assertion. Addressing this would improve the robustness of the implementation. Overall, this is a solid contribution.
| sequence_lens = prompt_lens + response_lens | ||
| sequence_offsets = sequence_lens.cumsum(dim=0) | ||
| assert sequence_offsets[-1].item() == values.shape[0] | ||
| assert not prompt_lens.eq(0).any(), f"seq_offset - resp_len - 1 assumes prompt_len > 0. Got {prompt_lens}" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The assertion assert not prompt_lens.eq(0).any() correctly prevents a bug in the slicing logic for prompts of length 0. However, this also disallows what could be a valid use case (empty prompts).
The slicing logic on line 146, values[seq_offset - resp_len - 1 : seq_offset - 1], is incorrect when prompt_len is 0, as the start index becomes -1, which leads to incorrect behavior.
While the assertion prevents a crash, it would be more robust to:
- Change the assertion to a
ValueErrorwith a clear error message for users. - Ideally, fix the underlying slicing logic to correctly handle empty prompts if they are expected to be supported.
A more robust check would be:
if torch.any(prompt_lens == 0):
raise ValueError(
"Prompts with length 0 are not supported by the current slicing logic. "
"Please filter out empty prompts from your dataset or update the slicing logic."
)f14c77d to
58e1e28
Compare
What does this PR do?
Adds on-policy distillation to FSDP engine with top-k distillation losses and KL estimator distillation losses.
Test
Tested with
examples/on_policy_distillation_trainer/run_qwen_gsmk8k.shResults are for the k3 loss estimator (only computed for the sampled token) and the top-k JSD loss. The top-k loss converged faster than k3 and generalized better, but training was more unstable.
GSM8K eval acc
GSM8K train acc
Distillation loss
Design & Code Changes