Skip to content

Commit

Permalink
Adapt integration test to new Pipeline
Browse files Browse the repository at this point in the history
  • Loading branch information
sadra-barikbin committed Aug 26, 2024
1 parent 8e0a075 commit de60b36
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 40 deletions.
17 changes: 11 additions & 6 deletions src/lighteval/models/tgi_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from huggingface_hub import AsyncInferenceClient, InferenceClient
from transformers import AutoTokenizer

from lighteval.models.abstract_model import ModelInfo
from lighteval.models.endpoint_model import InferenceEndpointModel


Expand All @@ -45,15 +46,19 @@ def __init__(self, address, auth_token=None, model_id=None) -> None:
self.client = InferenceClient(base_url=address, headers=headers, timeout=240)
self.async_client = AsyncInferenceClient(base_url=address, headers=headers, timeout=240)
self._max_gen_toks = 256
self.model_info = requests.get(f"{address}/info", headers=headers).json()
if "model_id" not in self.model_info:
info = requests.get(f"{address}/info", headers=headers).json()
if "model_id" not in info:
raise ValueError("Error occured when fetching info: " + str(self.model_info))
if model_id:
self.model_info["model_id"] = model_id
self._tokenizer = AutoTokenizer.from_pretrained(self.model_info["model_id"])
self.name = info["model_id"]
self.model_info = ModelInfo(
model_name=model_id or self.name,
model_sha=info["model_sha"],
model_dtype=info["model_dtype"] or "default",
model_size=-1,
)
self._tokenizer = AutoTokenizer.from_pretrained(self.model_info.model_name)
self._add_special_tokens = True
self.use_async = True
self.name = self.model_info["model_id"]

def set_cache_hook(self, cache_hook):
self.cache_hook = cache_hook
Expand Down
48 changes: 31 additions & 17 deletions tests/test_base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,17 @@

import os
from typing import Iterator, TypeAlias
from unittest.mock import patch

import pytest
from huggingface_hub import ChatCompletionInputMessage
from transformers import BatchEncoding

from lighteval.evaluator import EvaluationTracker, evaluate
from lighteval.logging.evaluation_tracker import EvaluationTracker
from lighteval.metrics.metrics import Metrics
from lighteval.models.base_model import BaseModel
from lighteval.models.model_config import BaseModelConfig, EnvConfig
from lighteval.pipeline import ParallelismManager, Pipeline, PipelineParameters
from lighteval.tasks.lighteval_task import LightevalTask, LightevalTaskConfig, create_requests_from_tasks
from lighteval.tasks.requests import (
Doc,
Expand Down Expand Up @@ -120,28 +122,40 @@ def task(self) -> LightevalTask:
def test_integration(self, task: LightevalTask, base_model: BaseModel, num_fewshot: int, use_chat_template: bool):
base_model.use_chat_template = use_chat_template

env_config = EnvConfig(token=TOKEN, cache_dir=CACHE_PATH)
evaluation_tracker = EvaluationTracker()
pipeline_params = PipelineParameters(
launcher_type=ParallelismManager.NONE,
env_config=env_config,
use_chat_template=use_chat_template,
override_batch_size=1,
)

with patch("lighteval.pipeline.Pipeline._init_tasks_and_requests"):
pipeline = Pipeline(
tasks=f"custom|test|{num_fewshot}|0",
pipeline_parameters=pipeline_params,
evaluation_tracker=evaluation_tracker,
model=base_model,
)
task_dict = {"custom|test": task}
evaluation_tracker.task_config_logger.log(task_dict)
requests_dict, docs = create_requests_from_tasks(
fewshot_dict = {"custom|test": [(num_fewshot, False)]}
pipeline.task_names_list = ["custom|test"]
pipeline.task_dict = task_dict
pipeline.fewshot_dict = fewshot_dict
requests, docs = create_requests_from_tasks(
task_dict=task_dict,
fewshot_dict={"custom|test": [(num_fewshot, False)]},
num_fewshot_seeds=0,
fewshot_dict=fewshot_dict,
num_fewshot_seeds=pipeline_params.num_fewshot_seeds,
lm=base_model,
max_samples=1,
max_samples=pipeline_params.max_samples,
evaluation_tracker=evaluation_tracker,
use_chat_template=use_chat_template,
system_prompt=None,
system_prompt=pipeline_params.system_prompt,
)
pipeline.requests = requests
pipeline.docs = docs
evaluation_tracker.task_config_logger.log(task_dict)

evaluation_tracker = evaluate(
lm=base_model,
requests_dict=requests_dict,
docs=docs,
task_dict=task_dict,
override_bs=1,
evaluation_tracker=evaluation_tracker,
)
evaluation_tracker.metrics_logger.aggregate(task_dict=task_dict)
evaluation_tracker.details_logger.aggregate()
evaluation_tracker.generate_final_dict()
pipeline.evaluate()
48 changes: 31 additions & 17 deletions tests/test_endpoint_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,22 +25,25 @@
import time
from collections import defaultdict
from typing import Iterator, TypeAlias
from unittest.mock import patch

import docker
import docker.errors
import pytest
import requests
from huggingface_hub import ChatCompletionInputMessage

from lighteval.evaluator import EvaluationTracker, evaluate
from lighteval.logging.evaluation_tracker import EvaluationTracker
from lighteval.metrics.metrics import Metrics
from lighteval.models.tgi_model import ModelClient as TGIModel
from lighteval.pipeline import ParallelismManager, Pipeline, PipelineParameters
from lighteval.tasks.lighteval_task import LightevalTask, LightevalTaskConfig, create_requests_from_tasks
from lighteval.tasks.requests import (
Doc,
Request,
RequestType,
)
from lighteval.utils.utils import EnvConfig


TOKEN = os.environ.get("HF_TOKEN")
Expand Down Expand Up @@ -164,28 +167,39 @@ def test_loglikelihood(self, zero_shot_request_dict: RequestDict, tgi_model: TGI
@pytest.mark.parametrize("num_fewshot", [0, 2])
@pytest.mark.parametrize("use_chat_template", [False, True])
def test_integration(self, task: LightevalTask, tgi_model: TGIModel, num_fewshot: int, use_chat_template: bool):
env_config = EnvConfig(token=TOKEN, cache_dir=CACHE_PATH)
evaluation_tracker = EvaluationTracker()
pipeline_params = PipelineParameters(
launcher_type=ParallelismManager.NONE,
env_config=env_config,
use_chat_template=use_chat_template,
)

with patch("lighteval.pipeline.Pipeline._init_tasks_and_requests"):
pipeline = Pipeline(
tasks=f"custom|test|{num_fewshot}|0",
pipeline_parameters=pipeline_params,
evaluation_tracker=evaluation_tracker,
model=tgi_model,
)
task_dict = {"custom|test": task}
evaluation_tracker.task_config_logger.log(task_dict)
requests_dict, docs = create_requests_from_tasks(
fewshot_dict = {"custom|test": [(num_fewshot, False)]}
pipeline.task_names_list = ["custom|test"]
pipeline.task_dict = task_dict
pipeline.fewshot_dict = fewshot_dict
requests, docs = create_requests_from_tasks(
task_dict=task_dict,
fewshot_dict={"custom|test": [(num_fewshot, False)]},
num_fewshot_seeds=0,
fewshot_dict=fewshot_dict,
num_fewshot_seeds=pipeline_params.num_fewshot_seeds,
lm=tgi_model,
max_samples=1,
max_samples=pipeline_params.max_samples,
evaluation_tracker=evaluation_tracker,
use_chat_template=use_chat_template,
system_prompt=None,
system_prompt=pipeline_params.system_prompt,
)
pipeline.requests = requests
pipeline.docs = docs
evaluation_tracker.task_config_logger.log(task_dict)

evaluation_tracker = evaluate(
lm=tgi_model,
requests_dict=requests_dict,
docs=docs,
task_dict=task_dict,
override_bs=1,
evaluation_tracker=evaluation_tracker,
)
evaluation_tracker.metrics_logger.aggregate(task_dict=task_dict)
evaluation_tracker.details_logger.aggregate()
evaluation_tracker.generate_final_dict()
pipeline.evaluate()

0 comments on commit de60b36

Please sign in to comment.