Skip to content

Commit f1e930e

Browse files
committed
Re-factor distillation
1 parent b57c369 commit f1e930e

File tree

14 files changed

+223
-138
lines changed

14 files changed

+223
-138
lines changed

examples/dev/debug.sh

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ set -exo pipefail
55

66
export NCCL_P2P_DISABLE=1
77
export CUDA_DEVICE_ORDER=PCI_BUS_ID
8-
export CUDA_VISIBLE_DEVICES=5
8+
export CUDA_VISIBLE_DEVICES=0
99
NUM_DEV=1
1010
export DATA_PATH=$PWD/../verlData
1111
export HF_HOME=$DATA_PATH
@@ -64,4 +64,5 @@ python -m verl.trainer.main_ppo \
6464
actor_rollout_ref.rollout.enforce_eager=True \
6565
actor_rollout_ref.ref.fsdp_config.use_torch_compile=False \
6666
actor_rollout_ref.rollout.agent.num_workers=1 \
67-
trainer.use_legacy_worker_impl=disable
67+
trainer.use_legacy_worker_impl=disable \
68+
actor_rollout_ref.actor.distillation_config.enabled=True

verl/trainer/config/actor/dp_actor.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@ defaults:
1313
# fsdp engine config
1414
- ../engine@fsdp_config: fsdp
1515

16+
# fsdp distillation config
17+
- ../distillation@distillation_config: dp_distillation
18+
1619
# dp actor config, inheriting from trainer/config/actor/actor.yaml
1720
- actor
1821

verl/trainer/config/ref/dp_ref.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@ defaults:
1313
# Target class for this configuration
1414
_target_: verl.workers.config.FSDPActorConfig
1515

16+
# fsdp distillation config
17+
distillation_config: ${oc.select:actor_rollout_ref.actor.distillation_config}
18+
1619
# fsdp config
1720
fsdp_config:
1821

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
# Copyright 2025 Bytedance Ltd. and/or its affiliates
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from .utils import * # noqa: F401
16+
from .losses import * # noqa: F401
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2024 Bytedance Ltd. and/or its affiliates
1+
# Copyright 2025 Bytedance Ltd. and/or its affiliates
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.
@@ -11,118 +11,11 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
"""
15-
Contains utilities/classes for on-policy distillation
16-
"""
1714

18-
from typing import Union, Optional, Callable, Any
19-
from enum import Enum
20-
from omegaconf import DictConfig
21-
from verl.workers.config import ActorConfig
2215
import torch
23-
import torch.nn.functional as F
24-
from tensordict import TensorDict
25-
from verl.utils import tensordict_utils as tu
26-
27-
28-
29-
class Stage(Enum):
30-
"""
31-
Stages for PPO training
32-
"""
33-
OLD_LOG_PROB = "old_log_prob"
34-
REF_LOG_PROB = "ref_log_prob"
35-
ACTOR_UPDATE = "actor_update"
36-
37-
@classmethod
38-
def get_topk_keys(cls, stage: Union[str, "Stage"]):
39-
if isinstance(stage, str):
40-
stage = cls(stage)
41-
return f"{stage.value}_topk_log_probs", f"{stage.value}_topk_indices"
42-
43-
44-
def topk_logprobs_from_logits(logits: torch.Tensor, k: int, compute_both: bool, topk_indices: Optional[torch.Tensor] = None) -> tuple[torch.Tensor, torch.Tensor]:
45-
logprobs = F.log_softmax(logits, dim=-1)
46-
47-
needs_dedupe = False
48-
if compute_both:
49-
if topk_indices is None or topk_indices.shape[-1] == k:
50-
should_compute_topk = True
51-
elif topk_indices.shape[-1] == 2 * k:
52-
should_compute_topk = False
53-
else:
54-
raise ValueError(f"{topk_indices.shape=} is not expected with {k=}")
55-
else:
56-
if topk_indices is None:
57-
should_compute_topk = True
58-
elif topk_indices.shape[-1] == k:
59-
should_compute_topk = False
60-
else:
61-
raise ValueError(f"{topk_indices.shape=} is not expected with {k=}")
62-
63-
64-
topk_logprobs_ls = []
65-
topk_logprobs_indices_ls = []
66-
67-
# Gather logits for provided indices.
68-
if topk_indices is not None:
69-
topk_logprobs = torch.gather(logprobs, dim=-1, index=topk_indices)
70-
topk_logprobs_ls.append(topk_logprobs)
71-
topk_logprobs_indices_ls.append(topk_indices)
72-
73-
# Compute top-k logprobs.
74-
if should_compute_topk:
75-
topk_logprobs, topk_indices = torch.topk(logprobs, k=k, dim=-1)
76-
topk_logprobs_ls.append(topk_logprobs)
77-
topk_logprobs_indices_ls.append(topk_indices)
78-
79-
topk_logprobs = torch.cat(topk_logprobs_ls, dim=-1)
80-
topk_indices = torch.cat(topk_logprobs_indices_ls, dim=-1)
81-
82-
# If top-k have been provided AND new top-k have been computed, we need to deduplicate the indices and logprobs.
83-
if needs_dedupe:
84-
85-
# Make sure indices are sorted so that we can identify duplicates.
86-
topk_indices_diff = topk_indices.diff(dim=-1)
87-
if topk_indices_diff.lt(0).any():
88-
topk_indices, sort_indices = topk_indices.sort(dim=-1)
89-
topk_logprobs = torch.gather(topk_logprobs, dim=-1, index=sort_indices)
90-
topk_indices_diff = topk_indices.diff(dim=-1)
91-
92-
# Find duplicate indices and set their prob to ~0.
93-
if topk_indices_diff.eq(0).any():
94-
index_diffs = torch.nn.functional.pad(topk_indices_diff, (0, 1), value=1)
95-
dupe_mask = index_diffs.eq(0)
96-
topk_logprobs[dupe_mask] = -torch.inf
97-
98-
return topk_logprobs, topk_indices
99-
100-
def compute_topk_outputs(logits: torch.Tensor, batch: TensorDict, cu_seqlens: torch.Tensor):
101-
"""
102-
TODO: Docstring for compute_topk_outputs
103-
"""
104-
stage = batch["stage"]
105-
topk_logprobs, topk_indices = topk_logprobs_from_logits(logits=logits, k=2, compute_both=True, topk_indices=batch.get("topk_indices", None))
106-
topk_logprobs_key, topk_indices_key = Stage.get_topk_keys(stage)
107-
output = {
108-
topk_logprobs_key: torch.nested.nested_tensor_from_jagged(topk_logprobs.squeeze(0), cu_seqlens),
109-
topk_indices_key: torch.nested.nested_tensor_from_jagged(topk_indices.squeeze(0), cu_seqlens),
110-
}
111-
return output
112-
113-
def gather_topk_outputs(stage: Stage, output: TensorDict):
114-
"""
115-
TODO: Docstring for gather_topk_outputs
116-
"""
117-
topk_logprobs_key, topk_indices_key = Stage.get_topk_keys(stage)
118-
topk_logprobs = tu.get(output, topk_logprobs_key)
119-
if topk_logprobs is not None:
120-
return {
121-
topk_logprobs_key: topk_logprobs.float(),
122-
topk_indices_key: tu.get(output, topk_indices_key),
123-
}
124-
else:
125-
return {}
16+
from typing import Callable, Optional, Any
17+
from omegaconf import DictConfig
18+
from verl.workers.config import DistillationConfig
12619

12720
# TODO: Update args
12821
DistillationLossFn = Callable[
@@ -132,7 +25,7 @@ def gather_topk_outputs(stage: Stage, output: TensorDict):
13225
torch.Tensor, # advantages
13326
torch.Tensor, # response_mask
13427
str, # loss_agg_mode
135-
Optional[DictConfig | ActorConfig], # config
28+
Optional[DictConfig | DistillationConfig], # config
13629
torch.Tensor | None, # rollout_log_probs
13730
],
13831
tuple[torch.Tensor, dict[str, Any]],
@@ -174,14 +67,14 @@ def get_distillation_loss_fn(name):
17467
)
17568
return DISTILLATION_LOSS_REGISTRY[loss_name]
17669

177-
from verl.workers.config import DistillationConfig
178-
17970
@register_distillation_loss("student_kl_topk") # type: ignore[arg-type]
18071
def compute_distillation_loss_student_kl_topk(
18172
teacher_log_probs: torch.Tensor,
18273
student_log_probs: torch.Tensor,
18374
teacher_topk_logprobs: torch.Tensor,
18475
student_topk_logprobs: torch.Tensor,
76+
teacher_topk_indices: torch.Tensor,
77+
student_topk_indices: torch.Tensor,
18578
response_mask: torch.Tensor,
18679
config: DistillationConfig,
18780
loss_agg_mode: str = "token-mean",
@@ -198,6 +91,12 @@ def compute_distillation_loss_student_kl_topk(
19891
Top-k log-probabilities of actions under the teacher policy, shape (batch_size, response_length, topk).
19992
student_topk_logprobs (torch.Tensor):
20093
Top-k log-probabilities of actions under the student policy, shape (batch_size, response_length, topk).
94+
teacher_topk_indices (torch.Tensor):
95+
Top-k action indices under the teacher policy, shape (batch_size, response_length, topk).
96+
student_topk_indices (torch.Tensor):
97+
Top-k action indices under the student policy, shape (batch_size, response_length, topk).
98+
response_mask (torch.Tensor):
99+
Mask indicating which tokens to include in the loss, shape (batch_size, response_length).
201100
config: `(verl.trainer.config.DistillationConfig)`:
202101
config for the actor.
203102
loss_agg_mode (str, optional):
@@ -207,6 +106,7 @@ def compute_distillation_loss_student_kl_topk(
207106
loss_agg_mode (str, optional):
208107
Aggregation mode for `agg_loss`. Defaults to "token-mean".
209108
"""
109+
breakpoint()
210110
assert config is not None
211111
topk = config.topk
212112
if teacher_topk_logprobs.shape[-1] != topk or student_topk_logprobs.shape[-1] != topk:
@@ -220,4 +120,6 @@ def compute_distillation_loss_student_kl_topk(
220120
# "actor/ppo_kl": ppo_kl.detach().item(),
221121
# "actor/pg_clipfrac_lower": pg_clipfrac_lower.detach().item(),
222122
# }
223-
return distillation_loss, distillation_metrics
123+
return distillation_loss, distillation_metrics
124+
125+

verl/trainer/distillation/utils.py

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
# Copyright 2024 Bytedance Ltd. and/or its affiliates
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""
15+
Contains utilities/classes for on-policy distillation
16+
"""
17+
18+
from typing import Union, Optional
19+
import torch
20+
import torch.nn.functional as F
21+
from tensordict import TensorDict
22+
from verl.utils import tensordict_utils as tu
23+
from enum import Enum
24+
25+
class Stage(Enum):
26+
"""
27+
Stages for PPO training
28+
"""
29+
OLD_LOG_PROB = "old_log_prob"
30+
REF_LOG_PROB = "ref_log_prob"
31+
ACTOR_UPDATE = "actor_update"
32+
33+
def get_topk_keys(stage: Union[str, Stage]):
34+
"""TODO: Docstring for get_topk_keys"""
35+
if isinstance(stage, Stage):
36+
stage = stage.value
37+
return f"{stage}_topk_log_probs", f"{stage}_topk_indices"
38+
39+
def topk_logprobs_from_logits(logits: torch.Tensor, k: int, compute_both: bool, topk_indices: Optional[torch.Tensor] = None) -> tuple[torch.Tensor, torch.Tensor]:
40+
"""TODO: Docstring for topk_logprobs_from_logits"""
41+
logprobs = F.log_softmax(logits, dim=-1)
42+
43+
needs_dedupe = False
44+
if compute_both:
45+
if topk_indices is None or topk_indices.shape[-1] == k:
46+
should_compute_topk = True
47+
elif topk_indices.shape[-1] == 2 * k:
48+
should_compute_topk = False
49+
else:
50+
raise ValueError(f"{topk_indices.shape=} is not expected with {k=}")
51+
else:
52+
if topk_indices is None:
53+
should_compute_topk = True
54+
elif topk_indices.shape[-1] == k:
55+
should_compute_topk = False
56+
else:
57+
raise ValueError(f"{topk_indices.shape=} is not expected with {k=}")
58+
59+
60+
topk_logprobs_ls = []
61+
topk_logprobs_indices_ls = []
62+
63+
# Gather logits for provided indices.
64+
if topk_indices is not None:
65+
topk_logprobs = torch.gather(logprobs, dim=-1, index=topk_indices)
66+
topk_logprobs_ls.append(topk_logprobs)
67+
topk_logprobs_indices_ls.append(topk_indices)
68+
69+
# Compute top-k logprobs.
70+
if should_compute_topk:
71+
topk_logprobs, topk_indices = torch.topk(logprobs, k=k, dim=-1)
72+
topk_logprobs_ls.append(topk_logprobs)
73+
topk_logprobs_indices_ls.append(topk_indices)
74+
75+
topk_logprobs = torch.cat(topk_logprobs_ls, dim=-1)
76+
topk_indices = torch.cat(topk_logprobs_indices_ls, dim=-1)
77+
78+
# If top-k have been provided AND new top-k have been computed, we need to deduplicate the indices and logprobs.
79+
if needs_dedupe:
80+
81+
# Make sure indices are sorted so that we can identify duplicates.
82+
topk_indices_diff = topk_indices.diff(dim=-1)
83+
if topk_indices_diff.lt(0).any():
84+
topk_indices, sort_indices = topk_indices.sort(dim=-1)
85+
topk_logprobs = torch.gather(topk_logprobs, dim=-1, index=sort_indices)
86+
topk_indices_diff = topk_indices.diff(dim=-1)
87+
88+
# Find duplicate indices and set their prob to ~0.
89+
if topk_indices_diff.eq(0).any():
90+
index_diffs = torch.nn.functional.pad(topk_indices_diff, (0, 1), value=1)
91+
dupe_mask = index_diffs.eq(0)
92+
topk_logprobs[dupe_mask] = -torch.inf
93+
94+
return topk_logprobs, topk_indices
95+
96+
def compute_topk_outputs(logits: torch.Tensor, batch: TensorDict, cu_seqlens: torch.Tensor):
97+
"""
98+
TODO: Docstring for compute_topk_outputs
99+
"""
100+
stage = batch["stage"]
101+
topk_logprobs, topk_indices = topk_logprobs_from_logits(logits=logits, k=2, compute_both=True, topk_indices=batch.get("topk_indices", None))
102+
topk_logprobs_key, topk_indices_key = get_topk_keys(stage)
103+
output = {
104+
topk_logprobs_key: torch.nested.nested_tensor_from_jagged(topk_logprobs.squeeze(0), cu_seqlens),
105+
topk_indices_key: torch.nested.nested_tensor_from_jagged(topk_indices.squeeze(0), cu_seqlens),
106+
}
107+
return output
108+
109+
def gather_topk_outputs(stage: Stage, output: TensorDict):
110+
"""
111+
TODO: Docstring for gather_topk_outputs
112+
"""
113+
topk_logprobs_key, topk_indices_key = get_topk_keys(stage)
114+
topk_logprobs = tu.get(output, topk_logprobs_key)
115+
if topk_logprobs is not None:
116+
return {
117+
topk_logprobs_key: topk_logprobs.float(),
118+
topk_indices_key: tu.get(output, topk_indices_key),
119+
}
120+
else:
121+
return {}
122+

verl/trainer/ppo/ray_trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,8 @@
5252
)
5353
from verl.trainer.ppo.reward import compute_reward, compute_reward_async
5454
from verl.trainer.ppo.utils import Role, WorkerType, need_critic, need_reference_policy, need_reward_model
55+
from verl.trainer.distillation import Stage, gather_topk_outputs
5556
from verl.utils import tensordict_utils as tu
56-
from verl.utils.distillation import Stage, gather_topk_outputs
5757
from verl.utils.checkpoint.checkpoint_manager import find_latest_ckpt_path, should_save_ckpt_esi
5858
from verl.utils.config import omega_conf_to_dataclass
5959
from verl.utils.debug import marked_timer

verl/workers/config/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from .optimizer import * # noqa: F401
2121
from .reward_model import * # noqa: F401
2222
from .rollout import * # noqa: F401
23+
from .distillation import * # noqa: F401
2324

2425
__all__ = (
2526
actor.__all__
@@ -29,4 +30,5 @@
2930
+ optimizer.__all__
3031
+ rollout.__all__
3132
+ model.__all__
33+
+ distillation.__all__
3234
)

0 commit comments

Comments
 (0)