-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmain.py
159 lines (131 loc) · 4.37 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
import argparse
import inspect
import shutil
from dataclasses import dataclass
from prompt_toolkit import PromptSession
from prompt_toolkit.completion import WordCompleter
from prompt_toolkit.history import InMemoryHistory
from haverscript import Response, connect, echo
def bell():
print("\a", end="", flush=True) # Ring the bell
@dataclass(frozen=True)
class Commands:
def help(self, session):
print("/help")
print("/bye")
print("/context")
print("/undo")
return session
def bye(self, _):
exit()
def context(self, session):
print(session.render())
return session
def undo(self, session):
if isinstance(session, Response):
print("Undoing...")
return session.parent
bell()
return session
def main():
models = connect().list()
# Set up argument parsing
parser = argparse.ArgumentParser(description="Haverscript Shell")
parser.add_argument(
"--model",
type=str,
choices=models,
required=True,
help="The model to use for processing.",
)
parser.add_argument(
"--context",
type=str,
help="The context to use for processing.",
)
parser.add_argument(
"--cache",
type=str,
help="database to use as a cache.",
)
parser.add_argument(
"--temperature",
type=float,
help="temperature of replies.",
)
parser.add_argument(
"--num_predict",
type=int,
help="cap on number of tokens in reply.",
)
parser.add_argument("--num_ctx", type=int, help="context size")
args = parser.parse_args()
terminal_size = shutil.get_terminal_size(fallback=(80, 24))
terminal_width = terminal_size.columns
llm = connect(args.model) | echo(width=terminal_width - 2, prompt=False)
if args.temperature is not None:
llm = llm.options(temperature=args.temperature)
if args.num_predict is not None:
llm = llm.options(num_predict=args.num_predict)
if args.num_ctx is not None:
llm = llm.options(num_ctx=args.num_ctx)
if args.cache:
llm = llm.cache(args.cache)
if args.context:
with open(args.context) as f:
session = llm.load(markdown=f.read())
else:
session = llm
command_list = [
"/" + method
for method in dir(Commands)
if callable(getattr(Commands, method)) and not method.startswith("__")
]
print(f"connected to {args.model} (/help for help) ")
while True:
try:
if args.cache:
previous = session.children()
if previous:
history = InMemoryHistory()
for prompt in previous.prompt:
history.append_string(prompt)
prompt_session = PromptSession(history=history)
else:
prompt_session = PromptSession()
else:
prompt_session = PromptSession()
completer = WordCompleter(command_list, sentence=True)
try:
text = prompt_session.prompt("> ", completer=completer)
except EOFError:
exit()
if text.startswith("/"):
cmd = text.split(maxsplit=1)[0]
if cmd in command_list:
after = text[len(cmd) :].strip()
function = getattr(Commands(), cmd[1:], None)
if len(inspect.signature(function).parameters) > 1:
if not after:
print(f"{cmd} expects an argument")
continue
session = function(session, after)
else:
if after:
print(f"{cmd} does not take any argument")
continue
session = function(session)
continue
print(f"Unknown command {text}.")
continue
else:
print()
try:
session = session.chat(text)
except KeyboardInterrupt:
print("^C\n")
print()
except KeyboardInterrupt:
print("Use Ctrl + d or /bye to exit.")
if __name__ == "__main__":
main()