Skip to content

Commit

Permalink
fix bm25 bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
capjamesg committed Nov 18, 2024
1 parent 82b6537 commit 804e9fd
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 78 deletions.
166 changes: 91 additions & 75 deletions jamesql/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,24 +3,24 @@
import math
import os
import string
import threading
import time
import uuid
from collections import defaultdict
from copy import deepcopy
from enum import Enum
from functools import lru_cache
from operator import itemgetter
from typing import Dict, List
from sortedcontainers import SortedDict
import threading

from operator import itemgetter
import orjson
import pybmoore
import pygtrie
from BTrees.OOBTree import OOBTree
from lark import Lark
from nltk.corpus import stopwords
from nltk import download
from nltk.corpus import stopwords
from sortedcontainers import SortedDict

from jamesql.rewriter import grammar as rewriter_grammar
from jamesql.rewriter import string_query_to_jamesql
Expand Down Expand Up @@ -95,7 +95,7 @@ def get_trigrams(line):
class JameSQL:
SELF_METHODS = {"close_to": "_close_to"}

def __init__(self, match_limit_for_large_result_pages = 1000) -> None:
def __init__(self, match_limit_for_large_result_pages=1000) -> None:
self.global_index = {}
self.uuids_to_position_in_global_index = {}
self.gsis = {}
Expand Down Expand Up @@ -230,11 +230,13 @@ def _create_reverse_index(
# Compute term frequency (TF) for each word in the document
total_words_in_document = len(words)
for word, count in word_count.items():
self.tf[document["uuid"]][word] = count # / total_words_in_document
self.tf[document["uuid"]][word] = count # / total_words_in_document

# Compute inverse document frequency (IDF) for each word in the corpus
for word, doc_count in document_frequencies.items():
self.idf[word] =math.log((total_documents - doc_count + 0.5) / (doc_count + 0.5) + 1)
self.idf[word] = math.log(
(total_documents - doc_count + 0.5) / (doc_count + 0.5) + 1
)

# Compute TF-IDF for each document
for document in documents:
Expand All @@ -244,7 +246,7 @@ def _create_reverse_index(
self.tf_idf[word][score].append(document["uuid"])
else:
self.tf_idf[word][score] = [document["uuid"]]

for w in document[index_by].split(" "):
if self.reverse_tf_idf[w].get(index_by) is None:
self.reverse_tf_idf[w][index_by] = {}
Expand Down Expand Up @@ -492,15 +494,17 @@ def add(
)

if self.autosuggest_on and document.get(self.autosuggest_on):
self.autosuggest_index[document[self.autosuggest_on].lower()] = document[
self.autosuggest_on
]
self.autosuggest_index[
document[self.autosuggest_on].lower()
] = document[self.autosuggest_on]

# add to GSI
for key, value in document.items():
if isinstance(value, str):
self.doc_lengths[document["uuid"]][key] = len(value.split(" "))
self.document_length_words[document["uuid"]] += len(value.split(" "))
self.document_length_words[document["uuid"]] += len(
value.split(" ")
)

if key not in self.gsis:
if key == "uuid":
Expand Down Expand Up @@ -558,7 +562,9 @@ def add(
self.gsis[key]["gsi"][value] = []

self.gsis[key]["gsi"][value].append(document["uuid"])
elif self.gsis[key]["strategy"] == GSI_INDEX_STRATEGIES.TRIGRAM_CODE.name:
elif (
self.gsis[key]["strategy"] == GSI_INDEX_STRATEGIES.TRIGRAM_CODE.name
):
code_lines = value.split("\n")
total_lines = len(code_lines)
file_name = document.get("file_name")
Expand All @@ -578,12 +584,17 @@ def add(
)
self.gsis[key]["id2line"][f"{file_name}:{line_num}"] = line
self.gsis[key]["doc_lengths"][file_name] = total_lines
elif self.gsis[key]["strategy"] == GSI_INDEX_STRATEGIES.NOT_INDEXABLE.name:
elif (
self.gsis[key]["strategy"]
== GSI_INDEX_STRATEGIES.NOT_INDEXABLE.name
):
pass
else:
raise ValueError(
"Invalid GSI strategy. Must be one of: "
+ ", ".join([strategy.name for strategy in GSI_INDEX_STRATEGIES])
+ ", ".join(
[strategy.name for strategy in GSI_INDEX_STRATEGIES]
)
+ "."
)

Expand All @@ -599,7 +610,7 @@ def update(self, uuid: str, document: dict) -> Dict[str, dict]:
"""
Accepts a UUID and a document and updates the document associated with that key.
"""

with self.write_lock:
if uuid not in self.uuids_to_position_in_global_index:
return {"error": "Document not found"}
Expand Down Expand Up @@ -819,7 +830,9 @@ def create_gsi(

return gsi

def search(self, query: dict) -> List[str]:
def search(
self, query: dict, term_queries: list = [], fields: list = []
) -> List[str]:
# with self.write_lock:
start_time = time.time()

Expand Down Expand Up @@ -852,7 +865,11 @@ def search(self, query: dict) -> List[str]:

metadata, result_ids = self._recursively_parse_query(query["query"])

results = [self.global_index.get(doc_id) for doc_id in result_ids if doc_id in self.global_index]
results = [
self.global_index.get(doc_id)
for doc_id in result_ids
if doc_id in self.global_index
]

results = orjson.loads(orjson.dumps(results))

Expand All @@ -871,32 +888,22 @@ def search(self, query: dict) -> List[str]:
results_sort_by = query["sort_by"]

if query.get("sort_order") == "asc":
results = sorted(
results, key=itemgetter(results_sort_by), reverse=False
)
results = sorted(results, key=itemgetter(results_sort_by), reverse=False)
else:
results = sorted(
results, key=itemgetter(results_sort_by), reverse=True
)
results = sorted(results, key=itemgetter(results_sort_by), reverse=True)

if self.enable_experimental_bm25_ranker:
# TODO: Make sure this code can process boosts.

self.avgdl = sum(self.document_length_words.values()) / len(self.document_length_words)

if query["query"].get("or"):
operator = "or"
term_queries = [term.get("or")[0][list(term.get("or")[0].keys())[0]]["contains"] for term in query["query"]["or"]]
fields = [list(term.get("or")[0].keys()) for term in query["query"]["or"]]
else:
operator = "and"
term_queries = [term[list(term.keys())[0]]["contains"] for term in query["query"]["and"][0]["or"]]
fields = [list(term[list(term.keys())[0]].keys()) for term in query["query"]["and"][0]["or"]]

term_queries = list(set(term_queries))
fields = [field for sublist in fields for field in sublist]
self.avgdl = sum(self.document_length_words.values()) / len(
self.document_length_words
)

gsis = {field: self.gsis[field]["gsi"] for field in fields if self.gsis.get(field)}
gsis = {
field: self.gsis[field]["gsi"]
for field in fields
if self.gsis.get(field)
}

for doc in results:
doc["_score"] = 0
Expand All @@ -905,44 +912,48 @@ def search(self, query: dict) -> List[str]:
tf = self.tf.get(doc["uuid"], {}).get(term, 0)
idf = self.idf.get(term, 0)

term_score = (tf * (self.k1 + 1)) / (tf + self.k1 * (1 - self.b + self.b * (self.document_length_words[doc["uuid"]] / self.avgdl)))
term_score = (tf * (self.k1 + 1)) / (
tf
+ self.k1
* (
1
- self.b
+ self.b
* (self.document_length_words[doc["uuid"]] / self.avgdl)
)
)
term_score *= idf

doc["_score"] += term_score

for field in fields:
if not gsis.get(field):
continue
word_pos = gsis[field][term]["documents"]["uuid"][doc["uuid"]]
# give a boost if all terms are within 1 word of each other
# so a doc with "all too well" would do btter than "all well too"
if all([w in word_pos for w in term_queries]):
first_word_pos = set(word_pos[term_queries[0]])
for i, term in enumerate(term_queries):
positions = set([x - i for x in word_pos[term]])
first_word_pos &= positions

if first_word_pos:
doc["_score"] += (len(first_word_pos) + 1) * len(first_word_pos)

if field != "title_lower":
# TODO: Run only if query len > 1 word
word_pos = defaultdict(list)

for i, term in enumerate(doc[field].lower().split(" ")):
word_pos[term].append(i)

for term in term_queries:
# give a boost if all terms are within 1 word of each other
# so a doc with "all too well" would do better than "all well too"

if (
all([w in word_pos for w in term_queries])
and len(term_queries) > 1
):
first_word_pos = set(word_pos[term_queries[0]])
for i, term in enumerate(term_queries):
positions = set([x - i for x in word_pos[term]])
first_word_pos &= positions

if first_word_pos:
if first_word_pos and field != "title_lower":
doc["_score"] += (
len(first_word_pos) + 1
) # * len(set(word_pos[term_queries[0]]))
elif first_word_pos and field == "title_lower":
doc["_score"] *= 2 + len(first_word_pos)

# if "title_lower" in fields:
# # TODO: Make this more dynamic
# doc["_score"] *= len([term.get("or")[0].get("title_lower", {}).get("contains") in doc["title"].lower() for term in query["query"]["or"] if str(term.get("or")[0].get("title_lower", {}).get("contains")).lower() in doc["title"].lower()]) + 1

# sort by doc score
results = sorted(
results, key=lambda x: x.get("_score", 0), reverse=True
)
results = sorted(results, key=lambda x: x.get("_score", 0), reverse=True)

if query.get("query_score"):
tree = parse_script_score(query["query_score"])
Expand Down Expand Up @@ -1032,7 +1043,9 @@ def _recursively_parse_query(self, query_tree: dict) -> set:

if isinstance(query_tree[first_key], dict):
for key, query in query_tree[first_key].items():
query_metadata, query_values = self._recursively_parse_query({key: query})
query_metadata, query_values = self._recursively_parse_query(
{key: query}
)
metadata.append(query_metadata)
values.append(query_values)
else:
Expand Down Expand Up @@ -1077,9 +1090,7 @@ def _recursively_parse_query(self, query_tree: dict) -> set:

acc = set.union(getattr(self, func)(query_tree[first_key]))
else:
scores, result_uuids = self._run(
{"query": query_tree}, first_key
)
scores, result_uuids = self._run({"query": query_tree}, first_key)
acc = set.union(acc, result_uuids)

return scores, acc
Expand Down Expand Up @@ -1339,7 +1350,7 @@ def _run(self, query: dict, query_field: str) -> List[str]:
else:
for word in str(query_term).split(" "):
word = word.lower()

if gsi.get(word) is None:
continue

Expand Down Expand Up @@ -1404,13 +1415,18 @@ def _run(self, query: dict, query_field: str) -> List[str]:
"scores": defaultdict(dict),
"highlights": defaultdict(dict),
}
for doc in matching_documents[:self.match_limit_for_large_result_pages]:
advanced_query_information["scores"][doc] = matching_document_scores.get(doc, 0) * float(
boost_factor
)

for doc in matching_documents[: self.match_limit_for_large_result_pages]:
advanced_query_information["scores"][doc] = matching_document_scores.get(
doc, 0
) * float(boost_factor)

if matching_highlights:
advanced_query_information["highlights"][doc] = matching_highlights.get(doc, {})
advanced_query_information["highlights"][doc] = matching_highlights.get(
doc, {}
)

return advanced_query_information, matching_documents[:self.match_limit_for_large_result_pages]
return (
advanced_query_information,
matching_documents[: self.match_limit_for_large_result_pages],
)
2 changes: 0 additions & 2 deletions jamesql/rewriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,8 +243,6 @@ def word_query(self, items):
}
}



if self.boosts.get(field):
results[field]["boost"] = self.boosts.get(field, boost)

Expand Down
5 changes: 4 additions & 1 deletion jamesql/script_lang.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,10 @@ def start(self, items):
def decay(self, items):
# decay by half for every 30 days
# item is datetime.dateime object
days_since_post = (datetime.datetime.now() - datetime.datetime.strptime(items[0], "%Y-%m-%dT%H:%M:%S")).days
days_since_post = (
datetime.datetime.now()
- datetime.datetime.strptime(items[0], "%Y-%m-%dT%H:%M:%S")
).days

return 1.1 ** (days_since_post / 30)

Expand Down

0 comments on commit 804e9fd

Please sign in to comment.