1- # Copyright 2024 Bytedance Ltd. and/or its affiliates
1+ # Copyright 2025 Bytedance Ltd. and/or its affiliates
22#
33# Licensed under the Apache License, Version 2.0 (the "License");
44# you may not use this file except in compliance with the License.
1111# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212# See the License for the specific language governing permissions and
1313# limitations under the License.
14- """
15- Contains utilities/classes for on-policy distillation
16- """
1714
18- from typing import Union , Optional , Callable , Any
19- from enum import Enum
20- from omegaconf import DictConfig
21- from verl .workers .config import ActorConfig
2215import torch
23- import torch .nn .functional as F
24- from tensordict import TensorDict
25- from verl .utils import tensordict_utils as tu
26-
27-
28-
29- class Stage (Enum ):
30- """
31- Stages for PPO training
32- """
33- OLD_LOG_PROB = "old_log_prob"
34- REF_LOG_PROB = "ref_log_prob"
35- ACTOR_UPDATE = "actor_update"
36-
37- @classmethod
38- def get_topk_keys (cls , stage : Union [str , "Stage" ]):
39- if isinstance (stage , str ):
40- stage = cls (stage )
41- return f"{ stage .value } _topk_log_probs" , f"{ stage .value } _topk_indices"
42-
43-
44- def topk_logprobs_from_logits (logits : torch .Tensor , k : int , compute_both : bool , topk_indices : Optional [torch .Tensor ] = None ) -> tuple [torch .Tensor , torch .Tensor ]:
45- logprobs = F .log_softmax (logits , dim = - 1 )
46-
47- needs_dedupe = False
48- if compute_both :
49- if topk_indices is None or topk_indices .shape [- 1 ] == k :
50- should_compute_topk = True
51- elif topk_indices .shape [- 1 ] == 2 * k :
52- should_compute_topk = False
53- else :
54- raise ValueError (f"{ topk_indices .shape = } is not expected with { k = } " )
55- else :
56- if topk_indices is None :
57- should_compute_topk = True
58- elif topk_indices .shape [- 1 ] == k :
59- should_compute_topk = False
60- else :
61- raise ValueError (f"{ topk_indices .shape = } is not expected with { k = } " )
62-
63-
64- topk_logprobs_ls = []
65- topk_logprobs_indices_ls = []
66-
67- # Gather logits for provided indices.
68- if topk_indices is not None :
69- topk_logprobs = torch .gather (logprobs , dim = - 1 , index = topk_indices )
70- topk_logprobs_ls .append (topk_logprobs )
71- topk_logprobs_indices_ls .append (topk_indices )
72-
73- # Compute top-k logprobs.
74- if should_compute_topk :
75- topk_logprobs , topk_indices = torch .topk (logprobs , k = k , dim = - 1 )
76- topk_logprobs_ls .append (topk_logprobs )
77- topk_logprobs_indices_ls .append (topk_indices )
78-
79- topk_logprobs = torch .cat (topk_logprobs_ls , dim = - 1 )
80- topk_indices = torch .cat (topk_logprobs_indices_ls , dim = - 1 )
81-
82- # If top-k have been provided AND new top-k have been computed, we need to deduplicate the indices and logprobs.
83- if needs_dedupe :
84-
85- # Make sure indices are sorted so that we can identify duplicates.
86- topk_indices_diff = topk_indices .diff (dim = - 1 )
87- if topk_indices_diff .lt (0 ).any ():
88- topk_indices , sort_indices = topk_indices .sort (dim = - 1 )
89- topk_logprobs = torch .gather (topk_logprobs , dim = - 1 , index = sort_indices )
90- topk_indices_diff = topk_indices .diff (dim = - 1 )
91-
92- # Find duplicate indices and set their prob to ~0.
93- if topk_indices_diff .eq (0 ).any ():
94- index_diffs = torch .nn .functional .pad (topk_indices_diff , (0 , 1 ), value = 1 )
95- dupe_mask = index_diffs .eq (0 )
96- topk_logprobs [dupe_mask ] = - torch .inf
97-
98- return topk_logprobs , topk_indices
99-
100- def compute_topk_outputs (logits : torch .Tensor , batch : TensorDict , cu_seqlens : torch .Tensor ):
101- """
102- TODO: Docstring for compute_topk_outputs
103- """
104- stage = batch ["stage" ]
105- topk_logprobs , topk_indices = topk_logprobs_from_logits (logits = logits , k = 2 , compute_both = True , topk_indices = batch .get ("topk_indices" , None ))
106- topk_logprobs_key , topk_indices_key = Stage .get_topk_keys (stage )
107- output = {
108- topk_logprobs_key : torch .nested .nested_tensor_from_jagged (topk_logprobs .squeeze (0 ), cu_seqlens ),
109- topk_indices_key : torch .nested .nested_tensor_from_jagged (topk_indices .squeeze (0 ), cu_seqlens ),
110- }
111- return output
112-
113- def gather_topk_outputs (stage : Stage , output : TensorDict ):
114- """
115- TODO: Docstring for gather_topk_outputs
116- """
117- topk_logprobs_key , topk_indices_key = Stage .get_topk_keys (stage )
118- topk_logprobs = tu .get (output , topk_logprobs_key )
119- if topk_logprobs is not None :
120- return {
121- topk_logprobs_key : topk_logprobs .float (),
122- topk_indices_key : tu .get (output , topk_indices_key ),
123- }
124- else :
125- return {}
16+ from typing import Callable , Optional , Any
17+ from omegaconf import DictConfig
18+ from verl .workers .config import DistillationConfig
12619
12720# TODO: Update args
12821DistillationLossFn = Callable [
@@ -132,7 +25,7 @@ def gather_topk_outputs(stage: Stage, output: TensorDict):
13225 torch .Tensor , # advantages
13326 torch .Tensor , # response_mask
13427 str , # loss_agg_mode
135- Optional [DictConfig | ActorConfig ], # config
28+ Optional [DictConfig | DistillationConfig ], # config
13629 torch .Tensor | None , # rollout_log_probs
13730 ],
13831 tuple [torch .Tensor , dict [str , Any ]],
@@ -174,14 +67,14 @@ def get_distillation_loss_fn(name):
17467 )
17568 return DISTILLATION_LOSS_REGISTRY [loss_name ]
17669
177- from verl .workers .config import DistillationConfig
178-
17970@register_distillation_loss ("student_kl_topk" ) # type: ignore[arg-type]
18071def compute_distillation_loss_student_kl_topk (
18172 teacher_log_probs : torch .Tensor ,
18273 student_log_probs : torch .Tensor ,
18374 teacher_topk_logprobs : torch .Tensor ,
18475 student_topk_logprobs : torch .Tensor ,
76+ teacher_topk_indices : torch .Tensor ,
77+ student_topk_indices : torch .Tensor ,
18578 response_mask : torch .Tensor ,
18679 config : DistillationConfig ,
18780 loss_agg_mode : str = "token-mean" ,
@@ -198,6 +91,12 @@ def compute_distillation_loss_student_kl_topk(
19891 Top-k log-probabilities of actions under the teacher policy, shape (batch_size, response_length, topk).
19992 student_topk_logprobs (torch.Tensor):
20093 Top-k log-probabilities of actions under the student policy, shape (batch_size, response_length, topk).
94+ teacher_topk_indices (torch.Tensor):
95+ Top-k action indices under the teacher policy, shape (batch_size, response_length, topk).
96+ student_topk_indices (torch.Tensor):
97+ Top-k action indices under the student policy, shape (batch_size, response_length, topk).
98+ response_mask (torch.Tensor):
99+ Mask indicating which tokens to include in the loss, shape (batch_size, response_length).
201100 config: `(verl.trainer.config.DistillationConfig)`:
202101 config for the actor.
203102 loss_agg_mode (str, optional):
@@ -207,6 +106,7 @@ def compute_distillation_loss_student_kl_topk(
207106 loss_agg_mode (str, optional):
208107 Aggregation mode for `agg_loss`. Defaults to "token-mean".
209108 """
109+ breakpoint ()
210110 assert config is not None
211111 topk = config .topk
212112 if teacher_topk_logprobs .shape [- 1 ] != topk or student_topk_logprobs .shape [- 1 ] != topk :
@@ -220,4 +120,6 @@ def compute_distillation_loss_student_kl_topk(
220120 # "actor/ppo_kl": ppo_kl.detach().item(),
221121 # "actor/pg_clipfrac_lower": pg_clipfrac_lower.detach().item(),
222122 # }
223- return distillation_loss , distillation_metrics
123+ return distillation_loss , distillation_metrics
124+
125+
0 commit comments