From 804e9fd3d3c7d031ac585dab06fad6a0593b3b73 Mon Sep 17 00:00:00 2001 From: capjamesg Date: Mon, 18 Nov 2024 11:57:44 +0000 Subject: [PATCH] fix bm25 bugs --- jamesql/index.py | 166 ++++++++++++++++++++++------------------- jamesql/rewriter.py | 2 - jamesql/script_lang.py | 5 +- 3 files changed, 95 insertions(+), 78 deletions(-) diff --git a/jamesql/index.py b/jamesql/index.py index c5f257b..d7d8de5 100644 --- a/jamesql/index.py +++ b/jamesql/index.py @@ -3,24 +3,24 @@ import math import os import string +import threading import time import uuid from collections import defaultdict from copy import deepcopy from enum import Enum from functools import lru_cache +from operator import itemgetter from typing import Dict, List -from sortedcontainers import SortedDict -import threading -from operator import itemgetter import orjson import pybmoore import pygtrie from BTrees.OOBTree import OOBTree from lark import Lark -from nltk.corpus import stopwords from nltk import download +from nltk.corpus import stopwords +from sortedcontainers import SortedDict from jamesql.rewriter import grammar as rewriter_grammar from jamesql.rewriter import string_query_to_jamesql @@ -95,7 +95,7 @@ def get_trigrams(line): class JameSQL: SELF_METHODS = {"close_to": "_close_to"} - def __init__(self, match_limit_for_large_result_pages = 1000) -> None: + def __init__(self, match_limit_for_large_result_pages=1000) -> None: self.global_index = {} self.uuids_to_position_in_global_index = {} self.gsis = {} @@ -230,11 +230,13 @@ def _create_reverse_index( # Compute term frequency (TF) for each word in the document total_words_in_document = len(words) for word, count in word_count.items(): - self.tf[document["uuid"]][word] = count # / total_words_in_document + self.tf[document["uuid"]][word] = count # / total_words_in_document # Compute inverse document frequency (IDF) for each word in the corpus for word, doc_count in document_frequencies.items(): - self.idf[word] =math.log((total_documents - doc_count + 0.5) / (doc_count + 0.5) + 1) + self.idf[word] = math.log( + (total_documents - doc_count + 0.5) / (doc_count + 0.5) + 1 + ) # Compute TF-IDF for each document for document in documents: @@ -244,7 +246,7 @@ def _create_reverse_index( self.tf_idf[word][score].append(document["uuid"]) else: self.tf_idf[word][score] = [document["uuid"]] - + for w in document[index_by].split(" "): if self.reverse_tf_idf[w].get(index_by) is None: self.reverse_tf_idf[w][index_by] = {} @@ -492,15 +494,17 @@ def add( ) if self.autosuggest_on and document.get(self.autosuggest_on): - self.autosuggest_index[document[self.autosuggest_on].lower()] = document[ - self.autosuggest_on - ] + self.autosuggest_index[ + document[self.autosuggest_on].lower() + ] = document[self.autosuggest_on] # add to GSI for key, value in document.items(): if isinstance(value, str): self.doc_lengths[document["uuid"]][key] = len(value.split(" ")) - self.document_length_words[document["uuid"]] += len(value.split(" ")) + self.document_length_words[document["uuid"]] += len( + value.split(" ") + ) if key not in self.gsis: if key == "uuid": @@ -558,7 +562,9 @@ def add( self.gsis[key]["gsi"][value] = [] self.gsis[key]["gsi"][value].append(document["uuid"]) - elif self.gsis[key]["strategy"] == GSI_INDEX_STRATEGIES.TRIGRAM_CODE.name: + elif ( + self.gsis[key]["strategy"] == GSI_INDEX_STRATEGIES.TRIGRAM_CODE.name + ): code_lines = value.split("\n") total_lines = len(code_lines) file_name = document.get("file_name") @@ -578,12 +584,17 @@ def add( ) self.gsis[key]["id2line"][f"{file_name}:{line_num}"] = line self.gsis[key]["doc_lengths"][file_name] = total_lines - elif self.gsis[key]["strategy"] == GSI_INDEX_STRATEGIES.NOT_INDEXABLE.name: + elif ( + self.gsis[key]["strategy"] + == GSI_INDEX_STRATEGIES.NOT_INDEXABLE.name + ): pass else: raise ValueError( "Invalid GSI strategy. Must be one of: " - + ", ".join([strategy.name for strategy in GSI_INDEX_STRATEGIES]) + + ", ".join( + [strategy.name for strategy in GSI_INDEX_STRATEGIES] + ) + "." ) @@ -599,7 +610,7 @@ def update(self, uuid: str, document: dict) -> Dict[str, dict]: """ Accepts a UUID and a document and updates the document associated with that key. """ - + with self.write_lock: if uuid not in self.uuids_to_position_in_global_index: return {"error": "Document not found"} @@ -819,7 +830,9 @@ def create_gsi( return gsi - def search(self, query: dict) -> List[str]: + def search( + self, query: dict, term_queries: list = [], fields: list = [] + ) -> List[str]: # with self.write_lock: start_time = time.time() @@ -852,7 +865,11 @@ def search(self, query: dict) -> List[str]: metadata, result_ids = self._recursively_parse_query(query["query"]) - results = [self.global_index.get(doc_id) for doc_id in result_ids if doc_id in self.global_index] + results = [ + self.global_index.get(doc_id) + for doc_id in result_ids + if doc_id in self.global_index + ] results = orjson.loads(orjson.dumps(results)) @@ -871,32 +888,22 @@ def search(self, query: dict) -> List[str]: results_sort_by = query["sort_by"] if query.get("sort_order") == "asc": - results = sorted( - results, key=itemgetter(results_sort_by), reverse=False - ) + results = sorted(results, key=itemgetter(results_sort_by), reverse=False) else: - results = sorted( - results, key=itemgetter(results_sort_by), reverse=True - ) + results = sorted(results, key=itemgetter(results_sort_by), reverse=True) if self.enable_experimental_bm25_ranker: # TODO: Make sure this code can process boosts. - self.avgdl = sum(self.document_length_words.values()) / len(self.document_length_words) - - if query["query"].get("or"): - operator = "or" - term_queries = [term.get("or")[0][list(term.get("or")[0].keys())[0]]["contains"] for term in query["query"]["or"]] - fields = [list(term.get("or")[0].keys()) for term in query["query"]["or"]] - else: - operator = "and" - term_queries = [term[list(term.keys())[0]]["contains"] for term in query["query"]["and"][0]["or"]] - fields = [list(term[list(term.keys())[0]].keys()) for term in query["query"]["and"][0]["or"]] - - term_queries = list(set(term_queries)) - fields = [field for sublist in fields for field in sublist] + self.avgdl = sum(self.document_length_words.values()) / len( + self.document_length_words + ) - gsis = {field: self.gsis[field]["gsi"] for field in fields if self.gsis.get(field)} + gsis = { + field: self.gsis[field]["gsi"] + for field in fields + if self.gsis.get(field) + } for doc in results: doc["_score"] = 0 @@ -905,44 +912,48 @@ def search(self, query: dict) -> List[str]: tf = self.tf.get(doc["uuid"], {}).get(term, 0) idf = self.idf.get(term, 0) - term_score = (tf * (self.k1 + 1)) / (tf + self.k1 * (1 - self.b + self.b * (self.document_length_words[doc["uuid"]] / self.avgdl))) + term_score = (tf * (self.k1 + 1)) / ( + tf + + self.k1 + * ( + 1 + - self.b + + self.b + * (self.document_length_words[doc["uuid"]] / self.avgdl) + ) + ) term_score *= idf doc["_score"] += term_score for field in fields: - if not gsis.get(field): - continue - word_pos = gsis[field][term]["documents"]["uuid"][doc["uuid"]] - # give a boost if all terms are within 1 word of each other - # so a doc with "all too well" would do btter than "all well too" - if all([w in word_pos for w in term_queries]): - first_word_pos = set(word_pos[term_queries[0]]) - for i, term in enumerate(term_queries): - positions = set([x - i for x in word_pos[term]]) - first_word_pos &= positions - - if first_word_pos: - doc["_score"] += (len(first_word_pos) + 1) * len(first_word_pos) - - if field != "title_lower": - # TODO: Run only if query len > 1 word + word_pos = defaultdict(list) + + for i, term in enumerate(doc[field].lower().split(" ")): + word_pos[term].append(i) + + for term in term_queries: + # give a boost if all terms are within 1 word of each other + # so a doc with "all too well" would do better than "all well too" + + if ( + all([w in word_pos for w in term_queries]) + and len(term_queries) > 1 + ): first_word_pos = set(word_pos[term_queries[0]]) for i, term in enumerate(term_queries): positions = set([x - i for x in word_pos[term]]) first_word_pos &= positions - if first_word_pos: + if first_word_pos and field != "title_lower": + doc["_score"] += ( + len(first_word_pos) + 1 + ) # * len(set(word_pos[term_queries[0]])) + elif first_word_pos and field == "title_lower": doc["_score"] *= 2 + len(first_word_pos) - # if "title_lower" in fields: - # # TODO: Make this more dynamic - # doc["_score"] *= len([term.get("or")[0].get("title_lower", {}).get("contains") in doc["title"].lower() for term in query["query"]["or"] if str(term.get("or")[0].get("title_lower", {}).get("contains")).lower() in doc["title"].lower()]) + 1 - # sort by doc score - results = sorted( - results, key=lambda x: x.get("_score", 0), reverse=True - ) + results = sorted(results, key=lambda x: x.get("_score", 0), reverse=True) if query.get("query_score"): tree = parse_script_score(query["query_score"]) @@ -1032,7 +1043,9 @@ def _recursively_parse_query(self, query_tree: dict) -> set: if isinstance(query_tree[first_key], dict): for key, query in query_tree[first_key].items(): - query_metadata, query_values = self._recursively_parse_query({key: query}) + query_metadata, query_values = self._recursively_parse_query( + {key: query} + ) metadata.append(query_metadata) values.append(query_values) else: @@ -1077,9 +1090,7 @@ def _recursively_parse_query(self, query_tree: dict) -> set: acc = set.union(getattr(self, func)(query_tree[first_key])) else: - scores, result_uuids = self._run( - {"query": query_tree}, first_key - ) + scores, result_uuids = self._run({"query": query_tree}, first_key) acc = set.union(acc, result_uuids) return scores, acc @@ -1339,7 +1350,7 @@ def _run(self, query: dict, query_field: str) -> List[str]: else: for word in str(query_term).split(" "): word = word.lower() - + if gsi.get(word) is None: continue @@ -1404,13 +1415,18 @@ def _run(self, query: dict, query_field: str) -> List[str]: "scores": defaultdict(dict), "highlights": defaultdict(dict), } - - for doc in matching_documents[:self.match_limit_for_large_result_pages]: - advanced_query_information["scores"][doc] = matching_document_scores.get(doc, 0) * float( - boost_factor - ) + + for doc in matching_documents[: self.match_limit_for_large_result_pages]: + advanced_query_information["scores"][doc] = matching_document_scores.get( + doc, 0 + ) * float(boost_factor) if matching_highlights: - advanced_query_information["highlights"][doc] = matching_highlights.get(doc, {}) + advanced_query_information["highlights"][doc] = matching_highlights.get( + doc, {} + ) - return advanced_query_information, matching_documents[:self.match_limit_for_large_result_pages] + return ( + advanced_query_information, + matching_documents[: self.match_limit_for_large_result_pages], + ) diff --git a/jamesql/rewriter.py b/jamesql/rewriter.py index 287208c..63ef02d 100644 --- a/jamesql/rewriter.py +++ b/jamesql/rewriter.py @@ -243,8 +243,6 @@ def word_query(self, items): } } - - if self.boosts.get(field): results[field]["boost"] = self.boosts.get(field, boost) diff --git a/jamesql/script_lang.py b/jamesql/script_lang.py index 4ed7579..6dd3f9f 100644 --- a/jamesql/script_lang.py +++ b/jamesql/script_lang.py @@ -54,7 +54,10 @@ def start(self, items): def decay(self, items): # decay by half for every 30 days # item is datetime.dateime object - days_since_post = (datetime.datetime.now() - datetime.datetime.strptime(items[0], "%Y-%m-%dT%H:%M:%S")).days + days_since_post = ( + datetime.datetime.now() + - datetime.datetime.strptime(items[0], "%Y-%m-%dT%H:%M:%S") + ).days return 1.1 ** (days_since_post / 30)