Skip to content

Commit 63d2b19

Browse files
Handle default models for embed and prompt
1 parent af080a7 commit 63d2b19

File tree

2 files changed

+36
-9
lines changed

2 files changed

+36
-9
lines changed

tests/test_tsellm.py

+28-9
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
1+
import llm.cli
12
from sqlite_utils import Database
23
from tsellm.cli import cli
34
import unittest
45
from test.support import captured_stdout, captured_stderr, captured_stdin, os_helper
56
from test.support.os_helper import TESTFN, unlink
6-
7+
from llm import models
78
import sqlite3
9+
from llm import cli as llm_cli
810

911

1012
class CommandLineInterface(unittest.TestCase):
@@ -25,7 +27,7 @@ def expect_success(self, *args):
2527
# This makes DeprecationWarning and other warnings cause a failure.
2628
# Let's not be that harsh yet.
2729
# See https://github.com/Florents-Tselai/llm/tree/fix-utc-warning-312
28-
#self.assertEqual(err, "")
30+
# self.assertEqual(err, "")
2931
return out
3032

3133
def expect_failure(self, *args):
@@ -65,17 +67,26 @@ def test_cli_on_disk_db(self):
6567
self.assertIn("(0,)", out)
6668

6769

68-
class PromptFunction(CommandLineInterface):
69-
"""Testing the SELECT prompt(...) function"""
70+
class SQLiteLLMFunction(CommandLineInterface):
7071

71-
def test_prompt_markov(self):
72-
out = self.expect_success(":memory:", "select prompt('hello world', 'markov')")
72+
def setUp(self):
73+
super().setUp()
74+
llm_cli.set_default_model("markov")
75+
llm_cli.set_default_embedding_model("hazo")
76+
77+
def assertMarkovResult(self, prompt, generated):
7378
# Every word should be one of the original prompt (see https://github.com/simonw/llm-markov/blob/657ca504bcf9f0bfc1c6ee5fe838cde9a8976381/tests/test_llm_markov.py#L20)
74-
self.assertIn("hello", out)
75-
self.assertIn("world", out)
79+
for w in prompt.split(" "):
80+
self.assertIn(w, generated)
7681

82+
def test_prompt_markov(self):
83+
out = self.expect_success(":memory:", "select prompt('hello world', 'markov')")
84+
self.assertMarkovResult("hello world", out)
7785

78-
class EmbedFunction(CommandLineInterface):
86+
def test_prompt_default_markov(self):
87+
self.assertEquals(llm_cli.get_default_model(), "markov")
88+
out = self.expect_success(":memory:", "select prompt('hello world')")
89+
self.assertMarkovResult("hello world", out)
7990

8091
def test_embed_hazo(self):
8192
out = self.expect_success(":memory:", "select embed('hello world', 'hazo')")
@@ -84,6 +95,14 @@ def test_embed_hazo(self):
8495
out,
8596
)
8697

98+
def test_embed_default_hazo(self):
99+
self.assertEquals(llm_cli.get_default_embedding_model(), "hazo")
100+
out = self.expect_success(":memory:", "select embed('hello world')")
101+
self.assertEqual(
102+
"('[5.0, 5.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]',)\n",
103+
out,
104+
)
105+
87106
#
88107
# def test_cli_prompt(existing_db_path):
89108
# db = Database(existing_db_path)

tsellm/core.py

+8
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import llm
22
import json
33

4+
from llm import cli as llm_cli
5+
46
TSELLM_CONFIG_SQL = """
57
-- tsellm configuration table
68
-- need to be taken care of accross migrations and versions.
@@ -15,13 +17,19 @@
1517
def _prompt_model(prompt, model):
1618
return llm.get_model(model).prompt(prompt).text()
1719

20+
def _prompt_model_default(prompt):
21+
return llm.get_model("markov").prompt(prompt).text()
1822

1923
def _embed_model(text, model):
2024
return json.dumps(llm.get_embedding_model(model).embed(text))
2125

26+
def _embed_model_default(text):
27+
return json.dumps(llm.get_embedding_model(llm_cli.get_default_embedding_model()).embed(text))
2228

2329
def _tsellm_init(con):
2430
"""Entry-point for tsellm initialization."""
2531
con.execute(TSELLM_CONFIG_SQL)
2632
con.create_function("prompt", 2, _prompt_model)
33+
con.create_function("prompt", 1, _prompt_model_default)
2734
con.create_function("embed", 2, _embed_model)
35+
con.create_function("embed", 1, _embed_model_default)

0 commit comments

Comments
 (0)