Skip to content

Commit fa86a36

Browse files
committed
fix - some qwen3_reranker have no lm_weight
fix - support qwen3 rerank after refactor embedding endpoint
1 parent e5a336b commit fa86a36

File tree

2 files changed

+18
-8
lines changed

2 files changed

+18
-8
lines changed

rtp_llm/models/downstream_modules/reranker/qwen3_reranker.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -107,11 +107,15 @@ def __init__(
107107
super().__init__(config)
108108
self.token_false_id = token_false_id
109109
self.token_true_id = token_true_id
110-
110+
self.tie_word_embeddings = config.tie_word_embeddings
111+
self.lm_head_weight_name = (
112+
"model.embed_tokens.weight"
113+
if self.tie_word_embeddings
114+
else "lm_head.weight"
115+
)
116+
111117
def custom_weight_info(self) -> List[CustomAtomicWeight]:
112-
w_list = [
113-
"lm_head.weight",
114-
]
118+
w_list = [self.lm_head_weight_name]
115119
weights = []
116120
for k in w_list:
117121
weights.append(
@@ -121,7 +125,7 @@ def custom_weight_info(self) -> List[CustomAtomicWeight]:
121125

122126
def init(self, tensor_map: Dict[str, torch.Tensor]):
123127
data_type = to_torch_dtype(self.config_.data_type)
124-
linear_weight = tensor_map["lm_head.weight"]
128+
linear_weight = tensor_map[self.lm_head_weight_name]
125129
self.linear = torch.nn.Linear(linear_weight.shape[1], linear_weight.shape[0])
126130
self.linear.weight.data = linear_weight
127131
self.linear = self.linear.to(data_type).eval().to(self.device)

rtp_llm/models/downstream_modules/utils.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
from typing import Optional
33

44
from rtp_llm.config.model_config import ModelConfig
5-
from rtp_llm.ops import TaskType
65
from rtp_llm.frontend.tokenizer_factory.tokenizers import BaseTokenizer
76
from rtp_llm.models.downstream_modules import (
87
ALLEmbeddingModule,
@@ -13,27 +12,32 @@
1312
RerankerModule,
1413
SparseEmbeddingModule,
1514
)
15+
from rtp_llm.models.downstream_modules.reranker.qwen3_reranker import (
16+
Qwen3RerankerModule,
17+
)
18+
from rtp_llm.ops import TaskType
1619

1720

1821
def create_custom_module(
1922
config: ModelConfig,
2023
tokenizer: Optional[BaseTokenizer],
21-
):
24+
):
2225
# try import internal module
2326
try:
2427
from internal_source.rtp_llm.models.downstream_modules.utils import (
2528
create_custom_module,
2629
)
30+
2731
internal_module = create_custom_module(config, tokenizer)
2832
if internal_module is not None:
2933
return internal_module
3034
except ImportError:
3135
logging.exception("internal module not found, using external module")
3236

33-
3437
task_type = config.task_type
3538
if task_type == TaskType.LANGUAGE_MODEL:
3639
return None
40+
model_type = config.model_type
3741
assert tokenizer is not None, "tokenizer should not be None"
3842
if task_type == TaskType.DENSE_EMBEDDING:
3943
return DenseEmbeddingModule(config, tokenizer)
@@ -47,6 +51,8 @@ def create_custom_module(
4751
return ClassifierModule(config, tokenizer)
4852
elif task_type == TaskType.BGE_M3:
4953
return BgeM3EmbeddingModule(config, tokenizer)
54+
elif model_type == "qwen_3":
55+
return Qwen3RerankerModule(config, tokenizer)
5056
elif task_type == TaskType.RERANKER:
5157
return RerankerModule(config, tokenizer)
5258
raise Exception(f"unknown task_type: {task_type}")

0 commit comments

Comments
 (0)