From 8c0018e9933c29a1916b8c6f95e8b0785084cfb4 Mon Sep 17 00:00:00 2001 From: Sadra Barikbin Date: Thu, 5 Sep 2024 11:26:40 +0330 Subject: [PATCH] Improve endpoint tests and bug fix in endpoint model --- src/lighteval/models/endpoint_model.py | 7 +- src/lighteval/models/tgi_model.py | 2 +- tests/conftest.py | 12 +++ tests/test_base_model.py | 12 --- tests/test_endpoint_model.py | 81 ++++++++++++++----- tests/test_test.py | 103 +++++++++++++++++++++++++ 6 files changed, 181 insertions(+), 36 deletions(-) create mode 100644 tests/conftest.py create mode 100644 tests/test_test.py diff --git a/src/lighteval/models/endpoint_model.py b/src/lighteval/models/endpoint_model.py index 3d85d67f..cfd71162 100644 --- a/src/lighteval/models/endpoint_model.py +++ b/src/lighteval/models/endpoint_model.py @@ -205,11 +205,10 @@ def _process_generate_response(self, response: EndpointOutput, request: GreedyUn def _process_logprob_response( self, response: TextGenerationOutput, request: LoglikelihoodRequest | LoglikelihoodRollingRequest ) -> LoglikelihoodResponse: - len_choice = len(request.tokenized_continuation) - logits = sum([t.logprob for t in response.details.prefill[1:][-len_choice:]]) + logits = sum([t.logprob for t in response.details.prefill[len(request.tokenized_context):]]) return LoglikelihoodResponse( result=(logits, True) if isinstance(request, LoglikelihoodRequest) else logits, - input_tokens=[t.id for t in response.details.prefill[:-len_choice]], + input_tokens=[t.id for t in response.details.prefill[len(request.tokenized_context):]], generated_tokens=-1, truncated_tokens_count=-1, padded_tokens_count=-1, @@ -255,6 +254,7 @@ def _prepare_request(self, request: Request) -> EndpointInput: context = request.context + [ChatCompletionInputMessage(role="assistant", content=request.choice)] if not isinstance(context, str): context = self.tokenizer.apply_chat_template(context, tokenize=False) + context = context.split(self.tokenizer.bos_token, 1)[-1] if isinstance(context, str): prepared_request = TextGenerationInput( @@ -290,6 +290,7 @@ def greedy_until( override_bs: Optional[int] = None, ) -> List[GenerativeResponse]: for request in requests: + # Why don't we set context to empty list here? request.tokenized_context = self.tok_encode(request.context) request.stop_sequence = as_list(request.stop_sequence) + [self.tokenizer.eos_token] diff --git a/src/lighteval/models/tgi_model.py b/src/lighteval/models/tgi_model.py index 0621ea12..c359a086 100644 --- a/src/lighteval/models/tgi_model.py +++ b/src/lighteval/models/tgi_model.py @@ -53,7 +53,7 @@ def __init__(self, address, auth_token=None, model_id=None) -> None: self.model_info = ModelInfo( model_name=model_id or self.name, model_sha=info["model_sha"], - model_dtype=info["model_dtype"] or "default", + model_dtype=info["model_dtype"] if "model_dtype" in info else "default", model_size=-1, ) self._tokenizer = AutoTokenizer.from_pretrained(self.model_info.model_name) diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 00000000..555b4396 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,12 @@ +from typing import Iterator +import pytest + +from lighteval.models.model_config import BaseModelConfig +from lighteval.models.abstract_model import EnvConfig +from lighteval.models.base_model import BaseModel + + +@pytest.fixture(scope="module") +def base_model() -> Iterator[BaseModel]: + config = BaseModelConfig("hf-internal-testing/tiny-random-LlamaForCausalLM") + return BaseModel(config, EnvConfig()) \ No newline at end of file diff --git a/tests/test_base_model.py b/tests/test_base_model.py index f622ff77..0d3c72be 100644 --- a/tests/test_base_model.py +++ b/tests/test_base_model.py @@ -41,16 +41,6 @@ ) -TOKEN = os.environ.get("HF_TOKEN") -CACHE_PATH = os.getenv("HF_HOME", ".") - - -@pytest.fixture(scope="module") -def base_model() -> Iterator[BaseModel]: - config = BaseModelConfig("hf-internal-testing/tiny-random-LlamaForCausalLM") - return BaseModel(config, EnvConfig(CACHE_PATH, TOKEN)) - - RequestDict: TypeAlias = dict[RequestType, list[Request]] @@ -122,11 +112,9 @@ def task(self) -> LightevalTask: def test_integration(self, task: LightevalTask, base_model: BaseModel, num_fewshot: int, use_chat_template: bool): base_model.use_chat_template = use_chat_template - env_config = EnvConfig(token=TOKEN, cache_dir=CACHE_PATH) evaluation_tracker = EvaluationTracker() pipeline_params = PipelineParameters( launcher_type=ParallelismManager.NONE, - env_config=env_config, use_chat_template=use_chat_template, override_batch_size=1, ) diff --git a/tests/test_endpoint_model.py b/tests/test_endpoint_model.py index 9550eb6a..98e33de0 100644 --- a/tests/test_endpoint_model.py +++ b/tests/test_endpoint_model.py @@ -27,27 +27,27 @@ from typing import Iterator, TypeAlias from unittest.mock import patch +import torch import docker import docker.errors import pytest import requests from huggingface_hub import ChatCompletionInputMessage +from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaForCausalLM, PreTrainedTokenizerFast from lighteval.logging.evaluation_tracker import EvaluationTracker from lighteval.metrics.metrics import Metrics from lighteval.models.tgi_model import ModelClient as TGIModel +from lighteval.models.base_model import BaseModel from lighteval.pipeline import ParallelismManager, Pipeline, PipelineParameters from lighteval.tasks.lighteval_task import LightevalTask, LightevalTaskConfig, create_requests_from_tasks from lighteval.tasks.requests import ( + LoglikelihoodRequest, + LoglikelihoodRollingRequest, Doc, Request, RequestType, ) -from lighteval.utils.utils import EnvConfig - - -TOKEN = os.environ.get("HF_TOKEN") -CACHE_PATH = os.getenv("HF_HOME", ".") @pytest.fixture(scope="module") @@ -83,12 +83,19 @@ def tgi_model() -> Iterator[TGIModel]: raise RuntimeError("Couldn't setup TGI server.") model = TGIModel(address) yield model - container.stop() - container.wait() - container.remove() + # container.stop() + # container.wait() + # container.remove() model.cleanup() +@pytest.fixture(scope="module") +def reference_model_tokenizer() -> tuple[LlamaForCausalLM, PreTrainedTokenizerFast]: + model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-LlamaForCausalLM") + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-LlamaForCausalLM") + return model, tokenizer + + RequestDict: TypeAlias = dict[RequestType, list[Request]] @@ -150,28 +157,62 @@ def zero_shot_request_dict(self, task: LightevalTask) -> RequestDict: result[req_type].extend(doc_result[req_type]) return result - def test_greedy_until(self, zero_shot_request_dict: RequestDict, tgi_model: TGIModel): - returns = tgi_model.greedy_until(zero_shot_request_dict[RequestType.GREEDY_UNTIL]) + def test_greedy_until(self, reference_model_tokenizer: tuple[LlamaForCausalLM, PreTrainedTokenizerFast], zero_shot_request_dict: RequestDict, tgi_model: TGIModel): + requests = zero_shot_request_dict[RequestType.GREEDY_UNTIL] + returns = tgi_model.greedy_until(requests) + model, tokenizer = reference_model_tokenizer assert len(returns) == 2 - assert all(r.result is not None for r in returns) - - def test_loglikelihood(self, zero_shot_request_dict: RequestDict, tgi_model: TGIModel): - returns = tgi_model.loglikelihood(zero_shot_request_dict[RequestType.LOGLIKELIHOOD]) + for req, res in zip(requests, returns): + is_chat = not isinstance(req.context, str) + tokenized_context = tokenizer.apply_chat_template(req.context, return_tensors='pt') if is_chat else tokenizer(req.context, return_tensors='pt')['input_ids'] + ref_context_continuaiton = model.generate(tokenized_context, tokenizer=tokenizer, stop_strings=req.stop_sequence, max_new_tokens=req.generation_size)[0].tolist() + continuation = tokenizer.decode(ref_context_continuaiton)[len(tokenizer.decode(tokenized_context[0].tolist())):] + assert continuation == res.result + + def test_loglikelihood(self, reference_model_tokenizer: tuple[LlamaForCausalLM, PreTrainedTokenizerFast], zero_shot_request_dict: RequestDict, tgi_model: TGIModel): + requests: list[LoglikelihoodRequest] = zero_shot_request_dict[RequestType.LOGLIKELIHOOD] + returns = tgi_model.loglikelihood(requests) + model, tokenizer = reference_model_tokenizer assert len(returns) == 4 - assert all(r.result is not None for r in returns) - + for req, res in zip(requests, returns): + is_chat = not isinstance(req.context, str) + sequence = req.context + [ChatCompletionInputMessage(role='assistant',content=req.choice)] if is_chat else req.context+req.choice + tokenized_sequence = tokenizer.apply_chat_template(sequence, return_tensors='pt') if is_chat else tokenizer(sequence, return_tensors='pt')['input_ids'] + + output = model.generate(tokenized_sequence, max_new_tokens=1, return_dict_in_generate=True, output_hidden_states=True) + with torch.no_grad(): + logprobs = torch.log_softmax(model.lm_head(output.hidden_states[0][-1]),dim=-1) + logprobs = logprobs.gather(dim=-1, index=tokenized_sequence[:,1:].unsqueeze(-1)) + context_length = len(tokenizer.apply_chat_template(req.context)) if is_chat else len(tokenizer.encode(req.context)) + continuation_logprob = logprobs[:, context_length-1:].sum() + + tokenized_choice = tokenized_sequence[:, context_length:] + assert tokenized_choice[0].tolist() == res.input_tokens + assert torch.allclose(torch.tensor(res.result[0]), continuation_logprob) + + def test_loglikelihood_rolling(self, reference_model_tokenizer: tuple[LlamaForCausalLM, PreTrainedTokenizerFast], zero_shot_request_dict: RequestDict, tgi_model: TGIModel): + model, tokenizer = reference_model_tokenizer + requests: list[LoglikelihoodRollingRequest] = zero_shot_request_dict[RequestType.LOGLIKELIHOOD_ROLLING] returns = tgi_model.loglikelihood_rolling(zero_shot_request_dict[RequestType.LOGLIKELIHOOD_ROLLING]) assert len(returns) == 2 - assert all(r.result is not None for r in returns) + for req, res in zip(requests, returns): + is_chat = not isinstance(req.context, str) + tokenized_context = tokenizer.apply_chat_template(req.context, return_tensors='pt') if is_chat else tokenizer(req.context, return_tensors='pt')['input_ids'] + output = model.generate(tokenized_context, max_new_tokens=1, return_dict_in_generate=True, output_hidden_states=True) + with torch.no_grad(): + logprobs = torch.log_softmax(model.lm_head(output.hidden_states[0][-1]),dim=-1) + logprob = logprobs.gather(dim=-1, index=tokenized_context[:,1:].unsqueeze(-1)).sum() + + assert tokenized_context[0, 1:].tolist() == res.input_tokens + assert torch.allclose(torch.tensor(res.result), logprob) @pytest.mark.parametrize("num_fewshot", [0, 2]) @pytest.mark.parametrize("use_chat_template", [False, True]) - def test_integration(self, task: LightevalTask, tgi_model: TGIModel, num_fewshot: int, use_chat_template: bool): - env_config = EnvConfig(token=TOKEN, cache_dir=CACHE_PATH) + def test_integration(self, task: LightevalTask, base_model: BaseModel, tgi_model: TGIModel, num_fewshot: int, use_chat_template: bool): + #TODO evaluation_tracker = EvaluationTracker() pipeline_params = PipelineParameters( launcher_type=ParallelismManager.NONE, - env_config=env_config, use_chat_template=use_chat_template, ) diff --git a/tests/test_test.py b/tests/test_test.py new file mode 100644 index 00000000..1612b650 --- /dev/null +++ b/tests/test_test.py @@ -0,0 +1,103 @@ +import time +import random +import asyncio +from typing import Iterator + +import pytest +import docker +import requests +import torch +from transformers import AutoTokenizer, AutoModelForCausalLM, LlamaForCausalLM +from huggingface_hub import ( + InferenceClient, + AsyncInferenceClient, + TextGenerationOutput, +) + + +@pytest.fixture(params=["sync", "async"]) +def tgi_client(request) -> Iterator[InferenceClient|AsyncInferenceClient]: + client = docker.from_env() + + try: + container = client.containers.get("lighteval-tgi-model-test") + port = container.ports["80/tcp"][0]["HostPort"] + except docker.errors.NotFound: + port = random.randint(8000, 9000) + container = client.containers.run( + "ghcr.io/huggingface/text-generation-inference:2.2.0", + command=[ + "--model-id", + "hf-internal-testing/tiny-random-LlamaForCausalLM", + "--dtype", + "float16", + ], + detach=True, + name="lighteval-tgi-model-test", + auto_remove=False, + ports={"80/tcp": port}, + ) + address = f"http://localhost:{port}" + for _ in range(40): + try: + if requests.get(f"{address}/health"): + break + except Exception: + time.sleep(1) + else: + raise RuntimeError("Couldn't setup TGI server.") + + if request.param == "async": + yield AsyncInferenceClient(base_url=address) + elif request.param == "sync": + yield InferenceClient(base_url=address) + else: + raise RuntimeError() + + +def test_logprobs(tgi_client: InferenceClient|AsyncInferenceClient): + model: LlamaForCausalLM = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-LlamaForCausalLM") + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-LlamaForCausalLM") + + # It raises error in async setting unless the size of `prompts` is < 3 + prompts = [ + "Tell me:\n\nHow are you?Fine, thanks!", + "Tell me:\n\nHow are you?Not bad!", + "Tell me:\n\nComment vas-tu?Comme ci, comme ça", + "Tell me:\n\nComment vas-tu?Ca va! Merci!", + ] + responses = [] + for prompt in prompts: + responses.append(tgi_client.text_generation( + prompt, + details=True, + decoder_input_details=True, + max_new_tokens=1, + stop_sequences=None, + do_sample=False, + return_full_text=False, + seed=42, + )) + if isinstance(tgi_client, AsyncInferenceClient): + loop = asyncio.get_event_loop() + responses: list[TextGenerationOutput] = loop.run_until_complete(asyncio.gather(*responses)) + + error = False + for prompt, response in zip(prompts, responses): + + tgi_logprobs = torch.tensor([t.logprob for t in response.details.prefill[1:]]) # Skipping whose logprob is None + + tokenized_sequence = tokenizer(prompt, return_tensors='pt')['input_ids'] + output = model.generate(tokenized_sequence, max_new_tokens=1, return_dict_in_generate=True, output_hidden_states=True) + with torch.no_grad(): + logprobs = torch.log_softmax(model.lm_head(output.hidden_states[0][-1]),dim=-1) + logprobs = logprobs.gather(dim=-1, index=tokenized_sequence[:,1:].unsqueeze(-1)).squeeze() + + if not torch.allclose(logprobs.sum(), tgi_logprobs.sum()): + print(f"====== prompt: {repr(prompt)} ======") + print("TGI logprobs:", tgi_logprobs.tolist()) + print("TGI tokens:",[t.id for t in response.details.prefill]) + print("Ref. logprobs:", logprobs.tolist()) + print("Ref. tokens:", tokenized_sequence[0].tolist()) + error = True + assert not error \ No newline at end of file