Skip to content

Commit 9268115

Browse files
author
psyloy
committed
configure max generation tokens for vllm/sglang
1 parent 74f3224 commit 9268115

File tree

5 files changed

+17
-0
lines changed

5 files changed

+17
-0
lines changed

verl/experimental/agent_loop/agent_loop.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -433,6 +433,12 @@ async def generate_sequences(self, batch: DataProto) -> DataProto:
433433
logprobs=config.calculate_log_probs,
434434
)
435435

436+
# configure max generation tokens for vllm/sglang
437+
for param_name in ["max_tokens", "max_new_tokens"]:
438+
param_value = getattr(config, param_name, None)
439+
if param_value is not None:
440+
sampling_params[param_value] = param_value
441+
436442
# override sampling params for validation
437443
if batch.meta_info.get("validate", False):
438444
sampling_params["top_p"] = config.val_kwargs.top_p

verl/experimental/fully_async_policy/agent_loop/agent_loop.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,13 +101,15 @@ async def generate_sequences_no_post(
101101
sampling_params = dict(
102102
temperature=config.temperature,
103103
top_p=config.top_p,
104+
top_k=config.top_k,
104105
repetition_penalty=1.0,
105106
logprobs=config.calculate_log_probs,
106107
)
107108

108109
# override sampling params for validation
109110
if batch.meta_info.get("validate", False):
110111
sampling_params["top_p"] = config.val_kwargs.top_p
112+
sampling_params["top_k"] = config.val_kwargs.top_k
111113
sampling_params["temperature"] = config.val_kwargs.temperature
112114

113115
if "agent_name" not in batch.non_tensor_batch:

verl/trainer/config/rollout/rollout.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,12 @@ top_k: -1
1616
# Top-p sampling parameter. Default 1.0.
1717
top_p: 1
1818

19+
# max number of tokens to generate for vllm.
20+
max_tokens: null
21+
22+
# max number of tokens to generate for sglang.
23+
max_new_tokens: null
24+
1925
# typically the same as data max prompt length
2026
# same as data.max_prompt_length if it exists
2127
prompt_length: ${oc.select:data.max_prompt_length,512}

verl/workers/config/rollout.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,8 @@ class RolloutConfig(BaseConfig):
130130
do_sample: bool = True
131131
n: int = 1
132132
repetition_penalty: float = 1.0
133+
max_tokens: Optional[list] = None
134+
max_new_tokens: Optional[list] = None
133135

134136
# Early termination threshold for multi-turn rollout in sglang.
135137
# Abort remaining requests when (1 - over_sample_rate) * total_requests are completed.

verl/workers/rollout/sglang_rollout/async_sglang_server.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -327,6 +327,7 @@ async def generate(
327327
f"({self.config.max_model_len})."
328328
)
329329

330+
# Determine max_new_tokens from sampling_params or use configured response_length as default
330331
if "max_new_tokens" in sampling_params:
331332
max_new_tokens = sampling_params.pop("max_new_tokens")
332333
elif "max_tokens" in sampling_params:

0 commit comments

Comments
 (0)