Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/rank_llm/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ class InferenceInvocation:
response: str
input_token_count: int
output_token_count: int
reasoning: Optional[str] = None
token_usage: Optional[Dict[str, Any]] = None
output_validation_regex: Optional[str] = None
output_extraction_regex: Optional[str] = None

Expand Down
21 changes: 20 additions & 1 deletion src/rank_llm/demo/rerank_oss_bright.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import json
import os
import sys
from importlib.resources import files
from pathlib import Path

SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
Expand All @@ -13,6 +14,23 @@
from rank_llm.rerank import Reranker
from rank_llm.rerank.listwise import RankListwiseOSLLM

"""
Note: You need to run the vllm server with the following command:
```bash
RANK_MODEL_ID="openai/gpt-oss-20b"
RANK_PORT=48003
RANK_VLLM_LOG="vllm_server_48003.log"
CUDA_VISIBLE_DEVICES=2 vllm serve "$RANK_MODEL_ID" \
--port "$RANK_PORT" \
--dtype auto \
--gpu-memory-utilization 0.9 \
--max-model-len 32768 \
--enable-prompt-tokens-details \
--enable-prefix-caching \
> "$RANK_VLLM_LOG" 2>&1 &
```
"""
TEMPLATES = files("rank_llm.rerank.prompt_templates")
for fusion in ["bm25-splade", "bge-splade", "bm25-bge"]:
for task in [
"aops",
Expand All @@ -39,7 +57,8 @@
use_alpha=True,
is_thinking=True,
reasoning_token_budget=4096 * 4,
base_url="http://localhost:18000/v1",
base_url="http://localhost:48003/v1",
prompt_template_path=(TEMPLATES / "rank_zephyr_template.yaml"),
)
)
kwargs = {"populate_invocations_history": True}
Expand Down
19 changes: 19 additions & 0 deletions src/rank_llm/demo/rerank_qwen3_bright.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,24 @@
from rank_llm.rerank import Reranker
from rank_llm.rerank.listwise import RankListwiseOSLLM

"""
Note: You need to run the vllm server with the following command:
```bash
RANK_MODEL_ID="Qwen/Qwen3-8B"
RANK_PORT=38003
RANK_VLLM_LOG="vllm_server_48003.log"
CUDA_VISIBLE_DEVICES=2 vllm serve "$RANK_MODEL_ID" \
--port "$RANK_PORT" \
--dtype auto \
--gpu-memory-utilization 0.9 \
--max-model-len 32768 \
--enable-prompt-tokens-details \
--enable-prefix-caching \
--reasoning-parser qwen3 \
> "$RANK_VLLM_LOG" 2>&1 &
```
"""


def main():
for fusion in ["bge-splade", "bm25-bge", "bm25-splade"]:
Expand Down Expand Up @@ -52,6 +70,7 @@ def main():
use_alpha=True,
is_thinking=True,
reasoning_token_budget=4096 * 4,
base_url="http://localhost:38003/v1",
)
)
kwargs = {"populate_invocations_history": True}
Expand Down
75 changes: 54 additions & 21 deletions src/rank_llm/rerank/listwise/listwise_rankllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,28 +119,61 @@ def permutation_pipeline_batched(
batched_results = self.run_llm_batched(
[prompt for prompt, _ in prompts], current_window_size=rank_end - rank_start
)

for index, (result, (prompt, in_token_count)) in enumerate(
zip(results, prompts)
):
permutation, out_token_count = batched_results[index]
if logging:
logger.debug(f"output: {permutation}")
if populate_invocations_history:
if result.invocations_history is None:
result.invocations_history = []
inference_invocation = InferenceInvocation(
prompt,
permutation,
in_token_count,
out_token_count,
self._inference_handler.template["output_validation_regex"],
self._inference_handler.template["output_extraction_regex"],
## Legacy format (text, out_token_count)
## TODO: Remove this once all the listwise rerankers have switched to the new format.
if len(batched_results[0]) == 2:
for index, (result, (prompt, in_token_count)) in enumerate(
zip(results, prompts)
):
permutation, out_token_count = batched_results[index]
if logging:
logger.debug(f"output: {permutation}")
if populate_invocations_history:
if result.invocations_history is None:
result.invocations_history = []
inference_invocation = InferenceInvocation(
prompt,
permutation,
in_token_count,
out_token_count,
self._inference_handler.template["output_validation_regex"],
self._inference_handler.template["output_extraction_regex"],
)
result.invocations_history.append(inference_invocation)
else:
## New format (text, reasoning, usage)
assert len(batched_results[0]) == 3
for index, (result, (prompt, in_token_count)) in enumerate(
zip(results, prompts)
):
permutation, reasoning, usage = batched_results[index]
in_token_count = usage.get("prompt_tokens") or usage.get("input_tokens")
out_token_count = usage.get("completion_tokens") or usage.get(
"output_tokens"
)
result.invocations_history.append(inference_invocation)
result = self.receive_permutation(
result, permutation, rank_start, rank_end, logging
)
if logging:
logger.debug(f"output: {permutation}")
if populate_invocations_history:
if result.invocations_history is None:
result.invocations_history = []
inference_invocation = InferenceInvocation(
prompt,
permutation,
in_token_count,
out_token_count,
reasoning=reasoning,
token_usage=usage,
output_validation_regex=self._inference_handler.template[
"output_validation_regex"
],
output_extraction_regex=self._inference_handler.template[
"output_extraction_regex"
],
)
result.invocations_history.append(inference_invocation)
result = self.receive_permutation(
result, permutation, rank_start, rank_end, logging
)

return results

Expand Down
10 changes: 2 additions & 8 deletions src/rank_llm/rerank/listwise/rank_listwise_os_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ def run_llm_batched(
self,
prompts: List[str | List[Dict[str, str]]],
current_window_size: Optional[int] = None,
) -> List[Tuple[str, int]]:
) -> List[Tuple[str, int]] | List[Tuple[str, str, Dict[str, Any]]]:
if current_window_size is None:
current_window_size = self._window_size

Expand Down Expand Up @@ -286,12 +286,6 @@ def run_llm_batched(
else self.num_output_tokens(current_window_size)
)
if self._base_url:
kwargs = {
"max_tokens": max_tokens,
"temperature": 0,
# TODO: expose the reasoning effort as an init param if needed.
"reasoning": {"effort": "medium", "summary": "detailed"},
}
return self._vllm_handler.chat_completions(
prompts=prompts, max_tokens=max_tokens, temperature=0
)
Expand Down Expand Up @@ -345,7 +339,7 @@ def run_llm_batched(

def run_llm(
self, prompt: str, current_window_size: Optional[int] = None
) -> Tuple[str, int]:
) -> Tuple[str, int] | Tuple[str, str, Dict[str, Any]]:
# Now forward the run_llm into run_llm_batched
if current_window_size is None:
current_window_size = self._window_size
Expand Down
18 changes: 8 additions & 10 deletions src/rank_llm/rerank/vllm_handler_with_openai_sdk.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,9 @@
import asyncio
from typing import Dict, List, Sequence, Tuple, Union
from typing import Any, Dict, List, Tuple

from openai import AsyncOpenAI, OpenAI
from transformers import AutoTokenizer, PreTrainedTokenizerBase

Message = Dict[str, str]
PromptLike = Union[str, Message, Sequence[Message]]


class VllmHandlerWithOpenAISDK:
def __init__(
Expand All @@ -33,7 +30,7 @@ def get_tokenizer(self) -> PreTrainedTokenizerBase:

async def _one_inference(
self, messages: list[dict[str, str]], **kwargs
) -> Tuple[str, int]:
) -> Tuple[str, str, Dict[str, Any]]:
assert isinstance(messages, list)
assert isinstance(messages[0], dict)
response = None
Expand All @@ -44,20 +41,21 @@ async def _one_inference(
**kwargs,
)
text = response.choices[0].message.content
toks = len(self._tokenizer.encode(text))
return text, toks
reasoning = response.choices[0].message.reasoning
usage = response.usage.model_dump(mode="json")
return text, reasoning, usage
except Exception as e:
print(response)
print(e)
return str(e), 0
return str(e), "Reasoning tokens redcated due to error", {}

async def _all_inferences(
self, prompts: list[list[dict[str, str]]], **kwargs
) -> List[Tuple[str, int]]:
) -> List[Tuple[str, str, Dict[str, Any]]]:
tasks = [asyncio.create_task(self._one_inference(p, **kwargs)) for p in prompts]
return await asyncio.gather(*tasks)

def chat_completions(
self, prompts: list[list[dict[str, str]]], **kwargs
) -> List[Tuple[str, int]]:
) -> List[Tuple[str, str, Dict[str, Any]]]:
return asyncio.run(self._all_inferences(prompts, **kwargs))