Skip to content

Conversation

@COLAZERO2
Copy link
Contributor

Bug Fixes:
Fixes bugs that caused the loss to remain high due to unstable gradients when training with gradient checkpointing enabled. After fixing, the acceptance rate increases as intended when using the gradient checkpoint memory optimization trick.

Modification:
Refactors the draft model’s forward function by separating target model hidden state retrieval and the draft model’s layer flow. Wraps the entire training-time test predictions over the drafting length, removing the torch.checkpoint() loops that previously led to a complicated computation graph and incorrect gradient flows.

This caused the loss to remain high due to unstable gradients when training with gradient checkpointing enabled. After fixing, accuracy increases as intended when using the gradient checkpoint memory optimization trick.
@jasonyong
Copy link

It works.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants