From de60b3629f45c4967d21db608f9709b95917248a Mon Sep 17 00:00:00 2001 From: Sadra Barikbin Date: Mon, 26 Aug 2024 13:32:34 +0330 Subject: [PATCH] Adapt integration test to new Pipeline --- src/lighteval/models/tgi_model.py | 17 +++++++---- tests/test_base_model.py | 48 ++++++++++++++++++++----------- tests/test_endpoint_model.py | 48 ++++++++++++++++++++----------- 3 files changed, 73 insertions(+), 40 deletions(-) diff --git a/src/lighteval/models/tgi_model.py b/src/lighteval/models/tgi_model.py index 92c3f65b..0621ea12 100644 --- a/src/lighteval/models/tgi_model.py +++ b/src/lighteval/models/tgi_model.py @@ -24,6 +24,7 @@ from huggingface_hub import AsyncInferenceClient, InferenceClient from transformers import AutoTokenizer +from lighteval.models.abstract_model import ModelInfo from lighteval.models.endpoint_model import InferenceEndpointModel @@ -45,15 +46,19 @@ def __init__(self, address, auth_token=None, model_id=None) -> None: self.client = InferenceClient(base_url=address, headers=headers, timeout=240) self.async_client = AsyncInferenceClient(base_url=address, headers=headers, timeout=240) self._max_gen_toks = 256 - self.model_info = requests.get(f"{address}/info", headers=headers).json() - if "model_id" not in self.model_info: + info = requests.get(f"{address}/info", headers=headers).json() + if "model_id" not in info: raise ValueError("Error occured when fetching info: " + str(self.model_info)) - if model_id: - self.model_info["model_id"] = model_id - self._tokenizer = AutoTokenizer.from_pretrained(self.model_info["model_id"]) + self.name = info["model_id"] + self.model_info = ModelInfo( + model_name=model_id or self.name, + model_sha=info["model_sha"], + model_dtype=info["model_dtype"] or "default", + model_size=-1, + ) + self._tokenizer = AutoTokenizer.from_pretrained(self.model_info.model_name) self._add_special_tokens = True self.use_async = True - self.name = self.model_info["model_id"] def set_cache_hook(self, cache_hook): self.cache_hook = cache_hook diff --git a/tests/test_base_model.py b/tests/test_base_model.py index 050b5ffb..f622ff77 100644 --- a/tests/test_base_model.py +++ b/tests/test_base_model.py @@ -22,15 +22,17 @@ import os from typing import Iterator, TypeAlias +from unittest.mock import patch import pytest from huggingface_hub import ChatCompletionInputMessage from transformers import BatchEncoding -from lighteval.evaluator import EvaluationTracker, evaluate +from lighteval.logging.evaluation_tracker import EvaluationTracker from lighteval.metrics.metrics import Metrics from lighteval.models.base_model import BaseModel from lighteval.models.model_config import BaseModelConfig, EnvConfig +from lighteval.pipeline import ParallelismManager, Pipeline, PipelineParameters from lighteval.tasks.lighteval_task import LightevalTask, LightevalTaskConfig, create_requests_from_tasks from lighteval.tasks.requests import ( Doc, @@ -120,28 +122,40 @@ def task(self) -> LightevalTask: def test_integration(self, task: LightevalTask, base_model: BaseModel, num_fewshot: int, use_chat_template: bool): base_model.use_chat_template = use_chat_template + env_config = EnvConfig(token=TOKEN, cache_dir=CACHE_PATH) evaluation_tracker = EvaluationTracker() + pipeline_params = PipelineParameters( + launcher_type=ParallelismManager.NONE, + env_config=env_config, + use_chat_template=use_chat_template, + override_batch_size=1, + ) + + with patch("lighteval.pipeline.Pipeline._init_tasks_and_requests"): + pipeline = Pipeline( + tasks=f"custom|test|{num_fewshot}|0", + pipeline_parameters=pipeline_params, + evaluation_tracker=evaluation_tracker, + model=base_model, + ) task_dict = {"custom|test": task} evaluation_tracker.task_config_logger.log(task_dict) - requests_dict, docs = create_requests_from_tasks( + fewshot_dict = {"custom|test": [(num_fewshot, False)]} + pipeline.task_names_list = ["custom|test"] + pipeline.task_dict = task_dict + pipeline.fewshot_dict = fewshot_dict + requests, docs = create_requests_from_tasks( task_dict=task_dict, - fewshot_dict={"custom|test": [(num_fewshot, False)]}, - num_fewshot_seeds=0, + fewshot_dict=fewshot_dict, + num_fewshot_seeds=pipeline_params.num_fewshot_seeds, lm=base_model, - max_samples=1, + max_samples=pipeline_params.max_samples, evaluation_tracker=evaluation_tracker, use_chat_template=use_chat_template, - system_prompt=None, + system_prompt=pipeline_params.system_prompt, ) + pipeline.requests = requests + pipeline.docs = docs + evaluation_tracker.task_config_logger.log(task_dict) - evaluation_tracker = evaluate( - lm=base_model, - requests_dict=requests_dict, - docs=docs, - task_dict=task_dict, - override_bs=1, - evaluation_tracker=evaluation_tracker, - ) - evaluation_tracker.metrics_logger.aggregate(task_dict=task_dict) - evaluation_tracker.details_logger.aggregate() - evaluation_tracker.generate_final_dict() + pipeline.evaluate() diff --git a/tests/test_endpoint_model.py b/tests/test_endpoint_model.py index 20c7701c..9550eb6a 100644 --- a/tests/test_endpoint_model.py +++ b/tests/test_endpoint_model.py @@ -25,6 +25,7 @@ import time from collections import defaultdict from typing import Iterator, TypeAlias +from unittest.mock import patch import docker import docker.errors @@ -32,15 +33,17 @@ import requests from huggingface_hub import ChatCompletionInputMessage -from lighteval.evaluator import EvaluationTracker, evaluate +from lighteval.logging.evaluation_tracker import EvaluationTracker from lighteval.metrics.metrics import Metrics from lighteval.models.tgi_model import ModelClient as TGIModel +from lighteval.pipeline import ParallelismManager, Pipeline, PipelineParameters from lighteval.tasks.lighteval_task import LightevalTask, LightevalTaskConfig, create_requests_from_tasks from lighteval.tasks.requests import ( Doc, Request, RequestType, ) +from lighteval.utils.utils import EnvConfig TOKEN = os.environ.get("HF_TOKEN") @@ -164,28 +167,39 @@ def test_loglikelihood(self, zero_shot_request_dict: RequestDict, tgi_model: TGI @pytest.mark.parametrize("num_fewshot", [0, 2]) @pytest.mark.parametrize("use_chat_template", [False, True]) def test_integration(self, task: LightevalTask, tgi_model: TGIModel, num_fewshot: int, use_chat_template: bool): + env_config = EnvConfig(token=TOKEN, cache_dir=CACHE_PATH) evaluation_tracker = EvaluationTracker() + pipeline_params = PipelineParameters( + launcher_type=ParallelismManager.NONE, + env_config=env_config, + use_chat_template=use_chat_template, + ) + + with patch("lighteval.pipeline.Pipeline._init_tasks_and_requests"): + pipeline = Pipeline( + tasks=f"custom|test|{num_fewshot}|0", + pipeline_parameters=pipeline_params, + evaluation_tracker=evaluation_tracker, + model=tgi_model, + ) task_dict = {"custom|test": task} evaluation_tracker.task_config_logger.log(task_dict) - requests_dict, docs = create_requests_from_tasks( + fewshot_dict = {"custom|test": [(num_fewshot, False)]} + pipeline.task_names_list = ["custom|test"] + pipeline.task_dict = task_dict + pipeline.fewshot_dict = fewshot_dict + requests, docs = create_requests_from_tasks( task_dict=task_dict, - fewshot_dict={"custom|test": [(num_fewshot, False)]}, - num_fewshot_seeds=0, + fewshot_dict=fewshot_dict, + num_fewshot_seeds=pipeline_params.num_fewshot_seeds, lm=tgi_model, - max_samples=1, + max_samples=pipeline_params.max_samples, evaluation_tracker=evaluation_tracker, use_chat_template=use_chat_template, - system_prompt=None, + system_prompt=pipeline_params.system_prompt, ) + pipeline.requests = requests + pipeline.docs = docs + evaluation_tracker.task_config_logger.log(task_dict) - evaluation_tracker = evaluate( - lm=tgi_model, - requests_dict=requests_dict, - docs=docs, - task_dict=task_dict, - override_bs=1, - evaluation_tracker=evaluation_tracker, - ) - evaluation_tracker.metrics_logger.aggregate(task_dict=task_dict) - evaluation_tracker.details_logger.aggregate() - evaluation_tracker.generate_final_dict() + pipeline.evaluate()