Skip to content

Commit

Permalink
recursive json_embed(json, model) for SQLite and DuckDB (#34)
Browse files Browse the repository at this point in the history
  • Loading branch information
Florents-Tselai authored Aug 15, 2024
1 parent 4c835b1 commit 2a36547
Show file tree
Hide file tree
Showing 5 changed files with 130 additions and 6 deletions.
44 changes: 41 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,6 @@
[![codecov](https://codecov.io/gh/Florents-Tselai/tsellm/branch/main/graph/badge.svg)](https://codecov.io/gh/Florents-Tselai/tsellm)
[![License](https://img.shields.io/badge/BSD%20license-blue.svg)](https://github.com/Florents-Tselai/tsellm/blob/main/LICENSE)




**tsellm** is the easiest way to access LLMs from SQLite or DuckDB.

```shell
Expand Down Expand Up @@ -44,12 +41,53 @@ so you can use any of its plugins:
```shell
llm install llm-sentence-transformers
llm sentence-transformers register all-MiniLM-L12-v2
llm install llm-embed-hazo # dummy embedding model for demonstration purposes
```

```sql
tsellm prompts.sqlite3 "select embed(p, 'sentence-transformers/all-MiniLM-L12-v2')"
```

### Embedding `JSON` Recursively

If you have `JSON` columns, you can embed these object recursively.
That is, an embedding vector of floats will replace each text occurrence in the object.

```bash
cat <<EOF | tee >(sqlite3 prompts.sqlite3) | duckdb prompts.duckdb
CREATE TABLE people(d JSON);
INSERT INTO people (d) VALUES
('{"name": "John Doe", "age": 30, "hobbies": ["reading", "biking"]}'),
('{"name": "Jane Smith", "age": 25, "hobbies": ["painting", "traveling"]}')
EOF
```

#### SQLite

```sql
tsellm prompts.sqlite3 "select json_embed(d, 'hazo') from people"
```

*Output*

```
('{"name": [4.0, 3.0,..., 0.0], "age": 30, "hobbies": [[7.0, 0.0,..., 0.0], [6.0, 0.0, ..., 0.0]]}',)
('{"name": [4.0, 5.0, ,..., 0.0], "age": 25, "hobbies": [[8.0, 0.0,..., 0.0], [9.0, 0.0,..., 0.0]]}',)
```

#### DuckDB

```sql
tsellm prompts.duckdb "select json_embed(d, 'hazo') from people"
```

*Output*

```
('{"name": [4.0, 3.0,..., 0.0], "age": 30, "hobbies": [[7.0, 0.0,..., 0.0], [6.0, 0.0, ..., 0.0]]}',)
('{"name": [4.0, 5.0, ,..., 0.0], "age": 25, "hobbies": [[8.0, 0.0,..., 0.0], [9.0, 0.0,..., 0.0]]}',)
```

### Embeddings for binary (`BLOB`) columns

```shell
Expand Down
64 changes: 63 additions & 1 deletion tests/test_tsellm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import unittest
from pathlib import Path
from test.support import captured_stdout, captured_stderr, captured_stdin
from test.support.os_helper import TESTFN, unlink

import duckdb
import llm.cli
Expand Down Expand Up @@ -175,6 +174,15 @@ def test_interact_valid_multiline_sql(self):

class InMemorySQLiteTest(TsellmConsoleTest):
path_args = None
alice_json = """{
\"name\": \"Alice\",
\"details\": {
\"age\": 30,
\"hobbies\": [\"reading\", \"cycling\"],
\"location\": \"Wonderland\"
},
\"greeting\": \"Hello, World!\"
}"""

def setUp(self):
super().setUp()
Expand Down Expand Up @@ -225,6 +233,33 @@ def test_embed_hazo_binary(self):
self.assertTrue(llm.get_embedding_model("hazo").supports_binary)
self.expect_success(*self.path_args, "select embed(randomblob(16), 'hazo')")

def test_embed_json_recursive(self):
out = self.expect_success(
*self.path_args,
f"select json_extract('{self.alice_json}', '$.name')",
)
self.assertEqual(
"('Alice',)\n",
out,
)

out = self.expect_success(
*self.path_args,
f"select json_embed('{self.alice_json}', 'hazo')",
)
self.assertEqual(
(
'(\'{"name": [5.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, '
'0.0, 0.0, 0.0, 0.0], "details": {"age": 30, "hobbies": [[7.0, 0.0, 0.0, 0.0, '
"0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [7.0, 0.0, 0.0, "
"0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]], "
'"location": [10.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, '
'0.0, 0.0, 0.0, 0.0]}, "greeting": [6.0, 6.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, '
"0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]}',)\n"
),
out,
)

def test_embed_default_hazo(self):
self.assertEqual(llm_cli.get_default_embedding_model(), "hazo")
out = self.expect_success(*self.path_args, "select embed('hello world')")
Expand Down Expand Up @@ -290,6 +325,33 @@ def test_embed_hazo_binary(self):
# See https://github.com/Florents-Tselai/tsellm/issues/25
pass

def test_embed_json_recursive(self):
out = self.expect_success(
*self.path_args,
f"select '{self.alice_json}'::json -> 'name'",
)
self.assertEqual(
"('\"Alice\"',)\n",
out,
)

out = self.expect_success(
*self.path_args,
f"select json_embed('{self.alice_json}'::json, 'hazo')",
)
self.assertEqual(
(
'(\'{"name": [5.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, '
'0.0, 0.0, 0.0, 0.0], "details": {"age": 30, "hobbies": [[7.0, 0.0, 0.0, 0.0, '
"0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [7.0, 0.0, 0.0, "
"0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]], "
'"location": [10.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, '
'0.0, 0.0, 0.0, 0.0]}, "greeting": [6.0, 6.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, '
"0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]}',)\n"
),
out,
)


class DiskDuckDBTest(InMemoryDuckDBTest):
db_fp = None
Expand Down
2 changes: 1 addition & 1 deletion tsellm/__version__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
__title__ = "tsellm"
__description__ = "Use LLMs in SQLite and DuckDB"
__version__ = "0.1.0a10"
__version__ = "0.1.0a12"
5 changes: 4 additions & 1 deletion tsellm/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
_prompt_model,
_prompt_model_default,
_embed_model,
_json_embed_model,
_embed_model_default,
)

Expand Down Expand Up @@ -79,6 +80,7 @@ class TsellmConsole(InteractiveConsole, ABC):
("prompt", 1, _prompt_model_default, False),
("embed", 2, _embed_model, False),
("embed", 1, _embed_model_default, False),
("json_embed", 2, _json_embed_model, False),
]

error_class = None
Expand All @@ -87,7 +89,7 @@ class TsellmConsole(InteractiveConsole, ABC):

@staticmethod
def create_console(
fp: Union[str, Path], in_memory_type: DatabaseType = DatabaseType.UNKNOWN
fp: Union[str, Path], in_memory_type: DatabaseType = DatabaseType.UNKNOWN
):
sniffer = DBSniffer(fp)
if sniffer.is_in_memory:
Expand Down Expand Up @@ -274,6 +276,7 @@ def is_valid_db(self) -> bool:
_functions = [
("prompt", 2, _prompt_model, False),
("embed", 2, _embed_model, False),
("json_embed", 2, _json_embed_model, False),
]

def connect(self):
Expand Down
21 changes: 21 additions & 0 deletions tsellm/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,21 @@
"""


def json_recurse_apply(json_obj, f):
if isinstance(json_obj, dict):
# Recursively apply the function to dictionary values
return {k: json_recurse_apply(v, f) for k, v in json_obj.items()}
elif isinstance(json_obj, list):
# Recursively apply the function to list elements
return [json_recurse_apply(item, f) for item in json_obj]
elif isinstance(json_obj, str):
# Apply the function to string values, which returns a list of floats
return f(json_obj)
else:
# Return the object as is if it's neither a dictionary, list, or string
return json_obj


def _prompt_model(prompt: str, model: str) -> str:
return llm.get_model(model).prompt(prompt).text()

Expand All @@ -26,6 +41,12 @@ def _embed_model(text: str, model: str) -> str:
return json.dumps(llm.get_embedding_model(model).embed(text))


def _json_embed_model(js: str, model: str) -> str:
return json.dumps(
json_recurse_apply(json.loads(js), lambda v: json.loads(_embed_model(v, model)))
)


def _embed_model_default(text: str) -> str:
return json.dumps(
llm.get_embedding_model(llm_cli.get_default_embedding_model()).embed(text)
Expand Down

0 comments on commit 2a36547

Please sign in to comment.