-
Notifications
You must be signed in to change notification settings - Fork 40
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add simple and full knowledge pipeline functional tests
This is a port of the old `scripts/test_knowledge.py` into functional tests that we can run with CI. These tests take longer, so are marked as `slow` in pytest. Signed-off-by: Ben Browning <[email protected]>
- Loading branch information
Showing
7 changed files
with
210 additions
and
59 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,14 +1,112 @@ | ||
# Standard | ||
from importlib import resources | ||
import pathlib | ||
import typing | ||
|
||
# Third Party | ||
from datasets import Dataset | ||
from llama_cpp.server.app import create_app | ||
from llama_cpp.server.settings import ModelSettings, ServerSettings | ||
from openai import OpenAI | ||
from starlette.testclient import TestClient | ||
import pytest | ||
|
||
# First Party | ||
from src.instructlab.sdg.pipeline import Pipeline, PipelineContext | ||
|
||
|
||
TESTS_PATH = pathlib.Path(__file__).parent.parent.absolute() | ||
|
||
|
||
@pytest.fixture | ||
def testdata_path() -> typing.Generator[pathlib.Path, None, None]: | ||
"""Path to local test data directory""" | ||
yield TESTS_PATH / "testdata" | ||
|
||
|
||
@pytest.fixture | ||
def num_gpu_layers(): | ||
return 0 | ||
|
||
|
||
@pytest.fixture | ||
def openai_client(model, model_repo_id, num_gpu_layers): | ||
server_settings = ServerSettings() | ||
model_settings = [ | ||
ModelSettings( | ||
model=model, | ||
hf_model_repo_id=model_repo_id, | ||
n_gpu_layers=num_gpu_layers, # just run on the CPU | ||
verbose=True, | ||
) | ||
] | ||
app = create_app( | ||
server_settings=server_settings, | ||
model_settings=model_settings, | ||
) | ||
|
||
@app.get("/") | ||
def read_root(): | ||
return {"message": "Hello from InstructLab! Visit us at https://instructlab.ai"} | ||
|
||
test_client = TestClient(app) | ||
return OpenAI( | ||
api_key="EMPTY", | ||
base_url="http://localhost:8000/v1", | ||
http_client=test_client, | ||
) | ||
|
||
|
||
@pytest.fixture | ||
def teacher_model(openai_client): | ||
models = openai_client.models.list() | ||
return models.data[0].id | ||
|
||
|
||
@pytest.fixture | ||
def max_num_tokens(): | ||
return 256 | ||
|
||
|
||
@pytest.fixture | ||
def pipeline_context( | ||
openai_client, | ||
model_family, | ||
teacher_model, | ||
num_instructions_to_generate, | ||
max_num_tokens, | ||
): | ||
return PipelineContext( | ||
openai_client, | ||
model_family, | ||
teacher_model, | ||
num_instructions_to_generate, | ||
max_num_tokens=max_num_tokens, | ||
) | ||
|
||
|
||
@pytest.fixture | ||
def knowledge_dataset(): | ||
return Dataset.from_list( | ||
[ | ||
{ | ||
"icl_query_1": "what is the location of the tubal tonsils?", | ||
"icl_response_1": "The location of the tubal tonsils is the roof of the pharynx.", | ||
"icl_query_2": "How long does the adenoid grow?", | ||
"task_description": "Teaching about human anatomy, specifically tonsils", | ||
"icl_response_2": "The adenoid grows until the age of 5, starts to shrink at the age of 7 and becomes small in adulthood.", | ||
"icl_query_3": "What is the immune systems first line of defense against ingested or inhaled foreign pathogens?", | ||
"icl_response_3": "The tonsils are the immune systems first line of defense.", | ||
"document": "The **tonsils** are a set of lymphoid organs facing into the aerodigestive tract, which is known as Waldeyer's tonsillar ring and consists of the adenoid tonsil or pharyngeal tonsil, two tubal tonsils, two palatine tonsils, and the lingual tonsils. These organs play an important role in the immune system. When used unqualified, the term most commonly refers specifically to the palatine tonsils, which are two lymphoid organs situated at either side of the back of the human throat. The palatine tonsils and the adenoid tonsil are organs consisting of lymphoepithelial tissue located near the oropharynx and nasopharynx parts of the throat", | ||
"icl_document": "The **tonsils** are a set of lymphoid organs facing into the aerodigestive tract, which is known as Waldeyer's tonsillar ring and consists of the adenoid tonsil or pharyngeal tonsil, two tubal tonsils, two palatine tonsils, and the lingual tonsils.", | ||
"domain": "textbook", | ||
"document_outline": "Medical description of tonsils", | ||
} | ||
] | ||
) | ||
|
||
|
||
@pytest.fixture | ||
def knowledge_pipeline(pipeline_context, pipelines_package): | ||
yaml_path = resources.files(pipelines_package).joinpath("knowledge.yaml") | ||
return Pipeline.from_file(pipeline_context, yaml_path) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
# Standard | ||
import unittest | ||
|
||
# Third Party | ||
import pytest | ||
|
||
# First Party | ||
from src.instructlab.sdg.datamixing import _get_question_hack, _get_response_hack | ||
from src.instructlab.sdg.pipeline import FULL_PIPELINES_PACKAGE | ||
|
||
|
||
@pytest.fixture | ||
def model(): | ||
return "mistral-7b-instruct-v0.2.Q5_K_M.gguf" | ||
# return "mistral-7b-instruct-v0.2.Q4_K_M.gguf" | ||
# return "mistral-7b-instruct-v0.2.Q3_K_S.gguf" | ||
|
||
|
||
@pytest.fixture | ||
def model_family(): | ||
return "mixtral" | ||
|
||
|
||
@pytest.fixture | ||
def model_repo_id(): | ||
return "TheBloke/Mistral-7B-Instruct-v0.2-GGUF" | ||
|
||
|
||
@pytest.fixture | ||
def num_instructions_to_generate(): | ||
return 1 | ||
|
||
|
||
@pytest.fixture | ||
def pipelines_package(): | ||
return FULL_PIPELINES_PACKAGE | ||
|
||
|
||
@pytest.mark.slow | ||
class TestFullPipeline(unittest.TestCase): | ||
@pytest.fixture(autouse=True) | ||
def _setup_fixtures(self, knowledge_dataset, knowledge_pipeline): | ||
self.knowledge_dataset = knowledge_dataset | ||
self.knowledge_pipeline = knowledge_pipeline | ||
|
||
def test_knowledge(self): | ||
samples = self.knowledge_pipeline.generate(self.knowledge_dataset) | ||
print(samples) | ||
assert len(samples) > 0 | ||
for sample in samples: | ||
print(sample) | ||
question = _get_question_hack(sample) | ||
response = _get_response_hack(sample) | ||
assert len(question) > 0 | ||
assert len(response) > 0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
# Standard | ||
import unittest | ||
|
||
# Third Party | ||
import pytest | ||
|
||
# First Party | ||
from src.instructlab.sdg.datamixing import _get_question_hack, _get_response_hack | ||
from src.instructlab.sdg.pipeline import SIMPLE_PIPELINES_PACKAGE | ||
|
||
|
||
@pytest.fixture | ||
def model(): | ||
return "merlinite-7b-Q4_K_M.gguf" | ||
|
||
|
||
@pytest.fixture | ||
def model_family(): | ||
return "merlinite" | ||
|
||
|
||
@pytest.fixture | ||
def model_repo_id(): | ||
return "ibm/merlinite-7b-GGUF" | ||
|
||
|
||
@pytest.fixture | ||
def num_instructions_to_generate(): | ||
return 2 | ||
|
||
|
||
@pytest.fixture | ||
def pipelines_package(): | ||
return SIMPLE_PIPELINES_PACKAGE | ||
|
||
|
||
@pytest.mark.slow | ||
class TestSimplePipeline(unittest.TestCase): | ||
@pytest.fixture(autouse=True) | ||
def _setup_fixtures(self, knowledge_dataset, knowledge_pipeline): | ||
self.knowledge_dataset = knowledge_dataset | ||
self.knowledge_pipeline = knowledge_pipeline | ||
|
||
def test_knowledge(self): | ||
samples = self.knowledge_pipeline.generate(self.knowledge_dataset) | ||
assert len(samples) > 0 | ||
for sample in samples: | ||
question = _get_question_hack(sample) | ||
response = _get_response_hack(sample) | ||
assert len(question) > 0 | ||
assert len(response) > 0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters