Skip to content

Commit 46edfe3

Browse files
committed
fix unittest and optimizer config
1 parent faed5ef commit 46edfe3

File tree

5 files changed

+42
-24
lines changed

5 files changed

+42
-24
lines changed

tests/common/config_test.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
import shutil
77
import unittest
88

9+
import torch
10+
911
from tests.tools import get_template_config, get_unittest_dataset_config
1012
from trinity.common.config import InferenceModelConfig, load_config
1113

@@ -143,10 +145,9 @@ def test_optimizer_config_propagation(self):
143145
config.algorithm.optimizer.lr = 1e-4
144146
config.algorithm.optimizer.weight_decay = 0.05
145147
config.algorithm.optimizer.clip_grad = 2.0
146-
config.algorithm.optimizer.lr_decay_steps = 1000
147-
config.algorithm.optimizer.lr_decay_style = "cosine"
148-
config.algorithm.optimizer.lr_warmup_init = 1e-7
149-
config.algorithm.optimizer.min_lr = 1e-6
148+
config.trainer.total_steps = 1000
149+
config.algorithm.optimizer.lr_scheduler_type = "cosine"
150+
config.algorithm.optimizer.min_lr_ratio = 1e-2
150151
config.check_and_update()
151152
self.assertEqual(config.trainer.trainer_config.actor_rollout_ref.actor.optim.lr, 1e-4)
152153
self.assertEqual(
@@ -159,10 +160,20 @@ def test_optimizer_config_propagation(self):
159160
self.assertEqual(
160161
config.trainer.trainer_config.actor_rollout_ref.actor.optim.lr_decay_style, "cosine"
161162
)
162-
self.assertEqual(
163-
config.trainer.trainer_config.actor_rollout_ref.actor.optim.lr_warmup_init, 1e-7
163+
self.assertTrue(
164+
torch.allclose(
165+
torch.tensor(
166+
config.trainer.trainer_config.actor_rollout_ref.actor.optim.lr_warmup_init
167+
),
168+
torch.tensor(1e-6),
169+
)
170+
)
171+
self.assertTrue(
172+
torch.allclose(
173+
torch.tensor(config.trainer.trainer_config.actor_rollout_ref.actor.optim.min_lr),
174+
torch.tensor(1e-6),
175+
)
164176
)
165-
self.assertEqual(config.trainer.trainer_config.actor_rollout_ref.actor.optim.min_lr, 1e-6)
166177
# critic optimizer should not be affected
167178
self.assertEqual(config.trainer.trainer_config.critic.optim.lr, 1e-5)
168179
self.assertEqual(config.trainer.trainer_config.critic.optim.weight_decay, 0.01)

tests/trainer/trainer_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1437,8 +1437,8 @@ def tearDown(self):
14371437
shutil.rmtree(self.config.checkpoint_job_dir, ignore_errors=True)
14381438

14391439

1440+
@unittest.skipIf("TINKER_API_KEY" not in os.environ, "TINKER_API_KEY is not set")
14401441
class TestTinkerTrainer(BaseTrainerCase):
1441-
@unittest.skipIf("TINKER_API_KEY" not in os.environ, "TINKER_API_KEY is not set")
14421442
def test_trainer(self):
14431443
"""Test GSM8K on tinker."""
14441444
# test both mode

trinity/common/config.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -93,17 +93,13 @@ class OptimizerConfig:
9393
lr: float = 1e-6
9494
lr_warmup_steps: int = -1
9595
lr_warmup_steps_ratio: float = 0.0
96-
min_lr_ratio: Optional[float] = 0.0
96+
min_lr_ratio: float = 0.0
9797
warmup_style: Optional[str] = None # deprecated !
9898
lr_scheduler_type: str = "constant"
9999
optimizer_type: str = "adam"
100100
betas: List[float] = field(default_factory=lambda: [0.9, 0.999])
101101
weight_decay: float = 0.01
102102
clip_grad: float = 1.0
103-
lr_warmup_init: float = 0.0 # used in megatron
104-
lr_decay_steps: Optional[int] = None # used in megatron
105-
lr_decay_style: str = "constant" # used in megatron, duplicated with lr_scheduler_type in veRL
106-
min_lr: float = 0.0
107103

108104

109105
@dataclass

trinity/common/models/tinker_model.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,8 @@ async def generate(self, prompt: str, **kwargs) -> Sequence[Experience]:
9595
if with_chat_completion:
9696
create_time = int(time.time())
9797
output = await self._generate_internal(prompt={"prompt_token_ids": token_ids}, **kwargs)
98-
return_logprobs = kwargs.get("logprobs", self.config.logprobs is not None)
98+
logprobs = kwargs.get("logprobs", self.config.logprobs)
99+
return_logprobs = logprobs is not None and logprobs is not False
99100
experiences = [
100101
Experience(
101102
tokens=torch.tensor(token_ids + sequence.tokens, dtype=torch.int32),

trinity/common/verl_config.py

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -66,10 +66,10 @@ class Optim:
6666
total_training_steps: int = -1 # ! DO NOT SET, use trainer.total_steps
6767
betas: List[float] = field(default_factory=lambda: [0.9, 0.999])
6868
clip_grad: float = 1.0
69-
lr_warmup_init: float = 0.0
69+
lr_warmup_init: Optional[float] = None # 0.0
7070
lr_decay_steps: Optional[int] = None
71-
lr_decay_style: str = "constant"
72-
min_lr: float = 0.0
71+
lr_decay_style: Optional[str] = None # "constant"
72+
min_lr: Optional[float] = None # 0.0
7373
weight_decay: float = 0.01
7474
weight_decay_incr_style: str = "constant"
7575
lr_wsd_decay_style: str = "exponential"
@@ -607,22 +607,32 @@ def synchronize_config(self, config: Config) -> None: # noqa: C901
607607
self.critic.strategy = "fsdp"
608608

609609
# Algorithm related config
610-
for field_name in config.algorithm.optimizer.__dataclass_fields__:
611-
field_value = getattr(config.algorithm.optimizer, field_name)
610+
actor_optim = self.actor_rollout_ref.actor.optim
611+
critic_optim = self.critic.optim
612+
optim_config = config.algorithm.optimizer
613+
for field_name in optim_config.__dataclass_fields__:
614+
field_value = getattr(optim_config, field_name)
612615
if field_name == "optimizer_type":
613-
setattr(self.actor_rollout_ref.actor.optim, "optimizer", field_value)
614-
elif hasattr(self.actor_rollout_ref.actor.optim, field_name):
615-
setattr(self.actor_rollout_ref.actor.optim, field_name, field_value)
616+
setattr(actor_optim, "optimizer", field_value)
617+
elif hasattr(actor_optim, field_name):
618+
setattr(actor_optim, field_name, field_value)
619+
# ensure megatron optimizer config compatibility
620+
set_if_none(actor_optim, "lr_warmup_init", optim_config.min_lr_ratio * optim_config.lr)
621+
set_if_none(actor_optim, "lr_decay_steps", self.trainer.total_training_steps)
622+
set_if_none(actor_optim, "lr_decay_style", optim_config.lr_scheduler_type)
623+
set_if_none(actor_optim, "min_lr", optim_config.min_lr_ratio * optim_config.lr)
624+
set_if_none(critic_optim, "lr_warmup_init", 0.0)
625+
set_if_none(critic_optim, "lr_decay_steps", self.trainer.total_training_steps)
626+
set_if_none(critic_optim, "lr_decay_style", "constant")
627+
set_if_none(critic_optim, "min_lr", 0.0)
616628
# fix optimizer type for fsdp
617629
if config.trainer.trainer_strategy.startswith("fsdp"):
618630
optim_map = {
619631
"adam": "AdamW",
620632
"adamw": "AdamW",
621633
"sgd": "SGD",
622634
}
623-
actor_optim = self.actor_rollout_ref.actor.optim
624635
actor_optim.optimizer = optim_map.get(actor_optim.optimizer, actor_optim.optimizer)
625-
critic_optim = self.critic.optim
626636
critic_optim.optimizer = optim_map.get(critic_optim.optimizer, critic_optim.optimizer)
627637
self.actor_rollout_ref.actor.use_kl_loss = config.algorithm.kl_loss_fn != "none"
628638
self.algorithm.use_kl_in_reward = config.algorithm.kl_penalty_fn != "none"

0 commit comments

Comments
 (0)