Skip to content

[DRAFT] SFT distillation with teacher endpoint#1905

Open
willccbb wants to merge 1 commit intomainfrom
will/sft-distillation
Open

[DRAFT] SFT distillation with teacher endpoint#1905
willccbb wants to merge 1 commit intomainfrom
will/sft-distillation

Conversation

@willccbb
Copy link
Member

@willccbb willccbb commented Feb 26, 2026

Note

Medium Risk
Adds a new external-rollout execution path that bypasses local inference/weight broadcast and reconstructs tokens from messages, which can affect rollout-token alignment and training dynamics. The changes touch orchestrator scheduling, sampling args, and loss computation, so misconfiguration or subtle tokenization differences could break training correctness.

Overview
Enables text-only teacher rollout distillation by letting the orchestrator generate rollouts from an external OpenAI-compatible endpoint (orchestrator.rollout_model) and training the student with an SFT-style masked NLL objective (trainer.loss.type = "sft").

The orchestrator now supports an external rollout mode that disables policy weight updates/broadcasting, enforces inference omission and use_token_client = false, and adjusts scheduler checkpoint/policy-update behavior accordingly. When rollout token IDs/logprobs aren’t returned, interleave_rollout can reconstruct per-step prompt_ids/completion_ids and masks from trajectory messages using the student tokenizer.

Adds config validation + tests for the new mode, updates sampling arg construction to only request token/logprob fields when using the token client, introduces the SFTLossConfig/sft_loss_fn path in the RL trainer, and documents the new workflow in docs/on_policy_distillation.md and the TOML config skill guide.

Written by Cursor Bugbot for commit ca97c02. This will update automatically on new commits. Configure here.

Copy link

@cursor cursor bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cursor Bugbot has reviewed your changes and found 1 potential issue.

Bugbot Autofix is OFF. To automatically fix reported issues with cloud agents, enable autofix in the Cursor dashboard.

"orchestrator.use_token_client must be false when orchestrator.rollout_model is configured."
)

return self
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing validation: rollout_model requires SFT loss

Medium Severity

When orchestrator.rollout_model is configured, validate_external_rollout_mode enforces inference=None and use_token_client=False, but does not require trainer.loss.type = "sft". With the default loss, reconstructed rollouts have completion_logprobs = [0.0], so inference_logprobs are all zero. The default loss treats these as the rollout policy logprobs and computes importance ratios from them, producing incorrect gradients and meaningless training.

Fix in Cursor Fix in Web

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant