Skip to content

Commit

Permalink
Add simple and full knowledge pipeline functional tests
Browse files Browse the repository at this point in the history
This is a port of the old `scripts/test_knowledge.py` into functional
tests that we can run with CI. These tests take longer, so are marked as
`slow` in pytest.

Signed-off-by: Ben Browning <[email protected]>
  • Loading branch information
bbrowning committed Nov 14, 2024
1 parent 9d4ed74 commit 6f377bc
Show file tree
Hide file tree
Showing 7 changed files with 210 additions and 59 deletions.
5 changes: 5 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -102,3 +102,8 @@ exclude = [
]
# honor excludes by not following there through imports
follow_imports = "silent"

[tool.pytest.ini_options]
markers = [
"slow: marks tests that are slow (deselect with '-m \"not slow\"')",
]
1 change: 1 addition & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

-r requirements.txt

llama-cpp-python[server]>=0.3.0,<1.0.0
pre-commit>=3.0.4,<4.0
pylint>=2.16.2,<4.0
pylint-pydantic
Expand Down
52 changes: 0 additions & 52 deletions scripts/test_knowledge.py

This file was deleted.

98 changes: 98 additions & 0 deletions tests/functional/conftest.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,112 @@
# Standard
from importlib import resources
import pathlib
import typing

# Third Party
from datasets import Dataset
from llama_cpp.server.app import create_app
from llama_cpp.server.settings import ModelSettings, ServerSettings
from openai import OpenAI
from starlette.testclient import TestClient
import pytest

# First Party
from src.instructlab.sdg.pipeline import Pipeline, PipelineContext


TESTS_PATH = pathlib.Path(__file__).parent.parent.absolute()


@pytest.fixture
def testdata_path() -> typing.Generator[pathlib.Path, None, None]:
"""Path to local test data directory"""
yield TESTS_PATH / "testdata"


@pytest.fixture
def num_gpu_layers():
return 0


@pytest.fixture
def openai_client(model, model_repo_id, num_gpu_layers):
server_settings = ServerSettings()
model_settings = [
ModelSettings(
model=model,
hf_model_repo_id=model_repo_id,
n_gpu_layers=num_gpu_layers, # just run on the CPU
verbose=True,
)
]
app = create_app(
server_settings=server_settings,
model_settings=model_settings,
)

@app.get("/")
def read_root():
return {"message": "Hello from InstructLab! Visit us at https://instructlab.ai"}

test_client = TestClient(app)
return OpenAI(
api_key="EMPTY",
base_url="http://localhost:8000/v1",
http_client=test_client,
)


@pytest.fixture
def teacher_model(openai_client):
models = openai_client.models.list()
return models.data[0].id


@pytest.fixture
def max_num_tokens():
return 256


@pytest.fixture
def pipeline_context(
openai_client,
model_family,
teacher_model,
num_instructions_to_generate,
max_num_tokens,
):
return PipelineContext(
openai_client,
model_family,
teacher_model,
num_instructions_to_generate,
max_num_tokens=max_num_tokens,
)


@pytest.fixture
def knowledge_dataset():
return Dataset.from_list(
[
{
"icl_query_1": "what is the location of the tubal tonsils?",
"icl_response_1": "The location of the tubal tonsils is the roof of the pharynx.",
"icl_query_2": "How long does the adenoid grow?",
"task_description": "Teaching about human anatomy, specifically tonsils",
"icl_response_2": "The adenoid grows until the age of 5, starts to shrink at the age of 7 and becomes small in adulthood.",
"icl_query_3": "What is the immune systems first line of defense against ingested or inhaled foreign pathogens?",
"icl_response_3": "The tonsils are the immune systems first line of defense.",
"document": "The **tonsils** are a set of lymphoid organs facing into the aerodigestive tract, which is known as Waldeyer's tonsillar ring and consists of the adenoid tonsil or pharyngeal tonsil, two tubal tonsils, two palatine tonsils, and the lingual tonsils. These organs play an important role in the immune system. When used unqualified, the term most commonly refers specifically to the palatine tonsils, which are two lymphoid organs situated at either side of the back of the human throat. The palatine tonsils and the adenoid tonsil are organs consisting of lymphoepithelial tissue located near the oropharynx and nasopharynx parts of the throat",
"icl_document": "The **tonsils** are a set of lymphoid organs facing into the aerodigestive tract, which is known as Waldeyer's tonsillar ring and consists of the adenoid tonsil or pharyngeal tonsil, two tubal tonsils, two palatine tonsils, and the lingual tonsils.",
"domain": "textbook",
"document_outline": "Medical description of tonsils",
}
]
)


@pytest.fixture
def knowledge_pipeline(pipeline_context, pipelines_package):
yaml_path = resources.files(pipelines_package).joinpath("knowledge.yaml")
return Pipeline.from_file(pipeline_context, yaml_path)
55 changes: 55 additions & 0 deletions tests/functional/test_full_pipeline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# Standard
import unittest

# Third Party
import pytest

# First Party
from src.instructlab.sdg.datamixing import _get_question_hack, _get_response_hack
from src.instructlab.sdg.pipeline import FULL_PIPELINES_PACKAGE


@pytest.fixture
def model():
return "mistral-7b-instruct-v0.2.Q5_K_M.gguf"
# return "mistral-7b-instruct-v0.2.Q4_K_M.gguf"
# return "mistral-7b-instruct-v0.2.Q3_K_S.gguf"


@pytest.fixture
def model_family():
return "mixtral"


@pytest.fixture
def model_repo_id():
return "TheBloke/Mistral-7B-Instruct-v0.2-GGUF"


@pytest.fixture
def num_instructions_to_generate():
return 1


@pytest.fixture
def pipelines_package():
return FULL_PIPELINES_PACKAGE


@pytest.mark.slow
class TestFullPipeline(unittest.TestCase):
@pytest.fixture(autouse=True)
def _setup_fixtures(self, knowledge_dataset, knowledge_pipeline):
self.knowledge_dataset = knowledge_dataset
self.knowledge_pipeline = knowledge_pipeline

def test_knowledge(self):
samples = self.knowledge_pipeline.generate(self.knowledge_dataset)
print(samples)
assert len(samples) > 0
for sample in samples:
print(sample)
question = _get_question_hack(sample)
response = _get_response_hack(sample)
assert len(question) > 0
assert len(response) > 0
51 changes: 51 additions & 0 deletions tests/functional/test_simple_pipeline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# Standard
import unittest

# Third Party
import pytest

# First Party
from src.instructlab.sdg.datamixing import _get_question_hack, _get_response_hack
from src.instructlab.sdg.pipeline import SIMPLE_PIPELINES_PACKAGE


@pytest.fixture
def model():
return "merlinite-7b-Q4_K_M.gguf"


@pytest.fixture
def model_family():
return "merlinite"


@pytest.fixture
def model_repo_id():
return "ibm/merlinite-7b-GGUF"


@pytest.fixture
def num_instructions_to_generate():
return 2


@pytest.fixture
def pipelines_package():
return SIMPLE_PIPELINES_PACKAGE


@pytest.mark.slow
class TestSimplePipeline(unittest.TestCase):
@pytest.fixture(autouse=True)
def _setup_fixtures(self, knowledge_dataset, knowledge_pipeline):
self.knowledge_dataset = knowledge_dataset
self.knowledge_pipeline = knowledge_pipeline

def test_knowledge(self):
samples = self.knowledge_pipeline.generate(self.knowledge_dataset)
assert len(samples) > 0
for sample in samples:
question = _get_question_hack(sample)
response = _get_response_hack(sample)
assert len(question) > 0
assert len(response) > 0
7 changes: 0 additions & 7 deletions tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,6 @@ commands =
unit: {envpython} -m pytest {posargs:tests --ignore=tests/functional}
unitcov: {envpython} -W error::UserWarning -m pytest --cov=instructlab.sdg --cov-report term --cov-report=html:coverage-{env_name} --cov-report=xml:coverage-{env_name}.xml --html=durations/{env_name}.html {posargs:tests --ignore=tests/functional -m "not (examples or slow)"}
functional: {envpython} -m pytest {posargs:tests/functional}
allowlist_externals =
functional: ./scripts/functional-tests.sh

[testenv:py3-functional]
setenv =
OPENAI_API_BASE={env:OPENAI_API_BASE:http://localhost:8000/v1}
OPENAI_API_KEY={env:OPENAI_API_KEY:EMPTY}

# format, check, and linting targets don't build and install the project to
# speed up testing.
Expand Down

0 comments on commit 6f377bc

Please sign in to comment.