Skip to content

Commit

Permalink
speed improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
capjamesg committed Nov 19, 2024
1 parent 887587f commit 7eb8505
Showing 1 changed file with 24 additions and 23 deletions.
47 changes: 24 additions & 23 deletions jamesql/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -841,7 +841,6 @@ def search(
results_limit = query.get("limit", 10)

metadata = {}
contexts = {}

if not query.get("query"):
return {
Expand Down Expand Up @@ -874,9 +873,14 @@ def search(
if doc_id in self.global_index
]

results = orjson.loads(orjson.dumps(results))

for r in results:
r["_score"] = 0
if r["uuid"] in metadata.get("scores", {}):
r["_score"] = metadata["scores"][r["uuid"]]
if r["uuid"] in metadata.get("highlights", {}):
contexts[r["uuid"]] = metadata["highlights"][r["uuid"]]
r["_context"] = metadata["highlights"][r["uuid"]]

end_time = time.time()

Expand All @@ -885,12 +889,17 @@ def search(

results_sort_by = query["sort_by"]

scores = {}
if query.get("sort_order") == "asc":
results = sorted(results, key=itemgetter(results_sort_by), reverse=False)
else:
results = sorted(results, key=itemgetter(results_sort_by), reverse=True)

if self.enable_experimental_bm25_ranker:
# TODO: Make sure this code can process boosts.

for doc in results:
doc["_score"] = 0

for term in term_queries:
tf = self.tf.get(doc["uuid"], {}).get(term, 0)
idf = self.idf.get(term, 0)
Expand All @@ -907,7 +916,7 @@ def search(
)
term_score *= idf

scores[doc["uuid"]] = scores.get(doc["uuid"], 0) + term_score
doc["_score"] += term_score

for field in fields:
gsi_index = self.gsis.get(field)["gsi"]
Expand All @@ -926,11 +935,14 @@ def search(
first_word_pos &= positions

if first_word_pos and field != "title_lower":
scores[doc["uuid"]] += (
doc["_score"] += (
len(first_word_pos) + 1
) # * len(set(word_pos[term_queries[0]]))
elif first_word_pos and field == "title_lower":
scores[doc["uuid"]] *= 2 + len(first_word_pos)
doc["_score"] *= 2 + len(first_word_pos)

# sort by doc score
results = sorted(results, key=lambda x: x.get("_score", 0), reverse=True)

if query.get("query_score"):
tree = parse_script_score(query["query_score"])
Expand All @@ -943,29 +955,18 @@ def search(

document["_score"] = transformer.transform(tree)

results = sorted(results, key=lambda x: x.get("_score", 0), reverse=True)

if query.get("skip"):
results = results[int(query["skip"]) :]

total_results = len(results)

# get max 5 scores from "scores" dict
scores = dict(sorted(scores.items(), key=lambda item: item[1], reverse=True))
if results_limit:
results = results[:results_limit]

results = results[:results_limit]

if results_sort_by and results_sort_by != "_score":
results = sorted(
results,
key=lambda x: x.get(results_sort_by, 0),
reverse=query.get("sort_order", True),
)

# zip results with contexts
results = [
{**result, "context": contexts.get(result["uuid"]), "_score": scores.get(result["uuid"], 0)}
for result in results
if result
]
if results_limit == 0:
results = []

result = {
"documents": results,
Expand Down

0 comments on commit 7eb8505

Please sign in to comment.