From 3e299c1c3610cb3086a8aef292975cf324b9ddbf Mon Sep 17 00:00:00 2001 From: capjamesg Date: Tue, 19 Nov 2024 22:53:35 +0000 Subject: [PATCH] speed improvements --- jamesql/index.py | 480 ++++++++++++++++++++++------------------------- 1 file changed, 221 insertions(+), 259 deletions(-) diff --git a/jamesql/index.py b/jamesql/index.py index bfdeff8..5b64af3 100644 --- a/jamesql/index.py +++ b/jamesql/index.py @@ -3,38 +3,36 @@ import math import os import string -import threading import time import uuid from collections import defaultdict +from copy import deepcopy from enum import Enum from functools import lru_cache -from operator import itemgetter from typing import Dict, List +from sortedcontainers import SortedDict +import threading +from operator import itemgetter import orjson import pybmoore import pygtrie from BTrees.OOBTree import OOBTree from lark import Lark -from nltk import download from nltk.corpus import stopwords -from sortedcontainers import SortedDict +from nltk import download from jamesql.rewriter import grammar as rewriter_grammar from jamesql.rewriter import string_query_to_jamesql from .script_lang import JameSQLScriptTransformer, grammar -if not os.path.exists(os.path.expanduser("~") + "/nltk_data"): - download("stopwords") +download("stopwords") INDEX_STORE = os.path.join(os.path.expanduser("~"), ".jamesql") JOURNAL_FILE = os.path.join(os.getcwd(), "journal.jamesql") INDEX_DATA_FILE = os.path.join(os.getcwd(), "index.jamesql") -END_OF_SENTENCE_TOKEN = "eos" - if not os.path.exists(INDEX_STORE): os.makedirs(INDEX_STORE) @@ -97,7 +95,7 @@ def get_trigrams(line): class JameSQL: SELF_METHODS = {"close_to": "_close_to"} - def __init__(self, match_limit_for_large_result_pages=1000) -> None: + def __init__(self, match_limit_for_large_result_pages = 1000) -> None: self.global_index = {} self.uuids_to_position_in_global_index = {} self.gsis = {} @@ -232,13 +230,11 @@ def _create_reverse_index( # Compute term frequency (TF) for each word in the document total_words_in_document = len(words) for word, count in word_count.items(): - self.tf[document["uuid"]][word] = count # / total_words_in_document + self.tf[document["uuid"]][word] = count # / total_words_in_document # Compute inverse document frequency (IDF) for each word in the corpus for word, doc_count in document_frequencies.items(): - self.idf[word] = math.log( - (total_documents - doc_count + 0.5) / (doc_count + 0.5) + 1 - ) + self.idf[word] =math.log((total_documents - doc_count + 0.5) / (doc_count + 0.5) + 1) # Compute TF-IDF for each document for document in documents: @@ -248,7 +244,7 @@ def _create_reverse_index( self.tf_idf[word][score].append(document["uuid"]) else: self.tf_idf[word][score] = [document["uuid"]] - + for w in document[index_by].split(" "): if self.reverse_tf_idf[w].get(index_by) is None: self.reverse_tf_idf[w][index_by] = {} @@ -496,17 +492,15 @@ def add( ) if self.autosuggest_on and document.get(self.autosuggest_on): - self.autosuggest_index[document[self.autosuggest_on].lower()] = ( - document[self.autosuggest_on] - ) + self.autosuggest_index[document[self.autosuggest_on].lower()] = document[ + self.autosuggest_on + ] # add to GSI for key, value in document.items(): if isinstance(value, str): self.doc_lengths[document["uuid"]][key] = len(value.split(" ")) - self.document_length_words[document["uuid"]] += len( - value.split(" ") - ) + self.document_length_words[document["uuid"]] += len(value.split(" ")) if key not in self.gsis: if key == "uuid": @@ -564,9 +558,7 @@ def add( self.gsis[key]["gsi"][value] = [] self.gsis[key]["gsi"][value].append(document["uuid"]) - elif ( - self.gsis[key]["strategy"] == GSI_INDEX_STRATEGIES.TRIGRAM_CODE.name - ): + elif self.gsis[key]["strategy"] == GSI_INDEX_STRATEGIES.TRIGRAM_CODE.name: code_lines = value.split("\n") total_lines = len(code_lines) file_name = document.get("file_name") @@ -586,17 +578,12 @@ def add( ) self.gsis[key]["id2line"][f"{file_name}:{line_num}"] = line self.gsis[key]["doc_lengths"][file_name] = total_lines - elif ( - self.gsis[key]["strategy"] - == GSI_INDEX_STRATEGIES.NOT_INDEXABLE.name - ): + elif self.gsis[key]["strategy"] == GSI_INDEX_STRATEGIES.NOT_INDEXABLE.name: pass else: raise ValueError( "Invalid GSI strategy. Must be one of: " - + ", ".join( - [strategy.name for strategy in GSI_INDEX_STRATEGIES] - ) + + ", ".join([strategy.name for strategy in GSI_INDEX_STRATEGIES]) + "." ) @@ -606,17 +593,13 @@ def add( os.remove(JOURNAL_FILE) - self.avgdl = sum(self.document_length_words.values()) / len( - self.document_length_words - ) - return document def update(self, uuid: str, document: dict) -> Dict[str, dict]: """ Accepts a UUID and a document and updates the document associated with that key. """ - + with self.write_lock: if uuid not in self.uuids_to_position_in_global_index: return {"error": "Document not found"} @@ -761,6 +744,10 @@ def create_gsi( ] ): strategy = GSI_INDEX_STRATEGIES.DATE + # if word count < 10, use prefix + # elif isinstance(index_by, str) and sum([len(item.split(" ")) for item in documents_in_indexed_by]) / len(documents_in_indexed_by) < 10: + # strategy = GSI_INDEX_STRATEGIES.PREFIX + # if average contains more than one word, use contains elif isinstance(documents_in_indexed_by[0], str) and sum( [len(item.split(" ")) for item in documents_in_indexed_by] ) / len(documents_in_indexed_by): @@ -832,9 +819,7 @@ def create_gsi( return gsi - def search( - self, query: dict, term_queries: list = [], fields: list = [] - ) -> List[str]: + def search(self, query: dict, term_queries: list = [], fields: list = []) -> List[str]: # with self.write_lock: start_time = time.time() @@ -867,11 +852,7 @@ def search( metadata, result_ids = self._recursively_parse_query(query["query"]) - results = [ - self.global_index.get(doc_id) - for doc_id in result_ids - if doc_id in self.global_index - ] + results = [self.global_index.get(doc_id) for doc_id in result_ids if doc_id in self.global_index] results = orjson.loads(orjson.dumps(results)) @@ -890,13 +871,21 @@ def search( results_sort_by = query["sort_by"] if query.get("sort_order") == "asc": - results = sorted(results, key=itemgetter(results_sort_by), reverse=False) + results = sorted( + results, key=itemgetter(results_sort_by), reverse=False + ) else: - results = sorted(results, key=itemgetter(results_sort_by), reverse=True) + results = sorted( + results, key=itemgetter(results_sort_by), reverse=True + ) if self.enable_experimental_bm25_ranker: # TODO: Make sure this code can process boosts. + self.avgdl = sum(self.document_length_words.values()) / len(self.document_length_words) + + gsis = {field: self.gsis[field]["gsi"] for field in fields if self.gsis.get(field)} + for doc in results: doc["_score"] = 0 @@ -904,45 +893,36 @@ def search( tf = self.tf.get(doc["uuid"], {}).get(term, 0) idf = self.idf.get(term, 0) - term_score = (tf * (self.k1 + 1)) / ( - tf - + self.k1 - * ( - 1 - - self.b - + self.b - * (self.document_length_words[doc["uuid"]] / self.avgdl) - ) - ) + term_score = (tf * (self.k1 + 1)) / (tf + self.k1 * (1 - self.b + self.b * (self.document_length_words[doc["uuid"]] / self.avgdl))) term_score *= idf doc["_score"] += term_score for field in fields: - gsi_index = self.gsis.get(field)["gsi"] + word_pos = defaultdict(list) + + for i, term in enumerate(doc[field].lower().split(" ")): + word_pos[term].append(i) - word_pos = {word: gsi_index.get(word, {}).get("documents", {}).get("uuid", {}).get(doc["uuid"], []) for word in term_queries} for term in term_queries: # give a boost if all terms are within 1 word of each other # so a doc with "all too well" would do better than "all well too" - if len(term_queries) < 2: - continue + if all([w in word_pos for w in term_queries]): + first_word_pos = set(word_pos[term_queries[0]]) + for i, term in enumerate(term_queries): + positions = set([x - i for x in word_pos[term]]) + first_word_pos &= positions - first_word_pos = set(word_pos[term_queries[0]]) - for i, term in enumerate(term_queries): - positions = set([x - i for x in word_pos[term]]) - first_word_pos &= positions - - if first_word_pos and field != "title_lower": - doc["_score"] += ( - len(first_word_pos) + 1 - ) # * len(set(word_pos[term_queries[0]])) - elif first_word_pos and field == "title_lower": - doc["_score"] *= 2 + len(first_word_pos) + if first_word_pos and field != "title_lower": + doc["_score"] += (len(first_word_pos) + 1) * len(set(word_pos[term_queries[0]])) + elif first_word_pos and field == "title_lower": + doc["_score"] *= 2 + len(first_word_pos) # sort by doc score - results = sorted(results, key=lambda x: x.get("_score", 0), reverse=True) + results = sorted( + results, key=lambda x: x.get("_score", 0), reverse=True + ) if query.get("query_score"): tree = parse_script_score(query["query_score"]) @@ -1032,9 +1012,7 @@ def _recursively_parse_query(self, query_tree: dict) -> set: if isinstance(query_tree[first_key], dict): for key, query in query_tree[first_key].items(): - query_metadata, query_values = self._recursively_parse_query( - {key: query} - ) + query_metadata, query_values = self._recursively_parse_query({key: query}) metadata.append(query_metadata) values.append(query_values) else: @@ -1079,7 +1057,9 @@ def _recursively_parse_query(self, query_tree: dict) -> set: acc = set.union(getattr(self, func)(query_tree[first_key])) else: - scores, result_uuids = self._run({"query": query_tree}, first_key) + scores, result_uuids = self._run( + {"query": query_tree}, first_key + ) acc = set.union(acc, result_uuids) return scores, acc @@ -1124,138 +1104,6 @@ def _turn_query_into_fuzzy_options(self, query_term: str) -> dict: return query_terms - def _run_trigram_code(self, query_term, query_field): - matching_highlights = {} - trigrams = get_trigrams(query_term) - - contexts = [] - - candidates = set(self.gsis[query_field]["gsi"].get(trigrams[0], [])) - - for trigram in trigrams: - candidates = candidates.intersection( - set(self.gsis[query_field]["gsi"].get(trigram, [])) - ) - - # candidate[2] is the document uuid - matching_documents = [candidate[2] for candidate in candidates] - - # get line numbers - for candidate in candidates: - contexts.append( - { - "line": candidate[1], - "code": self.gsis[query_field]["id2line"][ - f"{candidate[0]}:{candidate[1]}" - ], - } - ) - matching_highlights[candidate[2]] = contexts - - return matching_documents, matching_highlights - - def _run_get_strict_matches(self, query_term, gsi): - matching_documents = [] - matching_positions = {} - words = query_term.split() - uuids = set(gsi.get(words[0], {}).get("documents", {}).get("uuid", [])) - # only look at documents that contain all words, for efficiency - for w in query_term.split(" "): - uuids = uuids.intersection( - gsi.get(w, {}).get("documents", {}).get("uuid", []) - ) - - for document in uuids: - first_word_pos = set(gsi[words[0]]["documents"]["uuid"][document]) - for i, word in enumerate(words): - word_uuids = gsi[word]["documents"]["uuid"][document] - # subtract i from each position to account for the fact that the first word is at position 0 - word_positions = set([x - i for x in word_uuids]) - - first_word_pos &= word_positions - # print position of all "." in text - if len(first_word_pos) > 0: - matching_documents.append(document) - matching_positions[document] = first_word_pos - - return matching_documents, matching_positions - - def _run_get_highlights( - self, gsi, query_field, matching_documents, matching_positions, highlight_stride - ): - matching_highlights = {} - for document in matching_documents: - highlights = [] - # get pos of .EOS in text - eos_token_positions = gsi[END_OF_SENTENCE_TOKEN]["documents"]["uuid"][ - document - ] - if len(eos_token_positions) == 0: - continue - # get first token position before each match and after - # highlight_stride - for match in matching_positions[document]: - before = [pos for pos in eos_token_positions if pos < match] - after = [pos for pos in eos_token_positions if pos > match] - - if len(before) == 0 or len(after) == 0: - continue - - original_before = max(before) - original_after = min(after) - - doc = self.global_index[document][query_field].split(" ") - - if highlight_stride == 1: - highlights.append( - " ".join( - [ - doc[pos] - # + 1 ensures we ignore EOS - for pos in range(original_before + 1, original_after) - ] - ) - ) - continue - - before = [pos for pos in eos_token_positions if pos < original_before] - after = [pos for pos in eos_token_positions if pos > original_after - 1] - - if before and after: - highlights.append( - " ".join( - [ - doc[pos] - # + 1 ensures we ignore EOS - for pos in range(max(before) + 1, max(after)) - ] - ) - ) - elif before: - highlights.append( - " ".join( - [ - doc[pos] - # + 1 ensures we ignore EOS - for pos in range(max(before) + 1, original_after) - ] - ) - ) - elif after: - highlights.append( - " ".join( - [ - doc[pos] - # + 1 ensures we ignore EOS - for pos in range(original_before + 1, min(after)) - ] - ) - ) - - matching_highlights[document] = highlights - - return matching_highlights - def _run(self, query: dict, query_field: str) -> List[str]: """ Accept a query and return a list of matching documents. @@ -1285,7 +1133,8 @@ def _run(self, query: dict, query_field: str) -> List[str]: enforce_strict = query["query"][query_field].get("strict", False) highlight_terms = query["query"][query_field].get("highlight", False) - highlight_stride = query["query"][query_field].get("highlight_stride", 1) + + highlight_stride = query["query"][query_field].get("highlight_stride", 10) if not self.gsis.get(query_field): self.create_gsi(query_field, GSI_INDEX_STRATEGIES.INFER) @@ -1315,52 +1164,170 @@ def _run(self, query: dict, query_field: str) -> List[str]: query_terms = [query_term.replace("*", c) for c in string.ascii_lowercase] for query_term in query_terms: - if gsi_type == GSI_INDEX_STRATEGIES.TRIGRAM_CODE: - matching_documents, matching_highlights = self._run_trigram_code( - query_term, query_field - ) - elif ( - query_type == "starts_with" and gsi_type == GSI_INDEX_STRATEGIES.PREFIX - ): - matches = gsi.keys(prefix=query_term) - matching_documents.extend([gsi[match] for match in matches]) - elif ( - query_type in {"contains", "wildcard"} - and gsi_type == GSI_INDEX_STRATEGIES.CONTAINS + if gsi_type not in ( + GSI_INDEX_STRATEGIES.FLAT, + GSI_INDEX_STRATEGIES.NUMERIC, + GSI_INDEX_STRATEGIES.DATE, ): - if enforce_strict or highlight_terms: - matching_documents, matching_positions = ( - self._run_get_strict_matches(query_term, gsi) - ) - if highlight_terms: - matching_highlights = self._run_get_highlights( - gsi, - query_field, - matching_documents, - matching_positions, - highlight_stride, - ) - else: - for word in str(query_term).split(" "): - word = word.lower() + if gsi_type == GSI_INDEX_STRATEGIES.TRIGRAM_CODE: + trigrams = get_trigrams(query_term) - if gsi.get(word) is None: - continue + contexts = [] - results = self.reverse_tf_idf[word].get(query_field, {}) + candidates = set(self.gsis[query_field]["gsi"].get(trigrams[0], [])) - matching_document_scores = results - matching_documents.extend(results.keys()) + for trigram in trigrams: + candidates = candidates.intersection( + set(self.gsis[query_field]["gsi"].get(trigram, [])) + ) - elif gsi_type not in ( - GSI_INDEX_STRATEGIES.FLAT, - GSI_INDEX_STRATEGIES.NUMERIC, - GSI_INDEX_STRATEGIES.DATE, - ): - if query_type == "starts_with": + # candidate[2] is the document uuid + matching_documents.extend( + [candidate[2] for candidate in candidates] + ) + + # get line numbers + for candidate in candidates: + contexts.append( + { + "line": candidate[1], + "code": self.gsis[query_field]["id2line"][ + f"{candidate[0]}:{candidate[1]}" + ], + } + ) + matching_highlights[candidate[2]] = contexts + if ( + query_type == "starts_with" + and gsi_type == GSI_INDEX_STRATEGIES.PREFIX + ): + matches = gsi.keys(prefix=query_term) + matching_documents.extend([gsi[match] for match in matches]) + elif query_type == "starts_with": for document in self.global_index.values(): if document.get(query_field).startswith(query_term): matching_documents.append(document["uuid"]) + if ( + query_type in {"contains", "wildcard"} + and gsi_type == GSI_INDEX_STRATEGIES.CONTAINS + ): + if enforce_strict: + words = query_term.split() + + all_matches = {} + all_match_positions = {} + + if len(words) == 1: + all_matches[words[0]] = list( + set( + gsi.get(words[0], {}) + .get("documents", {}) + .get("uuid", []) + ) + ) + all_match_positions[words[0]] = ( + gsi.get(words[0], {}) + .get("documents", {}) + .get("uuid", {}) + ) + + for word_index in range(0, len(words)): + current_word = words[word_index] + if word_index + 1 == len(words): + next_word = current_word + else: + next_word = words[word_index + 1] + + # break if on last word + if word_index + 1 == len(words): + break + + current_word_positions = ( + gsi.get(current_word, {}) + .get("documents", {}) + .get("uuid", {}) + ) + next_word_positions = ( + gsi.get(next_word, {}) + .get("documents", {}) + .get("uuid", {}) + ) + + matches_for_this_word = [] + match_positions = defaultdict(list) + + for doc_id, positions in current_word_positions.items(): + if doc_id not in next_word_positions: + continue + + for position in set(positions): + if ( + position + 1 in next_word_positions[doc_id] + ) or len(words) == 1: + matches_for_this_word.append(doc_id) + match_positions[doc_id].append(position) + break + + if word_index + 1 == len(words) and len(words) == 1: + all_matches[current_word] = matches_for_this_word + all_match_positions[current_word] = match_positions + else: + all_matches[ + current_word + " " + next_word + ] = matches_for_this_word + all_match_positions[ + current_word + " " + next_word + ] = match_positions + + if all_matches: + matching_documents.extend( + set.intersection( + *[set(matches) for matches in all_matches.values()] + ) + ) + # score for each matching document is the # of matches + matching_document_scores.update( + { + doc_id: len(all_matches) + for doc_id in matching_documents + } + ) + if highlight_terms: + matches_with_context = defaultdict(list) + + for doc_occurrences in all_match_positions.values(): + for doc_id, positions in doc_occurrences.items(): + for position in positions: + start = max(0, position - highlight_stride) + end = min( + position + highlight_stride, + len( + self.global_index[doc_id][ + highlight_terms + ].split() + ), + ) + matches_with_context[doc_id].append( + " ".join( + self.global_index[doc_id][ + highlight_terms + ].split()[start:end] + ) + ) + + matching_highlights.update(matches_with_context) + else: + for word in str(query_term).split(" "): + word = word.lower() + + if gsi.get(word) is None: + continue + + results = self.reverse_tf_idf[word].get(query_field, {}) + + matching_document_scores = results + matching_documents.extend(results.keys()) + elif query_type == "equals": matching_documents.extend( gsi.get(query_term, {}).get("documents", {}).get("uuid", []) @@ -1402,7 +1369,7 @@ def _run(self, query: dict, query_field: str) -> List[str]: if query_term is None or key is None: continue - matches = pybmoore.search(str(query_term), key) + matches = pybmoore.search(query_term, key) if query_type in {"contains", "wildcard"} and len(matches) > 0: matching_documents.extend(value) @@ -1417,18 +1384,13 @@ def _run(self, query: dict, query_field: str) -> List[str]: "scores": defaultdict(dict), "highlights": defaultdict(dict), } - - for doc in matching_documents[: self.match_limit_for_large_result_pages]: - advanced_query_information["scores"][doc] = matching_document_scores.get( - doc, 0 - ) * float(boost_factor) + + for doc in matching_documents[:self.match_limit_for_large_result_pages]: + advanced_query_information["scores"][doc] = matching_document_scores.get(doc, 0) * float( + boost_factor + ) if matching_highlights: - advanced_query_information["highlights"][doc] = matching_highlights.get( - doc, {} - ) + advanced_query_information["highlights"][doc] = matching_highlights.get(doc, {}) - return ( - advanced_query_information, - matching_documents[: self.match_limit_for_large_result_pages], - ) + return advanced_query_information, matching_documents[:self.match_limit_for_large_result_pages]