1
+ import llm .cli
1
2
from sqlite_utils import Database
2
3
from tsellm .cli import cli
3
4
import unittest
4
5
from test .support import captured_stdout , captured_stderr , captured_stdin , os_helper
5
6
from test .support .os_helper import TESTFN , unlink
6
-
7
+ from llm import models
7
8
import sqlite3
9
+ from llm import cli as llm_cli
8
10
9
11
10
12
class CommandLineInterface (unittest .TestCase ):
@@ -25,7 +27,7 @@ def expect_success(self, *args):
25
27
# This makes DeprecationWarning and other warnings cause a failure.
26
28
# Let's not be that harsh yet.
27
29
# See https://github.com/Florents-Tselai/llm/tree/fix-utc-warning-312
28
- #self.assertEqual(err, "")
30
+ # self.assertEqual(err, "")
29
31
return out
30
32
31
33
def expect_failure (self , * args ):
@@ -65,17 +67,26 @@ def test_cli_on_disk_db(self):
65
67
self .assertIn ("(0,)" , out )
66
68
67
69
68
- class PromptFunction (CommandLineInterface ):
69
- """Testing the SELECT prompt(...) function"""
70
+ class SQLiteLLMFunction (CommandLineInterface ):
70
71
71
- def test_prompt_markov (self ):
72
- out = self .expect_success (":memory:" , "select prompt('hello world', 'markov')" )
72
+ def setUp (self ):
73
+ super ().setUp ()
74
+ llm_cli .set_default_model ("markov" )
75
+ llm_cli .set_default_embedding_model ("hazo" )
76
+
77
+ def assertMarkovResult (self , prompt , generated ):
73
78
# Every word should be one of the original prompt (see https://github.com/simonw/llm-markov/blob/657ca504bcf9f0bfc1c6ee5fe838cde9a8976381/tests/test_llm_markov.py#L20)
74
- self . assertIn ( "hello" , out )
75
- self .assertIn ("world" , out )
79
+ for w in prompt . split ( " " ):
80
+ self .assertIn (w , generated )
76
81
82
+ def test_prompt_markov (self ):
83
+ out = self .expect_success (":memory:" , "select prompt('hello world', 'markov')" )
84
+ self .assertMarkovResult ("hello world" , out )
77
85
78
- class EmbedFunction (CommandLineInterface ):
86
+ def test_prompt_default_markov (self ):
87
+ self .assertEquals (llm_cli .get_default_model (), "markov" )
88
+ out = self .expect_success (":memory:" , "select prompt('hello world')" )
89
+ self .assertMarkovResult ("hello world" , out )
79
90
80
91
def test_embed_hazo (self ):
81
92
out = self .expect_success (":memory:" , "select embed('hello world', 'hazo')" )
@@ -84,6 +95,14 @@ def test_embed_hazo(self):
84
95
out ,
85
96
)
86
97
98
+ def test_embed_default_hazo (self ):
99
+ self .assertEquals (llm_cli .get_default_embedding_model (), "hazo" )
100
+ out = self .expect_success (":memory:" , "select embed('hello world')" )
101
+ self .assertEqual (
102
+ "('[5.0, 5.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]',)\n " ,
103
+ out ,
104
+ )
105
+
87
106
#
88
107
# def test_cli_prompt(existing_db_path):
89
108
# db = Database(existing_db_path)
0 commit comments