From c0c748358f6a0dd8e98327871442e52eb71c2213 Mon Sep 17 00:00:00 2001 From: Florents Tselai Date: Wed, 26 Jun 2024 12:39:49 +0300 Subject: [PATCH] Add support for `prompt` annd `embed` and also handle default models (Closes #2, #4 ) (#7) --- .github/workflows/test.yml | 2 +- .gitignore | 3 +- setup.py | 7 ++- tests/conftest.py | 40 ------------- tests/test_tsellm.py | 116 ++++++++++++++++++++++++++++++++----- tsellm/cli.py | 2 + tsellm/core.py | 23 ++++++++ 7 files changed, 134 insertions(+), 59 deletions(-) delete mode 100644 tests/conftest.py diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 11c4741..9fe7d6d 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -24,5 +24,5 @@ jobs: pip install '.[test]' - name: Run tests run: | - pytest + python -m unittest tests/test_tsellm.py diff --git a/.gitignore b/.gitignore index f3b8a4e..7482cf0 100644 --- a/.gitignore +++ b/.gitignore @@ -10,4 +10,5 @@ venv dist build .idea/ -docs/_build \ No newline at end of file +docs/_build +*.db \ No newline at end of file diff --git a/setup.py b/setup.py index d32bac1..cd9c95f 100644 --- a/setup.py +++ b/setup.py @@ -33,7 +33,12 @@ def get_long_description(): packages=["tsellm"], install_requires=["llm", "setuptools", "pip"], extras_require={ - "test": ["pytest", "pytest-cov", "black", "ruff", "sqlite_utils", "llm-markov"] + "test": [ + "black", + "sqlite_utils", + "llm-markov", + "llm-embed-hazo", + ] }, python_requires=">=3.11", ) diff --git a/tests/conftest.py b/tests/conftest.py deleted file mode 100644 index 1dc46dc..0000000 --- a/tests/conftest.py +++ /dev/null @@ -1,40 +0,0 @@ -from sqlite_utils import Database -import pytest - -import pytest -import json -import llm -from llm.plugins import pm -from typing import Optional -import sqlite_utils -from pydantic import Field - - -def pytest_configure(config): - import sys - - sys._called_from_test = True - - -@pytest.fixture -def db_path(tmpdir): - path = str(tmpdir / "test.db") - return path - - -@pytest.fixture -def fresh_db_path(db_path): - return db_path - - -@pytest.fixture -def existing_db_path(fresh_db_path): - db = Database(fresh_db_path) - table = db.create_table("prompts", {"prompt": str, "generated": str}) - - table.insert({"prompt": "hello world!"}) - table.insert({"prompt": "how are you?"}) - table.insert({"prompt": "is this real life?"}) - table.insert({"prompt": "1+1=?"}) - - return fresh_db_path diff --git a/tests/test_tsellm.py b/tests/test_tsellm.py index 79df5f1..cdaa6f2 100644 --- a/tests/test_tsellm.py +++ b/tests/test_tsellm.py @@ -1,24 +1,108 @@ +import llm.cli from sqlite_utils import Database from tsellm.cli import cli +import unittest +from test.support import captured_stdout, captured_stderr, captured_stdin, os_helper +from test.support.os_helper import TESTFN, unlink +from llm import models +import sqlite3 +from llm import cli as llm_cli -def test_cli_prompt_mock(existing_db_path): - db = Database(existing_db_path) +class CommandLineInterface(unittest.TestCase): - assert db.execute("select prompt from prompts").fetchall() == [ - ("hello world!",), - ("how are you?",), - ("is this real life?",), - ("1+1=?",), - ] + def _do_test(self, *args, expect_success=True): + with ( + captured_stdout() as out, + captured_stderr() as err, + self.assertRaises(SystemExit) as cm, + ): + cli(args) + return out.getvalue(), err.getvalue(), cm.exception.code - cli([existing_db_path, "UPDATE prompts SET generated=prompt(prompt, 'markov')"]) + def expect_success(self, *args): + out, err, code = self._do_test(*args) + self.assertEqual(code, 0, "\n".join([f"Unexpected failure: {args=}", out, err])) - for prompt, generated in db.execute( - "select prompt, generated from prompts" - ).fetchall(): - words = generated.strip().split() + # This makes DeprecationWarning and other warnings cause a failure. + # Let's not be that harsh yet. + # See https://github.com/Florents-Tselai/llm/tree/fix-utc-warning-312 + # self.assertEqual(err, "") + return out + + def expect_failure(self, *args): + out, err, code = self._do_test(*args, expect_success=False) + self.assertNotEqual( + code, 0, "\n".join([f"Unexpected failure: {args=}", out, err]) + ) + self.assertEqual(out, "") + return err + + def test_cli_help(self): + out = self.expect_success("-h") + self.assertIn("usage: python -m tsellm", out) + + def test_cli_version(self): + out = self.expect_success("-v") + self.assertIn(sqlite3.sqlite_version, out) + + def test_cli_execute_sql(self): + out = self.expect_success(":memory:", "select 1") + self.assertIn("(1,)", out) + + def test_cli_execute_too_much_sql(self): + stderr = self.expect_failure(":memory:", "select 1; select 2") + err = "ProgrammingError: You can only execute one statement at a time" + self.assertIn(err, stderr) + + def test_cli_execute_incomplete_sql(self): + stderr = self.expect_failure(":memory:", "sel") + self.assertIn("OperationalError (SQLITE_ERROR)", stderr) + + def test_cli_on_disk_db(self): + self.addCleanup(unlink, TESTFN) + out = self.expect_success(TESTFN, "create table t(t)") + self.assertEqual(out, "") + out = self.expect_success(TESTFN, "select count(t) from t") + self.assertIn("(0,)", out) + + +class SQLiteLLMFunction(CommandLineInterface): + + def setUp(self): + super().setUp() + llm_cli.set_default_model("markov") + llm_cli.set_default_embedding_model("hazo") + + def assertMarkovResult(self, prompt, generated): # Every word should be one of the original prompt (see https://github.com/simonw/llm-markov/blob/657ca504bcf9f0bfc1c6ee5fe838cde9a8976381/tests/test_llm_markov.py#L20) - prompt_words = prompt.split() - for word in words: - assert word in prompt_words + for w in prompt.split(" "): + self.assertIn(w, generated) + + def test_prompt_markov(self): + out = self.expect_success(":memory:", "select prompt('hello world', 'markov')") + self.assertMarkovResult("hello world", out) + + def test_prompt_default_markov(self): + self.assertEqual(llm_cli.get_default_model(), "markov") + out = self.expect_success(":memory:", "select prompt('hello world')") + self.assertMarkovResult("hello world", out) + + def test_embed_hazo(self): + out = self.expect_success(":memory:", "select embed('hello world', 'hazo')") + self.assertEqual( + "('[5.0, 5.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]',)\n", + out, + ) + + def test_embed_default_hazo(self): + self.assertEqual(llm_cli.get_default_embedding_model(), "hazo") + out = self.expect_success(":memory:", "select embed('hello world')") + self.assertEqual( + "('[5.0, 5.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]',)\n", + out, + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/tsellm/cli.py b/tsellm/cli.py index 055d7ba..aa320ee 100644 --- a/tsellm/cli.py +++ b/tsellm/cli.py @@ -125,3 +125,5 @@ def cli(*args): console.interact(banner, exitmsg="") finally: con.close() + + sys.exit(0) diff --git a/tsellm/core.py b/tsellm/core.py index e28be0d..f844207 100644 --- a/tsellm/core.py +++ b/tsellm/core.py @@ -1,6 +1,12 @@ import llm +import json + +from llm import cli as llm_cli TSELLM_CONFIG_SQL = """ +-- tsellm configuration table +-- need to be taken care of accross migrations and versions. + CREATE TABLE IF NOT EXISTS __tsellm ( x text ); @@ -12,7 +18,24 @@ def _prompt_model(prompt, model): return llm.get_model(model).prompt(prompt).text() +def _prompt_model_default(prompt): + return llm.get_model("markov").prompt(prompt).text() + + +def _embed_model(text, model): + return json.dumps(llm.get_embedding_model(model).embed(text)) + + +def _embed_model_default(text): + return json.dumps( + llm.get_embedding_model(llm_cli.get_default_embedding_model()).embed(text) + ) + + def _tsellm_init(con): """Entry-point for tsellm initialization.""" con.execute(TSELLM_CONFIG_SQL) con.create_function("prompt", 2, _prompt_model) + con.create_function("prompt", 1, _prompt_model_default) + con.create_function("embed", 2, _embed_model) + con.create_function("embed", 1, _embed_model_default)