Skip to content

Commit

Permalink
Major refactoring and optimisation (#462)
Browse files Browse the repository at this point in the history
* Major refactoring and optimisation.
* Removed redundant code.
* Removed caching for --chat conversations.
* Fixing mypy typing.
* Printer class to handle LLM output formatting.
* Spinner and better handling when DISABLE_STREAMING is True.
  • Loading branch information
TheR1D authored Jan 31, 2024
1 parent c48926a commit ad6d297
Show file tree
Hide file tree
Showing 11 changed files with 163 additions and 128 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/lint_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ jobs:
run: isort sgpt tests scripts --check-only
- name: ruff
run: ruff sgpt tests scripts
# - name: mypy
# run: mypy sgpt --exclude function.py --exclude handler.py --exclude llm_functions
- name: mypy
run: mypy sgpt --exclude llm_functions
- name: tests
run: |
export OPENAI_API_KEY=test_api_key
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ skip = "__init__.py"

[tool.mypy]
strict = true
exclude = ["function.py", "handler.py", "llm_functions"]
exclude = ["llm_functions"]

[tool.ruff]
select = [
Expand Down
3 changes: 0 additions & 3 deletions sgpt/__main__.py

This file was deleted.

8 changes: 3 additions & 5 deletions sgpt/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,28 +204,26 @@ def main(
if repl:
# Will be in infinite loop here until user exits with Ctrl+C.
ReplHandler(repl, role_class).handle(
prompt,
init_prompt=prompt,
model=model,
temperature=temperature,
top_p=top_p,
chat_id=repl,
caching=cache,
functions=function_schemas,
)

if chat:
full_completion = ChatHandler(chat, role_class).handle(
prompt,
prompt=prompt,
model=model,
temperature=temperature,
top_p=top_p,
chat_id=chat,
caching=cache,
functions=function_schemas,
)
else:
full_completion = DefaultHandler(role_class).handle(
prompt,
prompt=prompt,
model=model,
temperature=temperature,
top_p=top_p,
Expand Down
12 changes: 5 additions & 7 deletions sgpt/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,19 +28,17 @@ def __call__(self, func: Callable[..., Any]) -> Callable[..., Any]:
"""

def wrapper(*args: Any, **kwargs: Any) -> Generator[str, None, None]:
# Exclude self instance from hashing.
cache_key = md5(json.dumps((args[1:], kwargs)).encode("utf-8")).hexdigest()
cache_file = self.cache_path / cache_key
# TODO: Fix caching for chat, should hash last user message, (not entire history).
if kwargs.pop("caching", True) and cache_file.exists():
yield cache_file.read_text()
key = md5(json.dumps((args[1:], kwargs)).encode("utf-8")).hexdigest()
file = self.cache_path / key
if kwargs.pop("caching") and file.exists():
yield file.read_text()
return
result = ""
for i in func(*args, **kwargs):
result += i
yield i
if "@FunctionCall" not in result:
cache_file.write_text(result)
file.write_text(result)
self._delete_oldest_files(self.length) # type: ignore

return wrapper
Expand Down
14 changes: 7 additions & 7 deletions sgpt/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import sys
from abc import ABCMeta
from pathlib import Path
from typing import Any, Callable
from typing import Any, Callable, Dict, List

from .config import cfg

Expand All @@ -16,23 +16,23 @@ def __init__(self, path: str):

@property
def name(self) -> str:
return self._name
return self._name # type: ignore

@property
def openai_schema(self) -> dict[str, Any]:
return self._openai_schema
return self._openai_schema # type: ignore

@property
def execute(self) -> Callable[..., str]:
return self._function
return self._function # type: ignore

@classmethod
def _read(cls, path: str) -> Any:
module_name = path.replace("/", ".").rstrip(".py")
spec = importlib.util.spec_from_file_location(module_name, path)
module = importlib.util.module_from_spec(spec)
module = importlib.util.module_from_spec(spec) # type: ignore
sys.modules[module_name] = module
spec.loader.exec_module(module)
spec.loader.exec_module(module) # type: ignore

if not isinstance(module.Function, ABCMeta):
raise TypeError(
Expand All @@ -58,5 +58,5 @@ def get_function(name: str) -> Callable[..., Any]:
raise ValueError(f"Function {name} not found")


def get_openai_schemas() -> [dict[str, Any]]:
def get_openai_schemas() -> List[Dict[str, Any]]:
return [function.openai_schema for function in functions]
23 changes: 12 additions & 11 deletions sgpt/handlers/chat_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,19 +42,20 @@ def __call__(self, func: Callable[..., Any]) -> Callable[..., Any]:

def wrapper(*args: Any, **kwargs: Any) -> Generator[str, None, None]:
chat_id = kwargs.pop("chat_id", None)
messages = kwargs["messages"]
if not kwargs.get("messages"):
return
if not chat_id:
yield from func(*args, **kwargs)
yield from func(*args, **kwargs, caching=False)
return
old_messages = self._read(chat_id)
for message in messages:
old_messages.append(message)
kwargs["messages"] = old_messages
previous_messages = self._read(chat_id)
for message in kwargs["messages"]:
previous_messages.append(message)
kwargs["messages"] = previous_messages
response_text = ""
for word in func(*args, **kwargs):
response_text += word
yield word
old_messages.append({"role": "assistant", "content": response_text})
previous_messages.append({"role": "assistant", "content": response_text})
self._write(kwargs["messages"], chat_id)

return wrapper
Expand Down Expand Up @@ -160,8 +161,8 @@ def make_messages(self, prompt: str) -> List[Dict[str, str]]:
return messages

@chat_session
def get_completion(
self,
**kwargs: Any,
) -> Generator[str, None, None]:
def get_completion(self, **kwargs: Any) -> Generator[str, None, None]:
yield from super().get_completion(**kwargs)

def handle(self, **kwargs: Any) -> str: # type: ignore[override]
return super().handle(**kwargs, chat_id=self.chat_id)
142 changes: 61 additions & 81 deletions sgpt/handlers/handler.py
Original file line number Diff line number Diff line change
@@ -1,80 +1,47 @@
import json
from pathlib import Path
from typing import Any, Dict, Generator, List
from typing import Any, Dict, Generator, List, Optional

import typer
from openai import OpenAI
from rich.console import Console
from rich.live import Live
from rich.markdown import Markdown

from ..cache import Cache
from ..config import cfg
from ..function import get_function
from ..printer import MarkdownPrinter, Printer, TextPrinter
from ..role import DefaultRoles, SystemRole

cache = Cache(int(cfg.get("CACHE_LENGTH")), Path(cfg.get("CACHE_PATH")))


class Handler:
cache = Cache(int(cfg.get("CACHE_LENGTH")), Path(cfg.get("CACHE_PATH")))

def __init__(self, role: SystemRole) -> None:
self.client = OpenAI(
base_url=cfg.get("OPENAI_BASE_URL"),
api_key=cfg.get("OPENAI_API_KEY"),
timeout=int(cfg.get("REQUEST_TIMEOUT")),
)
self.role = role
self.disable_stream = cfg.get("DISABLE_STREAMING") == "true"
self.show_functions_output = cfg.get("SHOW_FUNCTIONS_OUTPUT") == "true"
self.color = cfg.get("DEFAULT_COLOR")
self.theme_name = cfg.get("CODE_THEME")

def _handle_with_markdown(self, prompt: str, **kwargs: Any) -> str:
messages = self.make_messages(prompt.strip())
full_completion = ""
with Live(
Markdown(markup="", code_theme=self.theme_name),
console=Console(),
) as live:
if self.disable_stream:
live.update(
Markdown(markup="Loading...\r", code_theme=self.theme_name),
refresh=True,
)
for word in self.get_completion(messages=messages, **kwargs):
full_completion += word
live.update(
Markdown(markup=full_completion, code_theme=self.theme_name),
refresh=not self.disable_stream,
)
return full_completion

def _handle_with_plain_text(self, prompt: str, **kwargs: Any) -> str:
messages = self.make_messages(prompt.strip())
full_completion = ""
if self.disable_stream:
typer.echo("Loading...\r", nl=False)
for word in self.get_completion(messages=messages, **kwargs):
typer.secho(word, fg=self.color, bold=True, nl=False)
full_completion += word
# Overwrite "loading..."
typer.echo("\033[K" if not self.disable_stream else "")
return full_completion
@property
def printer(self) -> Printer:
use_markdown = "APPLY MARKDOWN" in self.role.role
code_theme, color = cfg.get("CODE_THEME"), cfg.get("DEFAULT_COLOR")
return MarkdownPrinter(code_theme) if use_markdown else TextPrinter(color)

def make_messages(self, prompt: str) -> List[Dict[str, str]]:
raise NotImplementedError

def handle_function_call(
self,
messages: List[dict[str, str]],
messages: List[dict[str, Any]],
name: str,
arguments: str,
) -> Generator[str, None, None]:
messages.append(
{
"role": "assistant",
"content": "",
"function_call": {"name": name, "arguments": arguments}, # type: ignore
"function_call": {"name": name, "arguments": arguments},
}
)

Expand All @@ -84,57 +51,70 @@ def handle_function_call(
dict_args = json.loads(arguments)
joined_args = ", ".join(f'{k}="{v}"' for k, v in dict_args.items())
yield f"> @FunctionCall `{name}({joined_args})` \n\n"

result = get_function(name)(**dict_args)
if self.show_functions_output:
if cfg.get("SHOW_FUNCTIONS_OUTPUT") == "true":
yield f"```text\n{result}\n```\n"
messages.append({"role": "function", "content": result, "name": name})

# TODO: Fix MyPy typing errors. This modules is excluded from MyPy checks.
@cache
def get_completion(self, **kwargs: Any) -> Generator[str, None, None]:
func_call = {"name": None, "arguments": ""}

def get_completion(
self,
model: str,
temperature: float,
top_p: float,
messages: List[Dict[str, Any]],
functions: Optional[List[Dict[str, str]]],
) -> Generator[str, None, None]:
name = arguments = ""
is_shell_role = self.role.name == DefaultRoles.SHELL.value
is_code_role = self.role.name == DefaultRoles.CODE.value
is_dsc_shell_role = self.role.name == DefaultRoles.DESCRIBE_SHELL.value
if is_shell_role or is_code_role or is_dsc_shell_role:
kwargs["functions"] = None

if self.disable_stream:
completion = self.client.chat.completions.create(**kwargs)
message = completion.choices[0].message
if completion.choices[0].finish_reason == "function_call":
name, arguments = (
message.function_call.name,
message.function_call.arguments,
)
yield from self.handle_function_call(
kwargs["messages"], name, arguments
)
yield from self.get_completion(**kwargs, caching=False)
yield message.content or ""
return

for chunk in self.client.chat.completions.create(**kwargs, stream=True):
delta = chunk.choices[0].delta
functions = None

for chunk in self.client.chat.completions.create(
model=model,
temperature=temperature,
top_p=top_p,
messages=messages, # type: ignore
functions=functions, # type: ignore
stream=True,
):
delta = chunk.choices[0].delta # type: ignore
if delta.function_call:
if delta.function_call.name:
func_call["name"] = delta.function_call.name
name = delta.function_call.name
if delta.function_call.arguments:
func_call["arguments"] += delta.function_call.arguments
if chunk.choices[0].finish_reason == "function_call":
name, arguments = func_call["name"], func_call["arguments"]
yield from self.handle_function_call(
kwargs["messages"], name, arguments
arguments += delta.function_call.arguments
if chunk.choices[0].finish_reason == "function_call": # type: ignore
yield from self.handle_function_call(messages, name, arguments)
yield from self.get_completion(
model, temperature, top_p, messages, functions, caching=False
)
yield from self.get_completion(**kwargs, caching=False)
return

yield delta.content or ""

def handle(self, prompt: str, **kwargs: Any) -> str:
default = DefaultRoles.DEFAULT.value
shell_descriptor = DefaultRoles.DESCRIBE_SHELL.value
if self.role.name == default or self.role.name == shell_descriptor:
return self._handle_with_markdown(prompt, **kwargs)
return self._handle_with_plain_text(prompt, **kwargs)
def handle(
self,
prompt: str,
model: str,
temperature: float,
top_p: float,
caching: bool,
functions: Optional[List[Dict[str, str]]] = None,
**kwargs: Any,
) -> str:
disable_stream = cfg.get("DISABLE_STREAMING") == "true"
messages = self.make_messages(prompt.strip())
generator = self.get_completion(
model=model,
temperature=temperature,
top_p=top_p,
messages=messages,
functions=functions,
caching=caching,
**kwargs,
)
return self.printer(generator, not disable_stream)
9 changes: 2 additions & 7 deletions sgpt/handlers/repl_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ def handle(self, init_prompt: str, **kwargs: Any) -> None: # type: ignore
if prompt == '"""':
prompt = self._get_multiline_input()
if prompt == "exit()":
# This is also useful during tests.
raise typer.Exit()
if init_prompt:
prompt = f"{init_prompt}\n\n\n{prompt}"
Expand All @@ -61,11 +60,7 @@ def handle(self, init_prompt: str, **kwargs: Any) -> None: # type: ignore
rich_print(Rule(style="bold magenta"))
elif self.role.name == DefaultRoles.SHELL.value and prompt == "d":
DefaultHandler(DefaultRoles.DESCRIBE_SHELL.get_role()).handle(
full_completion,
model=kwargs.get("model"),
temperature=kwargs.get("temperature"),
top_p=kwargs.get("top_p"),
caching=kwargs.get("caching"),
prompt=full_completion, **kwargs
)
else:
full_completion = super().handle(prompt, **kwargs)
full_completion = super().handle(prompt=prompt, **kwargs)
Loading

0 comments on commit ad6d297

Please sign in to comment.