Skip to content

Commit

Permalink
Merge pull request #144 from tackhwa/lmdeploy_backend
Browse files Browse the repository at this point in the history
[Feature]  support lmdeploy backend
  • Loading branch information
LarFii authored Oct 28, 2024
2 parents 4d078e9 + 81d5b90 commit b22049c
Show file tree
Hide file tree
Showing 3 changed files with 208 additions and 1 deletion.
75 changes: 75 additions & 0 deletions examples/lightrag_lmdeploy_demo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
import os

from lightrag import LightRAG, QueryParam
from lightrag.llm import lmdeploy_model_if_cache, hf_embedding
from lightrag.utils import EmbeddingFunc
from transformers import AutoModel, AutoTokenizer

WORKING_DIR = "./dickens"

if not os.path.exists(WORKING_DIR):
os.mkdir(WORKING_DIR)


async def lmdeploy_model_complete(
prompt=None, system_prompt=None, history_messages=[], **kwargs
) -> str:
model_name = kwargs["hashing_kv"].global_config["llm_model_name"]
return await lmdeploy_model_if_cache(
model_name,
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
## please specify chat_template if your local path does not follow original HF file name,
## or model_name is a pytorch model on huggingface.co,
## you can refer to https://github.com/InternLM/lmdeploy/blob/main/lmdeploy/model.py
## for a list of chat_template available in lmdeploy.
chat_template="llama3",
# model_format ='awq', # if you are using awq quantization model.
# quant_policy=8, # if you want to use online kv cache, 4=kv int4, 8=kv int8.
**kwargs,
)


rag = LightRAG(
working_dir=WORKING_DIR,
llm_model_func=lmdeploy_model_complete,
llm_model_name="meta-llama/Llama-3.1-8B-Instruct", # please use definite path for local model
embedding_func=EmbeddingFunc(
embedding_dim=384,
max_token_size=5000,
func=lambda texts: hf_embedding(
texts,
tokenizer=AutoTokenizer.from_pretrained(
"sentence-transformers/all-MiniLM-L6-v2"
),
embed_model=AutoModel.from_pretrained(
"sentence-transformers/all-MiniLM-L6-v2"
),
),
),
)


with open("./book.txt", "r", encoding="utf-8") as f:
rag.insert(f.read())

# Perform naive search
print(
rag.query("What are the top themes in this story?", param=QueryParam(mode="naive"))
)

# Perform local search
print(
rag.query("What are the top themes in this story?", param=QueryParam(mode="local"))
)

# Perform global search
print(
rag.query("What are the top themes in this story?", param=QueryParam(mode="global"))
)

# Perform hybrid search
print(
rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid"))
)
133 changes: 132 additions & 1 deletion lightrag/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,9 @@ async def hf_model_if_cache(
output = hf_model.generate(
**input_ids, max_new_tokens=512, num_return_sequences=1, early_stopping=True
)
response_text = hf_tokenizer.decode(output[0][len(inputs["input_ids"][0]):], skip_special_tokens=True)
response_text = hf_tokenizer.decode(
output[0][len(inputs["input_ids"][0]) :], skip_special_tokens=True
)
if hashing_kv is not None:
await hashing_kv.upsert({args_hash: {"return": response_text, "model": model}})
return response_text
Expand Down Expand Up @@ -322,6 +324,135 @@ async def ollama_model_if_cache(
return result


@lru_cache(maxsize=1)
def initialize_lmdeploy_pipeline(
model,
tp=1,
chat_template=None,
log_level="WARNING",
model_format="hf",
quant_policy=0,
):
from lmdeploy import pipeline, ChatTemplateConfig, TurbomindEngineConfig

lmdeploy_pipe = pipeline(
model_path=model,
backend_config=TurbomindEngineConfig(
tp=tp, model_format=model_format, quant_policy=quant_policy
),
chat_template_config=ChatTemplateConfig(model_name=chat_template)
if chat_template
else None,
log_level="WARNING",
)
return lmdeploy_pipe


async def lmdeploy_model_if_cache(
model,
prompt,
system_prompt=None,
history_messages=[],
chat_template=None,
model_format="hf",
quant_policy=0,
**kwargs,
) -> str:
"""
Args:
model (str): The path to the model.
It could be one of the following options:
- i) A local directory path of a turbomind model which is
converted by `lmdeploy convert` command or download
from ii) and iii).
- ii) The model_id of a lmdeploy-quantized model hosted
inside a model repo on huggingface.co, such as
"InternLM/internlm-chat-20b-4bit",
"lmdeploy/llama2-chat-70b-4bit", etc.
- iii) The model_id of a model hosted inside a model repo
on huggingface.co, such as "internlm/internlm-chat-7b",
"Qwen/Qwen-7B-Chat ", "baichuan-inc/Baichuan2-7B-Chat"
and so on.
chat_template (str): needed when model is a pytorch model on
huggingface.co, such as "internlm-chat-7b",
"Qwen-7B-Chat ", "Baichuan2-7B-Chat" and so on,
and when the model name of local path did not match the original model name in HF.
tp (int): tensor parallel
prompt (Union[str, List[str]]): input texts to be completed.
do_preprocess (bool): whether pre-process the messages. Default to
True, which means chat_template will be applied.
skip_special_tokens (bool): Whether or not to remove special tokens
in the decoding. Default to be True.
do_sample (bool): Whether or not to use sampling, use greedy decoding otherwise.
Default to be False, which means greedy decoding will be applied.
"""
try:
import lmdeploy
from lmdeploy import version_info, GenerationConfig
except Exception:
raise ImportError("Please install lmdeploy before intialize lmdeploy backend.")

kwargs.pop("response_format", None)
max_new_tokens = kwargs.pop("max_tokens", 512)
tp = kwargs.pop("tp", 1)
skip_special_tokens = kwargs.pop("skip_special_tokens", True)
do_preprocess = kwargs.pop("do_preprocess", True)
do_sample = kwargs.pop("do_sample", False)
gen_params = kwargs

version = version_info
if do_sample is not None and version < (0, 6, 0):
raise RuntimeError(
"`do_sample` parameter is not supported by lmdeploy until "
f"v0.6.0, but currently using lmdeloy {lmdeploy.__version__}"
)
else:
do_sample = True
gen_params.update(do_sample=do_sample)

lmdeploy_pipe = initialize_lmdeploy_pipeline(
model=model,
tp=tp,
chat_template=chat_template,
model_format=model_format,
quant_policy=quant_policy,
log_level="WARNING",
)

messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})

hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
messages.extend(history_messages)
messages.append({"role": "user", "content": prompt})
if hashing_kv is not None:
args_hash = compute_args_hash(model, messages)
if_cache_return = await hashing_kv.get_by_id(args_hash)
if if_cache_return is not None:
return if_cache_return["return"]

gen_config = GenerationConfig(
skip_special_tokens=skip_special_tokens,
max_new_tokens=max_new_tokens,
**gen_params,
)

response = ""
async for res in lmdeploy_pipe.generate(
messages,
gen_config=gen_config,
do_preprocess=do_preprocess,
stream_response=False,
session_id=1,
):
response += res.response

if hashing_kv is not None:
await hashing_kv.upsert({args_hash: {"return": response, "model": model}})
return response


async def gpt_4o_complete(
prompt, system_prompt=None, history_messages=[], **kwargs
) -> str:
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,4 @@ tiktoken
torch
transformers
xxhash
# lmdeploy[all]

0 comments on commit b22049c

Please sign in to comment.