Skip to content

Commit

Permalink
Add support for prompt annd embed and also handle default models (C…
Browse files Browse the repository at this point in the history
…loses #2, #4 ) (#7)
  • Loading branch information
Florents-Tselai authored Jun 26, 2024
1 parent aab27c6 commit c0c7483
Show file tree
Hide file tree
Showing 7 changed files with 134 additions and 59 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,5 +24,5 @@ jobs:
pip install '.[test]'
- name: Run tests
run: |
pytest
python -m unittest tests/test_tsellm.py
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,5 @@ venv
dist
build
.idea/
docs/_build
docs/_build
*.db
7 changes: 6 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,12 @@ def get_long_description():
packages=["tsellm"],
install_requires=["llm", "setuptools", "pip"],
extras_require={
"test": ["pytest", "pytest-cov", "black", "ruff", "sqlite_utils", "llm-markov"]
"test": [
"black",
"sqlite_utils",
"llm-markov",
"llm-embed-hazo",
]
},
python_requires=">=3.11",
)
40 changes: 0 additions & 40 deletions tests/conftest.py

This file was deleted.

116 changes: 100 additions & 16 deletions tests/test_tsellm.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,108 @@
import llm.cli
from sqlite_utils import Database
from tsellm.cli import cli
import unittest
from test.support import captured_stdout, captured_stderr, captured_stdin, os_helper
from test.support.os_helper import TESTFN, unlink
from llm import models
import sqlite3
from llm import cli as llm_cli


def test_cli_prompt_mock(existing_db_path):
db = Database(existing_db_path)
class CommandLineInterface(unittest.TestCase):

assert db.execute("select prompt from prompts").fetchall() == [
("hello world!",),
("how are you?",),
("is this real life?",),
("1+1=?",),
]
def _do_test(self, *args, expect_success=True):
with (
captured_stdout() as out,
captured_stderr() as err,
self.assertRaises(SystemExit) as cm,
):
cli(args)
return out.getvalue(), err.getvalue(), cm.exception.code

cli([existing_db_path, "UPDATE prompts SET generated=prompt(prompt, 'markov')"])
def expect_success(self, *args):
out, err, code = self._do_test(*args)
self.assertEqual(code, 0, "\n".join([f"Unexpected failure: {args=}", out, err]))

for prompt, generated in db.execute(
"select prompt, generated from prompts"
).fetchall():
words = generated.strip().split()
# This makes DeprecationWarning and other warnings cause a failure.
# Let's not be that harsh yet.
# See https://github.com/Florents-Tselai/llm/tree/fix-utc-warning-312
# self.assertEqual(err, "")
return out

def expect_failure(self, *args):
out, err, code = self._do_test(*args, expect_success=False)
self.assertNotEqual(
code, 0, "\n".join([f"Unexpected failure: {args=}", out, err])
)
self.assertEqual(out, "")
return err

def test_cli_help(self):
out = self.expect_success("-h")
self.assertIn("usage: python -m tsellm", out)

def test_cli_version(self):
out = self.expect_success("-v")
self.assertIn(sqlite3.sqlite_version, out)

def test_cli_execute_sql(self):
out = self.expect_success(":memory:", "select 1")
self.assertIn("(1,)", out)

def test_cli_execute_too_much_sql(self):
stderr = self.expect_failure(":memory:", "select 1; select 2")
err = "ProgrammingError: You can only execute one statement at a time"
self.assertIn(err, stderr)

def test_cli_execute_incomplete_sql(self):
stderr = self.expect_failure(":memory:", "sel")
self.assertIn("OperationalError (SQLITE_ERROR)", stderr)

def test_cli_on_disk_db(self):
self.addCleanup(unlink, TESTFN)
out = self.expect_success(TESTFN, "create table t(t)")
self.assertEqual(out, "")
out = self.expect_success(TESTFN, "select count(t) from t")
self.assertIn("(0,)", out)


class SQLiteLLMFunction(CommandLineInterface):

def setUp(self):
super().setUp()
llm_cli.set_default_model("markov")
llm_cli.set_default_embedding_model("hazo")

def assertMarkovResult(self, prompt, generated):
# Every word should be one of the original prompt (see https://github.com/simonw/llm-markov/blob/657ca504bcf9f0bfc1c6ee5fe838cde9a8976381/tests/test_llm_markov.py#L20)
prompt_words = prompt.split()
for word in words:
assert word in prompt_words
for w in prompt.split(" "):
self.assertIn(w, generated)

def test_prompt_markov(self):
out = self.expect_success(":memory:", "select prompt('hello world', 'markov')")
self.assertMarkovResult("hello world", out)

def test_prompt_default_markov(self):
self.assertEqual(llm_cli.get_default_model(), "markov")
out = self.expect_success(":memory:", "select prompt('hello world')")
self.assertMarkovResult("hello world", out)

def test_embed_hazo(self):
out = self.expect_success(":memory:", "select embed('hello world', 'hazo')")
self.assertEqual(
"('[5.0, 5.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]',)\n",
out,
)

def test_embed_default_hazo(self):
self.assertEqual(llm_cli.get_default_embedding_model(), "hazo")
out = self.expect_success(":memory:", "select embed('hello world')")
self.assertEqual(
"('[5.0, 5.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]',)\n",
out,
)


if __name__ == "__main__":
unittest.main()
2 changes: 2 additions & 0 deletions tsellm/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,3 +125,5 @@ def cli(*args):
console.interact(banner, exitmsg="")
finally:
con.close()

sys.exit(0)
23 changes: 23 additions & 0 deletions tsellm/core.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
import llm
import json

from llm import cli as llm_cli

TSELLM_CONFIG_SQL = """
-- tsellm configuration table
-- need to be taken care of accross migrations and versions.
CREATE TABLE IF NOT EXISTS __tsellm (
x text
);
Expand All @@ -12,7 +18,24 @@ def _prompt_model(prompt, model):
return llm.get_model(model).prompt(prompt).text()


def _prompt_model_default(prompt):
return llm.get_model("markov").prompt(prompt).text()


def _embed_model(text, model):
return json.dumps(llm.get_embedding_model(model).embed(text))


def _embed_model_default(text):
return json.dumps(
llm.get_embedding_model(llm_cli.get_default_embedding_model()).embed(text)
)


def _tsellm_init(con):
"""Entry-point for tsellm initialization."""
con.execute(TSELLM_CONFIG_SQL)
con.create_function("prompt", 2, _prompt_model)
con.create_function("prompt", 1, _prompt_model_default)
con.create_function("embed", 2, _embed_model)
con.create_function("embed", 1, _embed_model_default)

0 comments on commit c0c7483

Please sign in to comment.