Skip to content

Commit

Permalink
feat: Extract command from backticks
Browse files Browse the repository at this point in the history
  • Loading branch information
Columpio committed Nov 27, 2024
1 parent aac2f54 commit 31a566e
Show file tree
Hide file tree
Showing 3 changed files with 106 additions and 3 deletions.
4 changes: 2 additions & 2 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ python -m venv env && source ./env/bin/activate
Install the necessary dependencies, including development and test dependencies:

```shell
pip install -e ."[dev,test]"
pip install -e ."[dev,test,litellm]"
```

### Start Coding
Expand All @@ -35,4 +35,4 @@ Before creating a pull request, run `scripts/lint.sh` and `scripts/tests.sh` to
### Code Review
After submitting your pull request, be patient and receptive to feedback from reviewers. Address any concerns they raise and collaborate to refine the code. Together, we can enhance the ShellGPT project.

Thank you once again for your contribution! We're excited to have you join us.
Thank you once again for your contribution! We're excited to have you join us.
55 changes: 54 additions & 1 deletion sgpt/handlers/handler.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import json
import re
from pathlib import Path
from typing import Any, Callable, Dict, Generator, List, Optional

Expand Down Expand Up @@ -37,6 +38,7 @@ class Handler:

def __init__(self, role: SystemRole, markdown: bool) -> None:
self.role = role
self.is_shell = role.name == DefaultRoles.SHELL.value

api_base_url = cfg.get("API_BASE_URL")
self.base_url = None if api_base_url == "default" else api_base_url
Expand All @@ -45,6 +47,13 @@ def __init__(self, role: SystemRole, markdown: bool) -> None:
self.markdown = "APPLY MARKDOWN" in self.role.role and markdown
self.code_theme, self.color = cfg.get("CODE_THEME"), cfg.get("DEFAULT_COLOR")

self.backticks_start = re.compile(r"(^|[\r\n]+)```\w*[\r\n]+")
end_regex_parts = [r"[\r\n]+", "`", "`", "`", r"([\r\n]+|$)"]
self.backticks_end_prefixes = [
re.compile("".join(end_regex_parts[: i + 1]))
for i in range(len(end_regex_parts))
]

@property
def printer(self) -> Printer:
return (
Expand Down Expand Up @@ -82,6 +91,48 @@ def handle_function_call(
yield f"```text\n{result}\n```\n"
messages.append({"role": "function", "content": result, "name": name})

def _matches_end_at(self, text: str) -> tuple[bool, int]:
end_of_match = 0
for _i, regex in enumerate(self.backticks_end_prefixes):
m = regex.search(text)
if m:
end_of_match = m.end()
else:
return False, end_of_match
return True, m.start()

def _filter_chunks(
self, chunks: Generator[str, None, None]
) -> Generator[str, None, None]:
buffer = ""
inside_backticks = False
end_of_beginning = 0

for chunk in chunks:
buffer += chunk
if not inside_backticks:
m = self.backticks_start.search(buffer)
if not m:
continue
new_end_of_beginning = m.end()
if new_end_of_beginning > end_of_beginning:
end_of_beginning = new_end_of_beginning
continue
inside_backticks = True
buffer = buffer[end_of_beginning:]
if inside_backticks:
matches_end, index = self._matches_end_at(buffer)
if matches_end:
yield buffer[:index]
return
if index == len(buffer):
continue
else:
yield buffer
buffer = ""
if buffer:
yield buffer

@cache
def get_completion(
self,
Expand Down Expand Up @@ -152,7 +203,7 @@ def handle(
functions: Optional[List[Dict[str, str]]] = None,
**kwargs: Any,
) -> str:
disable_stream = cfg.get("DISABLE_STREAMING") == "true"
disable_stream = True # cfg.get("DISABLE_STREAMING") == "true"
messages = self.make_messages(prompt.strip())
generator = self.get_completion(
model=model,
Expand All @@ -163,4 +214,6 @@ def handle(
caching=caching,
**kwargs,
)
if self.role.name == DefaultRoles.SHELL.value:
generator = self._filter_chunks(generator)
return self.printer(generator, not disable_stream)
50 changes: 50 additions & 0 deletions tests/test_shell.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from pathlib import Path
from unittest.mock import patch

import pytest

from sgpt.config import cfg
from sgpt.role import DefaultRoles, SystemRole

Expand All @@ -22,6 +24,54 @@ def test_shell(completion):
assert "[E]xecute, [D]escribe, [A]bort:" in result.stdout


@patch("sgpt.handlers.handler.completion")
@pytest.mark.parametrize(
"prefix,suffix",
[
("", ""),
("some text before\n```powershell\n", "\n```" ""),
("```powershell\n", "\n```\nsome text after" ""),
("some text before\n```powershell\n", "\n```\nsome text after" ""),
(
"some text with ``` before\n```powershell\n",
"\n```\nsome text with ``` after" "",
),
("```powershell\n", "\n```" ""),
("```\n", "\n```" ""),
("```powershell\r\n", "\r\n```" ""),
("```\r\n", "\r\n```" ""),
("```powershell\r", "\r```" ""),
("```\r", "\r```" ""),
],
)
@pytest.mark.parametrize("group_by_size", range(10))
def test_shell_no_backticks(completion, prefix: str, suffix: str, group_by_size: int):
expected_output = "Get-Process | \nWhere-Object { $_.Port -eq 9000 }\r\n | Select-Object Id | Text \r\nwith '```' inside"
produced_output = prefix + expected_output + suffix
if group_by_size == 0:
produced_tokens = list(produced_output)
else:
produced_tokens = [
produced_output[i : i + group_by_size]
for i in range(0, len(produced_output), group_by_size)
]
assert produced_output == "".join(produced_tokens)

role = SystemRole.get(DefaultRoles.SHELL.value)
completion.return_value = mock_comp(produced_tokens)

args = {"prompt": "find pid by port 9000", "--shell": True}
result = runner.invoke(app, cmd_args(**args))

completion.assert_called_once_with(**comp_args(role, args["prompt"]))
index = result.stdout.find(expected_output)
assert index >= 0
rest = result.stdout[index + len(expected_output) :].strip()
assert "`" not in rest
assert result.exit_code == 0
assert "[E]xecute, [D]escribe, [A]bort:" == rest


@patch("sgpt.printer.TextPrinter.live_print")
@patch("sgpt.printer.MarkdownPrinter.live_print")
@patch("sgpt.handlers.handler.completion")
Expand Down

0 comments on commit 31a566e

Please sign in to comment.