Skip to content

Commit

Permalink
Add mistral backend (#5)
Browse files Browse the repository at this point in the history
* Fix some mypy errors

* Delete a useless line

* Change the query function to reflect changes in vLLM

* Add mistral backend

* Added docstrings and fixed the existing tests

* Added new unit tests

* Fix a mypy error

* Refactor to have BackEnd classes

* Fix mypy errors

* Update readme

* Refactor to take into account /metrics endpoint which can be implemented
  • Loading branch information
gsolard authored Aug 1, 2024
1 parent 2b13830 commit 150d655
Show file tree
Hide file tree
Showing 20 changed files with 770 additions and 198 deletions.
1 change: 1 addition & 0 deletions .env.example
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# MODEL=
# BASE_URL=
# MODEL_NAME=
HOST="localhost"
PORT=8000

Expand Down
9 changes: 5 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
![Build & Tests](https://github.com/France-Travail/benchmark_llm_serving/actions/workflows/build_and_tests.yaml/badge.svg)
![Wheel setup](https://github.com/France-Travail/benchmark_llm_serving/actions/workflows/wheel.yaml/badge.svg)

benchmark_llm_serving is a script aimed at benchmarking the serving API of LLMs. For now, it is focused on LLMs served via [vllm](https://github.com/vllm-project/vllm) and more specifically via [happy-vllm](https://github.com/France-Travail/happy_vllm) which is an API layer on vLLM adding new endpoints and permitting a configuration via environment variables.
benchmark_llm_serving is a script aimed at benchmarking the serving API of LLMs. For now, two backends are implemented : [mistral](https://docs.mistral.ai/api/) and [vLLM](https://github.com/vllm-project/vllm) (via [happy-vllm](https://github.com/France-Travail/happy_vllm) which is an API layer on vLLM adding new endpoints and permitting a configuration via environment variables).

## Installation

Expand Down Expand Up @@ -45,8 +45,8 @@ After the bench suite ends, you obtain a folder containing :
- `prompt_ingestion_graph.png` containing the graph of the speed of prompt ingestion by the model. It is the time taken to produce the first token vs the length of the prompt. The speed is the slope of this line and is indicated in the title of the graph. The data used for this graph is contained in the `data` folder.
- `thresholds.csv` is a .csv containing, for each couple of input length/output length, the number of parallel requests such that : the kv cache usage is inferior to 100% and the speed generation is above a specified threshold (by default, 20 tokens per second)
- `total_speed_generation_graph.png` is a graph containing, for each couple of input length/output length, the total speed generation vs the number of parallel requests. So, for example, if the model can answer to 10 parallel requests each with a speed of 20 tokens per second, the value on the graph will be 200 tokens per second (20 x 10). The data used for this graph is contained in the `data` folder.
- A folder `kv_cache_profile` containing, for each couple of input length/output length, a graph showing the response of the LLMs to n requests launched at the same time. On the y-axis, you have the kv cache usage, the number of requests running and the number of requests waiting. On the x-axis, you have the time. The graph is obtained by sending one request, watching the response of the LLM then two requests, then three requests, ...
- A folder `speed_generation` containing, for each couple of input length/output length, a graph showing the speed generation (per request) in token per second vs the number of parallel requests. The graph also shows the time to the first token generated in milliseconds and the max kv cache usage for this number of parallel requests. The corresponding data is in the `data` folder
- If the backend is `happy_vllm` : a folder `kv_cache_profile` containing, for each couple of input length/output length, a graph showing the response of the LLMs to n requests launched at the same time. On the y-axis, you have the kv cache usage, the number of requests running and the number of requests waiting. On the x-axis, you have the time. The graph is obtained by sending one request, watching the response of the LLM then two requests, then three requests, ...
- A folder `speed_generation` containing, for each couple of input length/output length, a graph showing the speed generation (per request) in token per second vs the number of parallel requests. The graph also shows the time to the first token generated in milliseconds. If the backend is `happy_vllm` it also shows the max kv cache usage for this number of parallel requests. The corresponding data is in the `data` folder

Note that the various input lengths are "32", "1024" and "4096" to simulate small, medium and long prompt. These length are to be understood as roughly this size (and generally speaking a bit above this size). The various output lengths are 16, 128 and 1024. Contrary to the input lengths, these are exact : the model produced exactly this number of tokens.

Expand All @@ -70,9 +70,10 @@ Here is a list of the arguments:
- `min-duration-speed-generation` : For each individual script benchmarking the speed generation, if this min duration (in seconds) is reached and the target-queries-nb is also reached, the script will end (default `60`)
- `target-queries-nb-speed-generation` : For each individual script benchmarking the speed generation, if this target-queries-nb is reached and the min-duration is also reached, the script will end (default `100`)
- `min-number-of-valid-queries`: The minimal number of valid queries that should be present in a file to be considered for graph drawing (default `50`)
- `backend` : For now, only happy_vllm is supported.
- `backend` : Only `happy_vllm`and `mistral` are supported.
- `completions-endpoint` : The endpoint for completions (default `/v1/completions`)
- `metrics-endpoint` : The endpoint for the metrics (default `/metrics/`)
- `info-endpoint` : The info endpoint (default `/v1/info`)
- `launch-arguments-endpoint` : The endpoint for getting the launch arguments of the API (default `/v1/launch_arguments`)
- `speed-threshold` : The speed generation above which the model is considered ok (default value `20`). It is only useful when writing `thresholds.csv`
- `model-name`: The name that should be displayed on the graph (default value : `None`). If it is `None`, the name displayed will be the one of the argument `model`
10 changes: 5 additions & 5 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@ requires-python = ">=3.10,<4.0"
dependencies = [
"aiohttp>=3.9.5,<4.0",
"prometheus_client>=0.20.0,<1.0",
"matplotlib>=3.8.4,<4.0",
"pydantic>=2.7.1,<3.0",
"pydantic-settings>=2.2.1,<3.0",
"requests>=2.32.0,<3.0",
"matplotlib>=3.9.1,<4.0",
"pydantic>=2.8.2,<3.0",
"pydantic-settings>=2.3.4,<3.0",
"requests>=2.32.3,<3.0",
"mdutils>=1.6.0,<2.0"
]
classifiers = [
Expand All @@ -44,7 +44,7 @@ include = ["benchmark_llm_serving*"]
bench-suite = "benchmark_llm_serving.bench_suite:main"

[project.optional-dependencies]
test = ["httpx>=0.23,<1.0", "pytest>=8.2.0,<9.0", "pytest-cov>=5.0.0,<6.0", "mypy>=1.7.1,<2.0", "pytest-asyncio>=0.23.6,<1.0",
test = ["httpx>=0.27,<1.0", "pytest>=8.3.2,<9.0", "pytest-cov>=5.0.0,<6.0", "mypy>=1.11.0,<2.0", "pytest-asyncio>=0.23.8,<1.0",
"aioresponses>=0.7.6,<1.0", "requests-mock>=1.12.1,<2.0"]

[tool.pytest.ini_options]
Expand Down
8 changes: 4 additions & 4 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
aiohttp==3.9.5
prometheus_client==0.20.0
matplotlib==3.8.4
pydantic==2.7.1
pydantic-settings==2.2.1
requests==2.32.0
matplotlib==3.9.1
pydantic==2.8.2
pydantic-settings==2.3.4
requests==2.32.3
mdutils==1.6.0
226 changes: 226 additions & 0 deletions src/benchmark_llm_serving/backends.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,226 @@
import argparse

from benchmark_llm_serving.io_classes import QueryOutput, QueryInput


class BackEnd():

TEMPERATURE = 0
REPETITION_PENALTY = 1.2

def __init__(self, backend_name: str, chunk_prefix: str = "data: ", last_chunk: str = "[DONE]", metrics_endpoint_exists: bool = True):
self.backend_name = backend_name
self.chunk_prefix = chunk_prefix
self.last_chunk = last_chunk
self.metrics_endpoint_exists = metrics_endpoint_exists

def get_payload(self, query_input: QueryInput, args: argparse.Namespace) -> dict:
"""Gets the payload to give to the model
Args:
query_input (QueryInput) : The query input to use
args (argparse.Namespace) : The cli args
Returns:
dict : The payload
"""
raise NotImplemented("The subclass should implement this method") # type: ignore

def get_newly_generated_text(self, json_chunk: dict) -> str:
"""Gets the newly generated text
Args:
json_chunk (dict) : The chunk containing the generated text
Returns:
str : The newly generated text
"""
raise NotImplemented("The subclass should implement this method") # type: ignore

def get_metrics_from_metrics_dict(self, metrics_dict: dict) -> dict:
"""Gets the useful metrics from the parsed output of the /metrics endpoint
Args:
metrics_dict (dict) : The parsed output of the /metrics endpoint
Returns:
dict : The useful metrics
"""
raise NotImplemented("The subclass should implement this method if metrics_endpoint_exists") # type: ignore

def test_chunk_validity(self, chunk: str) -> bool:
"""Tests if the chunk is valid or should not be considered.
Args:
chunk (str) : The chunk to consider
Returns:
bool : Whether the chunk is valid or not
"""
return True

def get_completions_headers(self) -> dict:
"""Gets the headers (depending on the backend) to use for the request
Returns:
dict: The headers
"""
return {}

def remove_response_prefix(self, chunk: str) -> str:
"""Removes the prefix in the response of a model
Args:
chunk (str) : The chunk received
Returns:
str : The string without the prefix
"""
return chunk.removeprefix(self.chunk_prefix)

def check_end_of_stream(self, chunk: str) -> bool:
"""Checks whether this is the last chunk of the stream
Args:
chunk (str) : The chunk to test
Returns:
bool : Whether it is the last chunk of the stream
"""
return chunk == self.last_chunk

def add_prompt_length(self, json_chunk: dict, output: QueryOutput) -> None:
"""Add the prompt length to the QueryOutput if the key "usage" is in the chunk
Args:
json_chunk (dict) : The chunk containing the prompt length
output (QueryOutput) : The output
"""
if "usage" in json_chunk:
if json_chunk['usage'] is not None:
output.prompt_length = json_chunk['usage']['prompt_tokens']


class BackendHappyVllm(BackEnd):

def get_payload(self, query_input: QueryInput, args: argparse.Namespace) -> dict:
"""Gets the payload to give to the model
Args:
query_input (QueryInput) : The query input to use
args (argparse.Namespace) : The cli args
Returns:
dict : The payload
"""
return {"prompt": query_input.prompt,
"model": args.model,
"max_tokens": args.output_length,
"min_tokens": args.output_length,
"temperature": self.TEMPERATURE,
"repetition_penalty": self.REPETITION_PENALTY,
"stream": True,
"stream_options": {"include_usage": True}
}

def get_newly_generated_text(self, json_chunk: dict) -> str:
"""Gets the newly generated text
Args:
json_chunk (dict) : The chunk containing the generated text
Returns:
str : The newly generated text
"""
if len(json_chunk['choices']):
data = json_chunk['choices'][0]['text']
return data
else:
return ""

def get_metrics_from_metrics_dict(self, metrics_dict: dict) -> dict:
"""Gets the useful metrics from the parsed output of the /metrics endpoint
Args:
metrics_dict (dict) : The parsed output of the /metrics endpoint
Returns:
dict : The useful metrics
"""
metrics = {}
metrics['num_requests_running'] = metrics_dict['vllm:num_requests_running'][0]['value']
metrics['num_requests_waiting'] = metrics_dict['vllm:num_requests_waiting'][0]['value']
metrics['gpu_cache_usage_perc'] = metrics_dict['vllm:gpu_cache_usage_perc'][0]['value']
return metrics


class BackEndMistral(BackEnd):

def get_payload(self, query_input: QueryInput, args: argparse.Namespace) -> dict:
"""Gets the payload to give to the model
Args:
query_input (QueryInput) : The query input to use
args (argparse.Namespace) : The cli args
Returns:
dict : The payload
"""
return {"messages": [{"role": "user", "content": query_input.prompt}],
"model": args.model,
"max_tokens": args.output_length,
"min_tokens": args.output_length,
"temperature": self.TEMPERATURE,
"stream": True
}

def test_chunk_validity(self, chunk: str) -> bool:
"""Tests if the chunk is valid or should not be considered.
Args:
chunk (str) : The chunk to consider
Returns:
bool : Whether the chunk is valid or not
"""
if chunk[:4] == "tok-":
return False
else:
return True

def get_completions_headers(self) -> dict:
"""Gets the headers (depending on the backend) to use for the request
Returns:
dict: The headers
"""
return {"Accept": "application/json",
"Content-Type": "application/json"}

def get_newly_generated_text(self, json_chunk: dict) -> str:
"""Gets the newly generated text
Args:
json_chunk (dict) : The chunk containing the generated text
Returns:
str : The newly generated text
"""
if len(json_chunk['choices']):
data = json_chunk['choices'][0]['delta']["content"]
return data
else:
return ""


def get_backend(backend_name: str) -> BackEnd:
implemented_backends = ["mistral", "happy_vllm"]
if backend_name not in implemented_backends:
raise ValueError(f"The specified backend {backend_name} is not implemented. Please use one of the following : {implemented_backends}")
if backend_name == "happy_vllm":
return BackendHappyVllm(backend_name, chunk_prefix="data: ", last_chunk="[DONE]", metrics_endpoint_exists=True)
if backend_name == "mistral":
return BackEndMistral(backend_name, chunk_prefix="data: ", last_chunk="[DONE]", metrics_endpoint_exists=False)
return BackEnd("not_implemented")
Loading

0 comments on commit 150d655

Please sign in to comment.