Skip to content

Commit

Permalink
Improve endpoint tests and bug fix in endpoint model
Browse files Browse the repository at this point in the history
  • Loading branch information
sadra-barikbin committed Sep 6, 2024
1 parent b291871 commit 8c0018e
Show file tree
Hide file tree
Showing 6 changed files with 181 additions and 36 deletions.
7 changes: 4 additions & 3 deletions src/lighteval/models/endpoint_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,11 +205,10 @@ def _process_generate_response(self, response: EndpointOutput, request: GreedyUn
def _process_logprob_response(
self, response: TextGenerationOutput, request: LoglikelihoodRequest | LoglikelihoodRollingRequest
) -> LoglikelihoodResponse:
len_choice = len(request.tokenized_continuation)
logits = sum([t.logprob for t in response.details.prefill[1:][-len_choice:]])
logits = sum([t.logprob for t in response.details.prefill[len(request.tokenized_context):]])
return LoglikelihoodResponse(
result=(logits, True) if isinstance(request, LoglikelihoodRequest) else logits,
input_tokens=[t.id for t in response.details.prefill[:-len_choice]],
input_tokens=[t.id for t in response.details.prefill[len(request.tokenized_context):]],
generated_tokens=-1,
truncated_tokens_count=-1,
padded_tokens_count=-1,
Expand Down Expand Up @@ -255,6 +254,7 @@ def _prepare_request(self, request: Request) -> EndpointInput:
context = request.context + [ChatCompletionInputMessage(role="assistant", content=request.choice)]
if not isinstance(context, str):
context = self.tokenizer.apply_chat_template(context, tokenize=False)
context = context.split(self.tokenizer.bos_token, 1)[-1]

if isinstance(context, str):
prepared_request = TextGenerationInput(
Expand Down Expand Up @@ -290,6 +290,7 @@ def greedy_until(
override_bs: Optional[int] = None,
) -> List[GenerativeResponse]:
for request in requests:
# Why don't we set context to empty list here?
request.tokenized_context = self.tok_encode(request.context)
request.stop_sequence = as_list(request.stop_sequence) + [self.tokenizer.eos_token]

Expand Down
2 changes: 1 addition & 1 deletion src/lighteval/models/tgi_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def __init__(self, address, auth_token=None, model_id=None) -> None:
self.model_info = ModelInfo(
model_name=model_id or self.name,
model_sha=info["model_sha"],
model_dtype=info["model_dtype"] or "default",
model_dtype=info["model_dtype"] if "model_dtype" in info else "default",
model_size=-1,
)
self._tokenizer = AutoTokenizer.from_pretrained(self.model_info.model_name)
Expand Down
12 changes: 12 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from typing import Iterator
import pytest

from lighteval.models.model_config import BaseModelConfig
from lighteval.models.abstract_model import EnvConfig
from lighteval.models.base_model import BaseModel


@pytest.fixture(scope="module")
def base_model() -> Iterator[BaseModel]:
config = BaseModelConfig("hf-internal-testing/tiny-random-LlamaForCausalLM")
return BaseModel(config, EnvConfig())
12 changes: 0 additions & 12 deletions tests/test_base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,16 +41,6 @@
)


TOKEN = os.environ.get("HF_TOKEN")
CACHE_PATH = os.getenv("HF_HOME", ".")


@pytest.fixture(scope="module")
def base_model() -> Iterator[BaseModel]:
config = BaseModelConfig("hf-internal-testing/tiny-random-LlamaForCausalLM")
return BaseModel(config, EnvConfig(CACHE_PATH, TOKEN))


RequestDict: TypeAlias = dict[RequestType, list[Request]]


Expand Down Expand Up @@ -122,11 +112,9 @@ def task(self) -> LightevalTask:
def test_integration(self, task: LightevalTask, base_model: BaseModel, num_fewshot: int, use_chat_template: bool):
base_model.use_chat_template = use_chat_template

env_config = EnvConfig(token=TOKEN, cache_dir=CACHE_PATH)
evaluation_tracker = EvaluationTracker()
pipeline_params = PipelineParameters(
launcher_type=ParallelismManager.NONE,
env_config=env_config,
use_chat_template=use_chat_template,
override_batch_size=1,
)
Expand Down
81 changes: 61 additions & 20 deletions tests/test_endpoint_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,27 +27,27 @@
from typing import Iterator, TypeAlias
from unittest.mock import patch

import torch
import docker
import docker.errors
import pytest
import requests
from huggingface_hub import ChatCompletionInputMessage
from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaForCausalLM, PreTrainedTokenizerFast

from lighteval.logging.evaluation_tracker import EvaluationTracker
from lighteval.metrics.metrics import Metrics
from lighteval.models.tgi_model import ModelClient as TGIModel
from lighteval.models.base_model import BaseModel
from lighteval.pipeline import ParallelismManager, Pipeline, PipelineParameters
from lighteval.tasks.lighteval_task import LightevalTask, LightevalTaskConfig, create_requests_from_tasks
from lighteval.tasks.requests import (
LoglikelihoodRequest,
LoglikelihoodRollingRequest,
Doc,
Request,
RequestType,
)
from lighteval.utils.utils import EnvConfig


TOKEN = os.environ.get("HF_TOKEN")
CACHE_PATH = os.getenv("HF_HOME", ".")


@pytest.fixture(scope="module")
Expand Down Expand Up @@ -83,12 +83,19 @@ def tgi_model() -> Iterator[TGIModel]:
raise RuntimeError("Couldn't setup TGI server.")
model = TGIModel(address)
yield model
container.stop()
container.wait()
container.remove()
# container.stop()
# container.wait()
# container.remove()
model.cleanup()


@pytest.fixture(scope="module")
def reference_model_tokenizer() -> tuple[LlamaForCausalLM, PreTrainedTokenizerFast]:
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-LlamaForCausalLM")
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-LlamaForCausalLM")
return model, tokenizer


RequestDict: TypeAlias = dict[RequestType, list[Request]]


Expand Down Expand Up @@ -150,28 +157,62 @@ def zero_shot_request_dict(self, task: LightevalTask) -> RequestDict:
result[req_type].extend(doc_result[req_type])
return result

def test_greedy_until(self, zero_shot_request_dict: RequestDict, tgi_model: TGIModel):
returns = tgi_model.greedy_until(zero_shot_request_dict[RequestType.GREEDY_UNTIL])
def test_greedy_until(self, reference_model_tokenizer: tuple[LlamaForCausalLM, PreTrainedTokenizerFast], zero_shot_request_dict: RequestDict, tgi_model: TGIModel):
requests = zero_shot_request_dict[RequestType.GREEDY_UNTIL]
returns = tgi_model.greedy_until(requests)
model, tokenizer = reference_model_tokenizer
assert len(returns) == 2
assert all(r.result is not None for r in returns)

def test_loglikelihood(self, zero_shot_request_dict: RequestDict, tgi_model: TGIModel):
returns = tgi_model.loglikelihood(zero_shot_request_dict[RequestType.LOGLIKELIHOOD])
for req, res in zip(requests, returns):
is_chat = not isinstance(req.context, str)
tokenized_context = tokenizer.apply_chat_template(req.context, return_tensors='pt') if is_chat else tokenizer(req.context, return_tensors='pt')['input_ids']
ref_context_continuaiton = model.generate(tokenized_context, tokenizer=tokenizer, stop_strings=req.stop_sequence, max_new_tokens=req.generation_size)[0].tolist()
continuation = tokenizer.decode(ref_context_continuaiton)[len(tokenizer.decode(tokenized_context[0].tolist())):]
assert continuation == res.result

def test_loglikelihood(self, reference_model_tokenizer: tuple[LlamaForCausalLM, PreTrainedTokenizerFast], zero_shot_request_dict: RequestDict, tgi_model: TGIModel):
requests: list[LoglikelihoodRequest] = zero_shot_request_dict[RequestType.LOGLIKELIHOOD]
returns = tgi_model.loglikelihood(requests)
model, tokenizer = reference_model_tokenizer
assert len(returns) == 4
assert all(r.result is not None for r in returns)

for req, res in zip(requests, returns):
is_chat = not isinstance(req.context, str)
sequence = req.context + [ChatCompletionInputMessage(role='assistant',content=req.choice)] if is_chat else req.context+req.choice
tokenized_sequence = tokenizer.apply_chat_template(sequence, return_tensors='pt') if is_chat else tokenizer(sequence, return_tensors='pt')['input_ids']

output = model.generate(tokenized_sequence, max_new_tokens=1, return_dict_in_generate=True, output_hidden_states=True)
with torch.no_grad():
logprobs = torch.log_softmax(model.lm_head(output.hidden_states[0][-1]),dim=-1)
logprobs = logprobs.gather(dim=-1, index=tokenized_sequence[:,1:].unsqueeze(-1))
context_length = len(tokenizer.apply_chat_template(req.context)) if is_chat else len(tokenizer.encode(req.context))
continuation_logprob = logprobs[:, context_length-1:].sum()

tokenized_choice = tokenized_sequence[:, context_length:]
assert tokenized_choice[0].tolist() == res.input_tokens
assert torch.allclose(torch.tensor(res.result[0]), continuation_logprob)

def test_loglikelihood_rolling(self, reference_model_tokenizer: tuple[LlamaForCausalLM, PreTrainedTokenizerFast], zero_shot_request_dict: RequestDict, tgi_model: TGIModel):
model, tokenizer = reference_model_tokenizer
requests: list[LoglikelihoodRollingRequest] = zero_shot_request_dict[RequestType.LOGLIKELIHOOD_ROLLING]
returns = tgi_model.loglikelihood_rolling(zero_shot_request_dict[RequestType.LOGLIKELIHOOD_ROLLING])
assert len(returns) == 2
assert all(r.result is not None for r in returns)
for req, res in zip(requests, returns):
is_chat = not isinstance(req.context, str)
tokenized_context = tokenizer.apply_chat_template(req.context, return_tensors='pt') if is_chat else tokenizer(req.context, return_tensors='pt')['input_ids']
output = model.generate(tokenized_context, max_new_tokens=1, return_dict_in_generate=True, output_hidden_states=True)
with torch.no_grad():
logprobs = torch.log_softmax(model.lm_head(output.hidden_states[0][-1]),dim=-1)
logprob = logprobs.gather(dim=-1, index=tokenized_context[:,1:].unsqueeze(-1)).sum()

assert tokenized_context[0, 1:].tolist() == res.input_tokens
assert torch.allclose(torch.tensor(res.result), logprob)

@pytest.mark.parametrize("num_fewshot", [0, 2])
@pytest.mark.parametrize("use_chat_template", [False, True])
def test_integration(self, task: LightevalTask, tgi_model: TGIModel, num_fewshot: int, use_chat_template: bool):
env_config = EnvConfig(token=TOKEN, cache_dir=CACHE_PATH)
def test_integration(self, task: LightevalTask, base_model: BaseModel, tgi_model: TGIModel, num_fewshot: int, use_chat_template: bool):
#TODO
evaluation_tracker = EvaluationTracker()
pipeline_params = PipelineParameters(
launcher_type=ParallelismManager.NONE,
env_config=env_config,
use_chat_template=use_chat_template,
)

Expand Down
103 changes: 103 additions & 0 deletions tests/test_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
import time
import random
import asyncio
from typing import Iterator

import pytest
import docker
import requests
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, LlamaForCausalLM
from huggingface_hub import (
InferenceClient,
AsyncInferenceClient,
TextGenerationOutput,
)


@pytest.fixture(params=["sync", "async"])
def tgi_client(request) -> Iterator[InferenceClient|AsyncInferenceClient]:
client = docker.from_env()

try:
container = client.containers.get("lighteval-tgi-model-test")
port = container.ports["80/tcp"][0]["HostPort"]
except docker.errors.NotFound:
port = random.randint(8000, 9000)
container = client.containers.run(
"ghcr.io/huggingface/text-generation-inference:2.2.0",
command=[
"--model-id",
"hf-internal-testing/tiny-random-LlamaForCausalLM",
"--dtype",
"float16",
],
detach=True,
name="lighteval-tgi-model-test",
auto_remove=False,
ports={"80/tcp": port},
)
address = f"http://localhost:{port}"
for _ in range(40):
try:
if requests.get(f"{address}/health"):
break
except Exception:
time.sleep(1)
else:
raise RuntimeError("Couldn't setup TGI server.")

if request.param == "async":
yield AsyncInferenceClient(base_url=address)
elif request.param == "sync":
yield InferenceClient(base_url=address)
else:
raise RuntimeError()


def test_logprobs(tgi_client: InferenceClient|AsyncInferenceClient):
model: LlamaForCausalLM = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-LlamaForCausalLM")
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-LlamaForCausalLM")

# It raises error in async setting unless the size of `prompts` is < 3
prompts = [
"Tell me:\n\nHow are you?Fine, thanks!",
"Tell me:\n\nHow are you?Not bad!",
"Tell me:\n\nComment vas-tu?Comme ci, comme ça",
"Tell me:\n\nComment vas-tu?Ca va! Merci!",
]
responses = []
for prompt in prompts:
responses.append(tgi_client.text_generation(
prompt,
details=True,
decoder_input_details=True,
max_new_tokens=1,
stop_sequences=None,
do_sample=False,
return_full_text=False,
seed=42,
))
if isinstance(tgi_client, AsyncInferenceClient):
loop = asyncio.get_event_loop()
responses: list[TextGenerationOutput] = loop.run_until_complete(asyncio.gather(*responses))

error = False
for prompt, response in zip(prompts, responses):

tgi_logprobs = torch.tensor([t.logprob for t in response.details.prefill[1:]]) # Skipping <s> whose logprob is None

tokenized_sequence = tokenizer(prompt, return_tensors='pt')['input_ids']
output = model.generate(tokenized_sequence, max_new_tokens=1, return_dict_in_generate=True, output_hidden_states=True)
with torch.no_grad():
logprobs = torch.log_softmax(model.lm_head(output.hidden_states[0][-1]),dim=-1)
logprobs = logprobs.gather(dim=-1, index=tokenized_sequence[:,1:].unsqueeze(-1)).squeeze()

if not torch.allclose(logprobs.sum(), tgi_logprobs.sum()):
print(f"====== prompt: {repr(prompt)} ======")
print("TGI logprobs:", tgi_logprobs.tolist())
print("TGI tokens:",[t.id for t in response.details.prefill])
print("Ref. logprobs:", logprobs.tolist())
print("Ref. tokens:", tokenized_sequence[0].tolist())
error = True
assert not error

0 comments on commit 8c0018e

Please sign in to comment.