Skip to content

Commit

Permalink
Support Alibaba-NLP/gte-Qwen2-7B-instruct embedding Model (#1186)
Browse files Browse the repository at this point in the history
Co-authored-by: Ying Sheng <[email protected]>
  • Loading branch information
zhaochenyang20 and Ying1123 authored Aug 25, 2024
1 parent 66e7dca commit 30b4f77
Show file tree
Hide file tree
Showing 15 changed files with 167 additions and 55 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/accuracy-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -43,4 +43,4 @@ jobs:
run: |
cd test/srt
python3 test_eval_accuracy_large.py
timeout-minutes: 10
timeout-minutes: 20
2 changes: 1 addition & 1 deletion .github/workflows/unit-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ jobs:
run: |
cd test/srt
python3 run_suite.py --suite minimal
timeout-minutes: 18
timeout-minutes: 20

- name: Test Frontend Language
run: |
Expand Down
15 changes: 15 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,13 @@ response = client.chat.completions.create(
max_tokens=64,
)
print(response)

# Text embedding
response = client.embeddings.create(
model="default",
input="How are you today",
)
print(response)
```

It supports streaming, vision, and most features of the Chat/Completions/Models/Batch endpoints specified by the [OpenAI API Reference](https://platform.openai.com/docs/api-reference/).
Expand Down Expand Up @@ -223,6 +230,8 @@ python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct

### Supported Models

**Generative Models**

- Llama / Llama 2 / Llama 3 / Llama 3.1
- Mistral / Mixtral / Mistral NeMo
- Gemma / Gemma 2
Expand All @@ -243,6 +252,12 @@ python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct
- ChatGLM
- InternLM 2

**Embedding Models**

- e5-mistral
- gte-Qwen2
- `python -m sglang.launch_server --model-path Alibaba-NLP/gte-Qwen2-7B-instruct --is-embedding`

Instructions for supporting a new model are [here](https://github.com/sgl-project/sglang/blob/main/docs/en/model_support.md).

#### Use Models From ModelScope
Expand Down
5 changes: 4 additions & 1 deletion python/sglang/srt/managers/tokenizer_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,10 @@ def __init__(
trust_remote_code=server_args.trust_remote_code,
model_overide_args=model_overide_args,
)
self.is_generation = is_generation_model(self.hf_config.architectures)

self.is_generation = is_generation_model(
self.hf_config.architectures, self.server_args.is_embedding
)

if server_args.context_length is not None:
self.context_len = server_args.context_length
Expand Down
1 change: 1 addition & 0 deletions python/sglang/srt/managers/tp_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ def __init__(
context_length=server_args.context_length,
model_overide_args=model_overide_args,
)

self.model_runner = ModelRunner(
model_config=self.model_config,
mem_fraction_static=server_args.mem_fraction_static,
Expand Down
17 changes: 13 additions & 4 deletions python/sglang/srt/model_executor/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ def load_model(self):
else None
)
self.is_generation = is_generation_model(
self.model_config.hf_config.architectures
self.model_config.hf_config.architectures, self.server_args.is_embedding
)

logger.info(
Expand Down Expand Up @@ -522,9 +522,18 @@ def forward_extend(self, batch: ScheduleBatch):
batch,
forward_mode=ForwardMode.EXTEND,
)
return self.model.forward(
batch.input_ids, input_metadata.positions, input_metadata
)
if self.is_generation:
return self.model.forward(
batch.input_ids, input_metadata.positions, input_metadata
)
else:
# Only embedding models have get_embedding parameter
return self.model.forward(
batch.input_ids,
input_metadata.positions,
input_metadata,
get_embedding=True,
)

@torch.inference_mode()
def forward_extend_multi_modal(self, batch: ScheduleBatch):
Expand Down
4 changes: 4 additions & 0 deletions python/sglang/srt/models/llama_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,11 @@ def forward(
positions: torch.Tensor,
input_metadata: InputMetadata,
input_embeds: torch.Tensor = None,
get_embedding: bool = True,
) -> EmbeddingPoolerOutput:
assert (
get_embedding
), "LlamaEmbeddingModel / MistralModel is only used for embedding"
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
return self.pooler(hidden_states, input_metadata)

Expand Down
12 changes: 9 additions & 3 deletions python/sglang/srt/models/qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.pooler import EmbeddingPoolerOutput, Pooler, PoolingType
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.forward_batch_info import InputMetadata

Expand Down Expand Up @@ -275,6 +276,7 @@ def __init__(
self.model = Qwen2Model(config, quant_config=quant_config)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config)
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)

@torch.no_grad()
def forward(
Expand All @@ -283,11 +285,15 @@ def forward(
positions: torch.Tensor,
input_metadata: InputMetadata,
input_embeds: torch.Tensor = None,
get_embedding: bool = False,
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
return self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, input_metadata
)
if not get_embedding:
return self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, input_metadata
)
else:
return self.pooler(hidden_states, input_metadata)

def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
Expand Down
3 changes: 3 additions & 0 deletions python/sglang/srt/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,11 +333,13 @@ def launch_server(
start_process = start_controller_process_single
else:
start_process = start_controller_process_multi

proc_controller = mp.Process(
target=start_process,
args=(server_args, port_args, pipe_controller_writer, model_overide_args),
)
proc_controller.start()

proc_detoken = mp.Process(
target=start_detokenizer_process,
args=(
Expand Down Expand Up @@ -515,6 +517,7 @@ def __init__(

self.pid = None
pipe_reader, pipe_writer = mp.Pipe(duplex=False)

proc = mp.Process(
target=launch_server,
args=(self.server_args, model_overide_args, pipe_writer),
Expand Down
11 changes: 11 additions & 0 deletions python/sglang/srt/server_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ class ServerArgs:
quantization: Optional[str] = None
served_model_name: Optional[str] = None
chat_template: Optional[str] = None
is_embedding: bool = False

# Port
host: str = "127.0.0.1"
Expand Down Expand Up @@ -200,6 +201,11 @@ def add_cli_args(parser: argparse.ArgumentParser):
action="store_true",
help="Whether or not to allow for custom models defined on the Hub in their own modeling files.",
)
parser.add_argument(
"--is-embedding",
action="store_true",
help="Whether to use a CausalLM as an embedding model.",
)
parser.add_argument(
"--context-length",
type=int,
Expand Down Expand Up @@ -458,6 +464,11 @@ def check_server_args(self):
assert not (
self.dp_size > 1 and self.node_rank is not None
), "multi-node data parallel is not supported"
if "Alibaba-NLP/gte-Qwen2-1.5B-instruct" == self.model_path:
logger.info(
"Not sure why, the tokenizer will add an additional token at the end of the prompt when trust_remote_mode=True"
)
self.trust_remote_code = False
if "gemma-2" in self.model_path.lower():
logger.info("When using sliding window in gemma-2, turn on flashinfer.")
self.disable_flashinfer = False
Expand Down
9 changes: 7 additions & 2 deletions python/sglang/srt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,13 +224,18 @@ def is_multimodal_model(model):
raise ValueError("unrecognized type")


def is_generation_model(model_architectures):
def is_generation_model(model_architectures, is_embedding: bool = False):
# We have two ways to determine whether a model is a generative model.
# 1. Check the model architectue
# 2. check the `is_embedding` server args

if (
"LlamaEmbeddingModel" in model_architectures
or "MistralModel" in model_architectures
):
return False
return True
else:
return not is_embedding


def decode_video_base64(video_base64):
Expand Down
32 changes: 16 additions & 16 deletions python/sglang/test/runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
"""

import json
import multiprocessing
import multiprocessing as mp
import os
from dataclasses import dataclass
from typing import List, Union
Expand Down Expand Up @@ -63,37 +63,35 @@ def __init__(
self,
model_path,
torch_dtype,
is_generation_model,
is_generation,
):
self.in_queue = multiprocessing.Queue()
self.out_queue = multiprocessing.Queue()
self.is_generation = is_generation

self.model_proc = multiprocessing.Process(
self.in_queue = mp.Queue()
self.out_queue = mp.Queue()

self.model_proc = mp.Process(
target=self.start_model_process,
args=(
self.in_queue,
self.out_queue,
model_path,
torch_dtype,
is_generation_model,
),
)
self.model_proc.start()

def start_model_process(
self, in_queue, out_queue, model_path, torch_dtype, is_generation_model
):
def start_model_process(self, in_queue, out_queue, model_path, torch_dtype):
self.tokenizer = AutoTokenizer.from_pretrained(
model_path,
torch_dtype=torch_dtype,
)

self.is_generation_model = is_generation_model

if self.is_generation_model:
if self.is_generation:
self.model = AutoModelForCausalLM.from_pretrained(
model_path,
torch_dtype=torch_dtype,
trust_remote_code=False,
low_cpu_mem_usage=True,
).cuda()
else:
Expand All @@ -107,7 +105,7 @@ def start_model_process(
while True:
prompts, max_new_tokens = in_queue.get()
if prompts is not None:
if self.is_generation_model:
if self.is_generation:
output_strs = []
prefill_logprobs = []
for p in prompts:
Expand Down Expand Up @@ -171,25 +169,27 @@ def __init__(
self,
model_path,
torch_dtype,
is_generation_model,
is_generation,
tp_size=1,
port=5157,
):
self.is_generation_model = is_generation_model
self.is_generation = is_generation
self.runtime = Runtime(
model_path=model_path,
tp_size=tp_size,
dtype=get_dtype_str(torch_dtype),
port=port,
mem_fraction_static=0.7,
trust_remote_code=False,
is_embedding=not self.is_generation,
)

def forward(
self,
prompts: Union[List[str], List[torch.Tensor]] = DEFAULT_PROMPTS,
max_new_tokens=8,
):
if self.is_generation_model:
if self.is_generation:
# the return value contains logprobs from prefill
output_strs = []
top_input_logprobs = []
Expand Down
28 changes: 13 additions & 15 deletions test/srt/models/test_embedding_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,10 @@
from sglang.test.runners import DEFAULT_PROMPTS, HFRunner, SRTRunner
from sglang.test.test_utils import get_similarities

MODELS = [("intfloat/e5-mistral-7b-instruct", 1, 0.2)]
MODELS = [
("Alibaba-NLP/gte-Qwen2-1.5B-instruct", 1, 1e-5),
("intfloat/e5-mistral-7b-instruct", 1, 1e-5),
]
TORCH_DTYPES = [torch.float16]


Expand All @@ -32,22 +35,20 @@ def assert_close_prefill_logits(
model_path,
tp_size,
torch_dtype,
long_context_tolerance,
prefill_tolerance,
) -> None:
with HFRunner(
model_path, torch_dtype=torch_dtype, is_generation_model=False
model_path, torch_dtype=torch_dtype, is_generation=False
) as hf_runner:
hf_outputs = hf_runner.forward(prompts)

with SRTRunner(
model_path,
tp_size=tp_size,
torch_dtype=torch_dtype,
is_generation_model=False,
is_generation=False,
) as srt_runner:
srt_outputs = srt_runner.forward(
prompts,
)
srt_outputs = srt_runner.forward(prompts)

for i in range(len(prompts)):
hf_logits = torch.Tensor(hf_outputs.embed_logits[i])
Expand All @@ -57,18 +58,15 @@ def assert_close_prefill_logits(
print("similarity diff", abs(similarity - 1))

if len(prompts[i]) <= 1000:
tolerance = 1e-5
else:
tolerance = long_context_tolerance
assert torch.all(
abs(similarity - 1) < tolerance
), "embeddings are not all close"
assert torch.all(
abs(similarity - 1) < prefill_tolerance
), "embeddings are not all close"

def test_prefill_logits(self):
for model, tp_size, long_context_tolerance in MODELS:
for model, tp_size, prefill_tolerance in MODELS:
for torch_dtype in TORCH_DTYPES:
self.assert_close_prefill_logits(
DEFAULT_PROMPTS, model, tp_size, torch_dtype, long_context_tolerance
DEFAULT_PROMPTS, model, tp_size, torch_dtype, prefill_tolerance
)


Expand Down
Loading

0 comments on commit 30b4f77

Please sign in to comment.