This repository has been archived by the owner on Aug 27, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Refactor vector store and encoders, decoupling encoder from vector st…
…ore (#41)
- Loading branch information
Showing
13 changed files
with
169 additions
and
40 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
This subdirectory is meant for random benchmarking experiments on memas internal components. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
import datasets | ||
from datetime import datetime | ||
from memas.interface.encoder import TextEncoder | ||
from memas.encoder import openai_ada_encoder, universal_sentence_encoder | ||
from memas.text_parsing import text_parsers | ||
|
||
|
||
def prep_dataset(): | ||
wikipedia = datasets.load_dataset("wikipedia", "20220301.en") | ||
test_sentences = [] | ||
i = 0 | ||
|
||
start = datetime.now() | ||
for row in wikipedia["train"]: | ||
test_sentences.extend(text_parsers.split_doc(row["text"], 1024)) | ||
i += 1 | ||
if i > 10: | ||
break | ||
end = datetime.now() | ||
print(f"Splitting {i} documents into {len(test_sentences)} sentences took {(end - start).total_seconds()}s") | ||
|
||
batch_sentences = {} | ||
for batch_size in [5, 10, 20, 50, 100]: | ||
batched_list = [test_sentences[i:i+batch_size] for i in range(0, len(test_sentences), batch_size)] | ||
# pop the last one since likely not fully populated | ||
batched_list.pop() | ||
batch_sentences[batch_size] = batched_list | ||
|
||
return test_sentences, batch_sentences | ||
|
||
|
||
def benchmark_single(test_sentences: list[str], encoder: TextEncoder): | ||
start = datetime.now() | ||
i = 0 | ||
for sentence in test_sentences: | ||
i += 1 | ||
try: | ||
encoder.embed(sentence) | ||
except Exception as err: | ||
print(err) | ||
print(f"{i}!", sentence) | ||
|
||
end = datetime.now() | ||
return (end - start).total_seconds() | ||
|
||
|
||
def benchmark_batch(batched_list: list[list[str]], encoder: TextEncoder): | ||
start = datetime.now() | ||
i = 0 | ||
for batch in batched_list: | ||
i += 1 | ||
try: | ||
encoder.embed_multiple(batch) | ||
except Exception as err: | ||
print(err) | ||
print(f"{i}!", batch) | ||
end = datetime.now() | ||
return (end - start).total_seconds() | ||
|
||
|
||
def compare_encoders(encoders: dict[str, TextEncoder]): | ||
test_sentences, batch_sentences = prep_dataset() | ||
print(len(test_sentences)) | ||
output = {"single": {}} | ||
for name, encoder in encoders.items(): | ||
single = benchmark_single(test_sentences, encoder) | ||
print(f"[{name}] Single: total {single}s, avg {single/len(test_sentences)}s per item") | ||
output["single"][name] = (single, single/len(test_sentences)) | ||
|
||
for batch_size, batched_list in batch_sentences.items(): | ||
output[batch_size] = {} | ||
for name, encoder in encoders.items(): | ||
batch_time = benchmark_batch(batched_list, encoder) | ||
output[batch_size][name] = (batch_time, batch_time/len(batched_list)) | ||
print(f"[{name}] {batch_size} batch: total {batch_time}s, avg {batch_time/len(batched_list)}s per item") | ||
return output | ||
|
||
|
||
if __name__ == "__main__": | ||
USE_encoder = universal_sentence_encoder.USETextEncoder() | ||
USE_encoder.init() | ||
output = compare_encoders({ | ||
"ada": openai_ada_encoder.ADATextEncoder("PLACE_HOLDER"), | ||
"use": USE_encoder | ||
}) | ||
print(output) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
# This is needed to resolve a bug with apache beam + datasets | ||
# Read https://github.com/huggingface/datasets/issues/5613 for more details | ||
multiprocess==0.70.11 | ||
dill==0.3.6 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
ipykernel | ||
datasets==2.13.1 | ||
apache_beam==2.49.0 | ||
openai | ||
|
||
memas-sdk | ||
memas-client |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
#!/bin/bash | ||
pip install -r requirements.txt | ||
# TODO: remove this after beam/datasets package upgrade | ||
pip install --no-deps -r requirements-no-deps.txt |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
import numpy as np | ||
import openai | ||
from memas.interface.encoder import TextEncoder | ||
|
||
|
||
ADA_MODEL="text-embedding-ada-002" | ||
|
||
|
||
class ADATextEncoder(TextEncoder): | ||
def __init__(self, api_key) -> None: | ||
super().__init__(ENCODER_NAME="ADA", VECTOR_DIMENSION=1536) | ||
openai.api_key = api_key | ||
|
||
def init(self): | ||
pass | ||
|
||
def embed(self, text: str) -> np.ndarray: | ||
return np.array(openai.Embedding.create(input = [text], model=ADA_MODEL)['data'][0]['embedding']) | ||
|
||
def embed_multiple(self, text_list: list[str]) -> list[np.ndarray]: | ||
embeddings = openai.Embedding.create(input = text_list, model=ADA_MODEL)['data'] | ||
return [np.array(resp['embedding']) for resp in embeddings] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,5 +5,6 @@ pymilvus==2.2.8 | |
elasticsearch==8.8.0 | ||
scylla-driver==3.26.2 | ||
nltk | ||
openai | ||
gunicorn[eventlet] | ||
futurist |