-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathquery.py
66 lines (54 loc) · 2.68 KB
/
query.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import chromadb
import numpy as np
from customEmbeddings import MyCustomEmbedding
from sentence_transformers import CrossEncoder
import heapq
chroma_client = chromadb.PersistentClient(path="./chroma_db/")
chroma_collection = chroma_client.get_collection(name="myRag",embedding_function=MyCustomEmbedding())
def query(query_text: str):
return chroma_collection.query(query_texts=query_text, n_results=5, include=["documents","embeddings", "distances"])
def return_all_embeddings():
return chroma_collection.get(include=['embeddings'])['embeddings']
def return_results_embeddings(results):
return results['embeddings'][0]
def return_results_documents(results):
return results['documents'][0]
def return_results_distances(results):
return results['distances'][0]
def document_distance(results: dict):
return zip(results['documents'][0],results['distances'][0])
def multi_query(queries: list):
found_embeddings = []
found_documents = []
found_distances = []
for query_string in queries:
results = query(query_string)
embeddings = return_results_embeddings(results)
documents = return_results_documents(results)
distances = return_results_distances(results)
found_embeddings += [embedding for embedding in embeddings
if not any(np.array_equal(embedding,found) for found in found_embeddings)
]
found_documents += [document for document in documents if document not in found_documents ]
found_distances += [distance for distance in distances if distance not in found_distances ]
return found_documents, found_embeddings, found_distances
def generate_ranked_results(queries: list):
found_documents, found_embeddings, found_distances = multi_query(queries)
cross_encoder = CrossEncoder('cross-encoder/stsb-roberta-large')
docscore = []
for query_text in queries:
for document, embedding, distance in zip(found_documents, found_embeddings, found_distances):
pair = [query_text, document]
score = cross_encoder.predict(pair)
docscore.append((score,
{
"document":document,
"query_text":query_text,
"embeddings":embedding,
"distance":distance
}
))
top_5_list = heapq.nlargest(10, docscore, key=lambda x: x[0])
top_5_scores = [score for score, document in top_5_list if score > 0.3]
top_5_documents = [document for score, document in top_5_list if score > 0.3]
return top_5_scores, top_5_documents