-
Notifications
You must be signed in to change notification settings - Fork 3.1k
[rollou] fix: forward max_tokens from rollout config to vLLM backends #5027
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request aims to forward max_tokens and max_new_tokens from the rollout configuration to the vLLM/sglang backends. I've identified a few critical issues that prevent this from working as intended. Specifically, there's a bug in how sampling parameters are updated in verl/experimental/agent_loop/agent_loop.py, an incorrect type hint in verl/workers/config/rollout.py that would cause a runtime error, and the new logic is missing entirely from verl/experimental/fully_async_policy/agent_loop/agent_loop.py. I have provided detailed comments and suggestions to address these problems.
| for param_name in ["max_tokens", "max_new_tokens"]: | ||
| param_value = getattr(config, param_name, None) | ||
| if param_value is not None: | ||
| sampling_params[param_value] = param_value |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's a bug in how the sampling parameters are being updated. You're using the parameter's value (param_value) as the dictionary key, but it should be the parameter's name (param_name). This will cause a TypeError if the value is not hashable, or will add an incorrect key to the sampling_params dictionary, preventing the setting from being applied.
| sampling_params[param_value] = param_value | |
| sampling_params[param_name] = param_value |
verl/workers/config/rollout.py
Outdated
| max_tokens: Optional[list] = None | ||
| max_new_tokens: Optional[list] = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The type hints for max_tokens and max_new_tokens are incorrectly defined as Optional[list]. These parameters should be integers representing the maximum number of tokens. Using a list will cause a TypeError at runtime when trying to add them to the sampling_params dictionary, as lists are not hashable.
| max_tokens: Optional[list] = None | |
| max_new_tokens: Optional[list] = None | |
| max_tokens: Optional[int] = None | |
| max_new_tokens: Optional[int] = None |
| repetition_penalty=1.0, | ||
| logprobs=config.calculate_log_probs, | ||
| ) | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This file is missing the logic to forward max_tokens and max_new_tokens from the rollout configuration to sampling_params. This was the main goal of the PR and its omission here will lead to inconsistent behavior between the two agent loop implementations. Please add the forwarding logic here as you did in verl/experimental/agent_loop/agent_loop.py.
| # configure max generation tokens for vllm/sglang | |
| for param_name in ["max_tokens", "max_new_tokens"]: | |
| param_value = getattr(config, param_name, None) | |
| if param_value is not None: | |
| sampling_params[param_name] = param_value | |
What does this PR do?
forward max_tokens/max_new_tokens from rollout config to vLLM/sglang backends