Skip to content

Commit

Permalink
Ollama integration 🦙 (#463)
Browse files Browse the repository at this point in the history
  • Loading branch information
TheR1D authored Feb 9, 2024
1 parent ad6d297 commit 1cb61de
Show file tree
Hide file tree
Showing 9 changed files with 90 additions and 106 deletions.
14 changes: 7 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,12 @@ https://github.com/TheR1D/shell_gpt/assets/16740832/9197283c-db6a-4b46-bfea-3eb7
```shell
pip install shell-gpt
```
By default, ShellGPT uses OpenAI's API and GPT-4 model. You'll need an API key, you can generate one [here](https://beta.openai.com/account/api-keys). You will be prompted for your key which will then be stored in `~/.config/shell_gpt/.sgptrc`. OpenAI API is not free of charge, please refer to the [OpenAI pricing](https://openai.com/pricing) for more information.

You'll need an OpenAI API key, you can generate one [here](https://beta.openai.com/account/api-keys).
You will be prompted for your key which will then be stored in `~/.config/shell_gpt/.sgptrc`.
> [!TIP]
> Alternatively, you can use locally hosted open source models which are available for free. To use local models, you will need to run your own LLM backend server such as [Ollama](https://github.com/ollama/ollama). To set up ShellGPT with Ollama, please follow this comprehensive [guide](https://github.com/TheR1D/shell_gpt/wiki/Ollama).
>
> **❗️Note that ShellGPT is not optimized for local models and may not work as expected.**
## Usage
**ShellGPT** is designed to quickly analyse and retrieve information. It's useful for straightforward requests ranging from technical configurations to general knowledge.
Expand All @@ -24,7 +27,7 @@ git diff | sgpt "Generate git commit message, for my changes"
# -> Added main feature details into README.md
```

You can analyze logs from various sources by passing them using stdin, along with a prompt. This enables you to quickly identify errors and get suggestions for possible solutions:
You can analyze logs from various sources by passing them using stdin, along with a prompt. For instance, we can use it to quickly analyze logs, identify errors and get suggestions for possible solutions:
```shell
docker logs -n 20 my_app | sgpt "check logs, find errors, provide possible solutions"
```
Expand All @@ -40,7 +43,7 @@ You can also use all kind of redirection operators to pass input:
sgpt "summarise" < document.txt
# -> The document discusses the impact...
sgpt << EOF
What is the best way to lear Golang.
What is the best way to lear Golang?
Provide simple hello world example.
EOF
# -> The best way to learn Golang...
Expand Down Expand Up @@ -444,9 +447,6 @@ Possible options for `CODE_THEME`: https://pygments.org/styles/
╰──────────────────────────────────────────────────────────────────────────────────────────────────────────╯
```

## LocalAI
By default, ShellGPT leverages OpenAI's large language models. However, it also provides the flexibility to use locally hosted models, which can be a cost-effective alternative. To use local models, you will need to run your own API server. You can accomplish this by using [LocalAI](https://github.com/go-skynet/LocalAI), a self-hosted, OpenAI-compatible API. Setting up LocalAI allows you to run language models on your own hardware, potentially without the need for an internet connection, depending on your usage. To set up your LocalAI, please follow this comprehensive [guide](https://github.com/TheR1D/shell_gpt/wiki/LocalAI). Remember that the performance of your local models may depend on the specifications of your hardware and the specific language model you choose to deploy.

## Docker
Run the container using the `OPENAI_API_KEY` environment variable, and a docker volume to store cache:
```shell
Expand Down
23 changes: 11 additions & 12 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ build-backend = "hatchling.build"

[project]
name = "shell_gpt"
description = "A command-line productivity tool powered by OpenAI GPT models, will help you accomplish your tasks faster and more efficiently."
keywords = ["shell", "gpt", "openai", "cli", "productivity", "cheet-sheet"]
description = "A command-line productivity tool powered by large language models, will help you accomplish your tasks faster and more efficiently."
keywords = ["shell", "gpt", "openai", "ollama", "cli", "productivity", "cheet-sheet"]
readme = "README.md"
license = "MIT"
requires-python = ">=3.6"
Expand All @@ -28,24 +28,15 @@ classifiers = [
"Programming Language :: Python :: 3.11",
]
dependencies = [
"requests >= 2.28.2, < 3.0.0",
"litellm >= 1.20.1, < 2.0.0",
"typer >= 0.7.0, < 1.0.0",
"click >= 7.1.1, < 9.0.0",
"rich >= 13.1.0, < 14.0.0",
"distro >= 1.8.0, < 2.0.0",
"openai >= 1.6.1, < 2.0.0",
"instructor >= 0.4.5, < 1.0.0",
'pyreadline3 >= 3.4.1, < 4.0.0; sys_platform == "win32"',
]

[project.scripts]
sgpt = "sgpt:cli"

[project.urls]
homepage = "https://github.com/ther1d/shell_gpt"
repository = "https://github.com/ther1d/shell_gpt"
documentation = "https://github.com/TheR1D/shell_gpt/blob/main/README.md"

[project.optional-dependencies]
test = [
"pytest >= 7.2.2, < 8.0.0",
Expand All @@ -61,6 +52,14 @@ dev = [
"pre-commit >= 3.1.1, < 4.0.0",
]

[project.scripts]
sgpt = "sgpt:cli"

[project.urls]
homepage = "https://github.com/ther1d/shell_gpt"
repository = "https://github.com/ther1d/shell_gpt"
documentation = "https://github.com/TheR1D/shell_gpt/blob/main/README.md"

[tool.hatch.version]
path = "sgpt/__version__.py"

Expand Down
2 changes: 1 addition & 1 deletion sgpt/__version__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "1.2.0"
__version__ = "1.3.0"
33 changes: 16 additions & 17 deletions sgpt/handlers/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,24 +2,21 @@
from pathlib import Path
from typing import Any, Dict, Generator, List, Optional

from openai import OpenAI
import litellm # type: ignore

from ..cache import Cache
from ..config import cfg
from ..function import get_function
from ..printer import MarkdownPrinter, Printer, TextPrinter
from ..role import DefaultRoles, SystemRole

litellm.suppress_debug_info = True


class Handler:
cache = Cache(int(cfg.get("CACHE_LENGTH")), Path(cfg.get("CACHE_PATH")))

def __init__(self, role: SystemRole) -> None:
self.client = OpenAI(
base_url=cfg.get("OPENAI_BASE_URL"),
api_key=cfg.get("OPENAI_API_KEY"),
timeout=int(cfg.get("REQUEST_TIMEOUT")),
)
self.role = role

@property
Expand Down Expand Up @@ -73,28 +70,30 @@ def get_completion(
if is_shell_role or is_code_role or is_dsc_shell_role:
functions = None

for chunk in self.client.chat.completions.create(
for chunk in litellm.completion(
model=model,
temperature=temperature,
top_p=top_p,
messages=messages, # type: ignore
functions=functions, # type: ignore
messages=messages,
functions=functions,
stream=True,
api_key=cfg.get("OPENAI_API_KEY"),
):
delta = chunk.choices[0].delta # type: ignore
if delta.function_call:
if delta.function_call.name:
name = delta.function_call.name
if delta.function_call.arguments:
arguments += delta.function_call.arguments
if chunk.choices[0].finish_reason == "function_call": # type: ignore
delta = chunk.choices[0].delta
function_call = delta.get("function_call")
if function_call:
if function_call.name:
name = function_call.name
if function_call.arguments:
arguments += function_call.arguments
if chunk.choices[0].finish_reason == "function_call":
yield from self.handle_function_call(messages, name, arguments)
yield from self.get_completion(
model, temperature, top_p, messages, functions, caching=False
)
return

yield delta.content or ""
yield delta.get("content") or ""

def handle(
self,
Expand Down
26 changes: 13 additions & 13 deletions tests/test_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,14 @@
from sgpt.config import cfg
from sgpt.role import DefaultRoles, SystemRole

from .utils import app, cmd_args, comp_args, comp_chunks, runner
from .utils import app, cmd_args, comp_args, mock_comp, runner

role = SystemRole.get(DefaultRoles.CODE.value)


@patch("openai.resources.chat.Completions.create")
@patch("litellm.completion")
def test_code_generation(mock):
mock.return_value = comp_chunks("print('Hello World')")
mock.return_value = mock_comp("print('Hello World')")

args = {"prompt": "hello world python", "--code": True}
result = runner.invoke(app, cmd_args(**args))
Expand All @@ -21,9 +21,9 @@ def test_code_generation(mock):
assert "print('Hello World')" in result.stdout


@patch("openai.resources.chat.Completions.create")
@patch("litellm.completion")
def test_code_generation_stdin(completion):
completion.return_value = comp_chunks("# Hello\nprint('Hello')")
completion.return_value = mock_comp("# Hello\nprint('Hello')")

args = {"prompt": "make comments for code", "--code": True}
stdin = "print('Hello')"
Expand All @@ -36,11 +36,11 @@ def test_code_generation_stdin(completion):
assert "print('Hello')" in result.stdout


@patch("openai.resources.chat.Completions.create")
@patch("litellm.completion")
def test_code_chat(completion):
completion.side_effect = [
comp_chunks("print('hello')"),
comp_chunks("print('hello')\nprint('world')"),
mock_comp("print('hello')"),
mock_comp("print('hello')\nprint('world')"),
]
chat_name = "_test"
chat_path = Path(cfg.get("CHAT_CACHE_PATH")) / chat_name
Expand Down Expand Up @@ -77,11 +77,11 @@ def test_code_chat(completion):
# TODO: Code chat can be recalled without --code option.


@patch("openai.resources.chat.Completions.create")
@patch("litellm.completion")
def test_code_repl(completion):
completion.side_effect = [
comp_chunks("print('hello')"),
comp_chunks("print('hello')\nprint('world')"),
mock_comp("print('hello')"),
mock_comp("print('hello')\nprint('world')"),
]
chat_name = "_test"
chat_path = Path(cfg.get("CHAT_CACHE_PATH")) / chat_name
Expand Down Expand Up @@ -109,7 +109,7 @@ def test_code_repl(completion):
assert "print('world')" in result.stdout


@patch("openai.resources.chat.Completions.create")
@patch("litellm.completion")
def test_code_and_shell(completion):
args = {"--code": True, "--shell": True}
result = runner.invoke(app, cmd_args(**args))
Expand All @@ -119,7 +119,7 @@ def test_code_and_shell(completion):
assert "Error" in result.stdout


@patch("openai.resources.chat.Completions.create")
@patch("litellm.completion")
def test_code_and_describe_shell(completion):
args = {"--code": True, "--describe-shell": True}
result = runner.invoke(app, cmd_args(**args))
Expand Down
28 changes: 14 additions & 14 deletions tests/test_default.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,15 @@
from sgpt.__version__ import __version__
from sgpt.role import DefaultRoles, SystemRole

from .utils import app, cmd_args, comp_args, comp_chunks, runner
from .utils import app, cmd_args, comp_args, mock_comp, runner

role = SystemRole.get(DefaultRoles.DEFAULT.value)
cfg = config.cfg


@patch("openai.resources.chat.Completions.create")
@patch("litellm.completion")
def test_default(completion):
completion.return_value = comp_chunks("Prague")
completion.return_value = mock_comp("Prague")

args = {"prompt": "capital of the Czech Republic?"}
result = runner.invoke(app, cmd_args(**args))
Expand All @@ -26,9 +26,9 @@ def test_default(completion):
assert "Prague" in result.stdout


@patch("openai.resources.chat.Completions.create")
@patch("litellm.completion")
def test_default_stdin(completion):
completion.return_value = comp_chunks("Prague")
completion.return_value = mock_comp("Prague")

stdin = "capital of the Czech Republic?"
result = runner.invoke(app, cmd_args(), input=stdin)
Expand All @@ -38,9 +38,9 @@ def test_default_stdin(completion):
assert "Prague" in result.stdout


@patch("openai.resources.chat.Completions.create")
@patch("litellm.completion")
def test_default_chat(completion):
completion.side_effect = [comp_chunks("ok"), comp_chunks("4")]
completion.side_effect = [mock_comp("ok"), mock_comp("4")]
chat_name = "_test"
chat_path = Path(cfg.get("CHAT_CACHE_PATH")) / chat_name
chat_path.unlink(missing_ok=True)
Expand Down Expand Up @@ -90,9 +90,9 @@ def test_default_chat(completion):
chat_path.unlink()


@patch("openai.resources.chat.Completions.create")
@patch("litellm.completion")
def test_default_repl(completion):
completion.side_effect = [comp_chunks("ok"), comp_chunks("8")]
completion.side_effect = [mock_comp("ok"), mock_comp("8")]
chat_name = "_test"
chat_path = Path(cfg.get("CHAT_CACHE_PATH")) / chat_name
chat_path.unlink(missing_ok=True)
Expand All @@ -119,9 +119,9 @@ def test_default_repl(completion):
assert "8" in result.stdout


@patch("openai.resources.chat.Completions.create")
@patch("litellm.completion")
def test_default_repl_stdin(completion):
completion.side_effect = [comp_chunks("ok init"), comp_chunks("ok another")]
completion.side_effect = [mock_comp("ok init"), mock_comp("ok another")]
chat_name = "_test"
chat_path = Path(cfg.get("CHAT_CACHE_PATH")) / chat_name
chat_path.unlink(missing_ok=True)
Expand Down Expand Up @@ -153,9 +153,9 @@ def test_default_repl_stdin(completion):
assert "ok another" in result.stdout


@patch("openai.resources.chat.Completions.create")
@patch("litellm.completion")
def test_llm_options(completion):
completion.return_value = comp_chunks("Berlin")
completion.return_value = mock_comp("Berlin")

args = {
"prompt": "capital of the Germany?",
Expand All @@ -179,7 +179,7 @@ def test_llm_options(completion):
assert "Berlin" in result.stdout


@patch("openai.resources.chat.Completions.create")
@patch("litellm.completion")
def test_version(completion):
args = {"--version": True}
result = runner.invoke(app, cmd_args(**args))
Expand Down
7 changes: 4 additions & 3 deletions tests/test_roles.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@
from sgpt.config import cfg
from sgpt.role import SystemRole

from .utils import app, cmd_args, comp_args, comp_chunks, runner
from .utils import app, cmd_args, comp_args, mock_comp, runner


@patch("openai.resources.chat.Completions.create")
@patch("litellm.completion")
def test_role(completion):
completion.return_value = comp_chunks('{"foo": "bar"}')
completion.return_value = mock_comp('{"foo": "bar"}')
path = Path(cfg.get("ROLE_STORAGE_PATH")) / "json_gen_test.json"
path.unlink(missing_ok=True)
args = {"--create-role": "json_gen_test"}
Expand Down Expand Up @@ -44,6 +44,7 @@ def test_role(completion):
assert "foo" in generated_json

# Test with stdin prompt.
completion.return_value = mock_comp('{"foo": "bar"}')
args = {"--role": "json_gen_test"}
stdin = "generate foo, bar"
result = runner.invoke(app, cmd_args(**args), input=stdin)
Expand Down
Loading

0 comments on commit 1cb61de

Please sign in to comment.