Skip to content

Commit

Permalink
performance improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
capjamesg committed Nov 18, 2024
1 parent 0e6879f commit d3cc451
Showing 1 changed file with 12 additions and 24 deletions.
36 changes: 12 additions & 24 deletions jamesql/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -893,50 +893,42 @@ def search(self, query: dict) -> List[str]:
term_queries = [term[list(term.keys())[0]]["contains"] for term in query["query"]["and"][0]["or"]]
fields = [list(term[list(term.keys())[0]].keys()) for term in query["query"]["and"][0]["or"]]

term_queries = list(set(term_queries))
fields = [field for sublist in fields for field in sublist]

for doc in results:
word_pos = defaultdict(list)
for i, word in enumerate(doc["post"].lower().split(" ")):
word_pos[word].append(i)
word_pos_title = defaultdict(list)
for i, word in enumerate(doc["title"].lower().split(" ")):
word_pos_title[word].append(i)
gsis = {field: self.gsis[field]["gsi"] for field in fields}

for doc in results:
doc["_score"] = 0

for term in term_queries:
tf = self.tf.get(doc["uuid"], {}).get(term, 0)
idf = self.idf.get(term, 0)

# bm25
term_score = (tf * (self.k1 + 1)) / (tf + self.k1 * (1 - self.b + self.b * (self.document_length_words[doc["uuid"]] / self.avgdl)))
term_score *= idf

doc["_score"] += term_score

for field in fields:
word_pos = gsis[field][term]["documents"]["uuid"][doc["uuid"]]
# give a boost if all terms are within 1 word of each other
# so a doc with "all too well" would do btter than "all well too"
if all([word_pos.get(w) for w in term_queries]):
if all([w in word_pos for w in term_queries]):
first_word_pos = set(word_pos[term_queries[0]])
total = first_word_pos.copy()
for i, term in enumerate(term_queries):
positions = set([x - i for x in word_pos[term]])
first_word_pos = first_word_pos.intersection(positions)
total = total.union(positions)
first_word_pos &= positions

if first_word_pos:
print(doc["title"])
doc["_score"] += (len(first_word_pos) + 1) * len(total)
doc["_score"] += (len(first_word_pos) + 1) * len(first_word_pos)

if field != "title_lower":
# TODO: Run only if query len > 1 word
if all([word_pos.get(w) for w in term_queries]):
first_word_pos = set(word_pos_title[term_queries[0]])
if field != "title_lower":
# TODO: Run only if query len > 1 word
first_word_pos = set(word_pos[term_queries[0]])
for i, term in enumerate(term_queries):
positions = set([x - i for x in word_pos_title[term]])
first_word_pos = first_word_pos.intersection(positions)
positions = set([x - i for x in word_pos[term]])
first_word_pos &= positions

if first_word_pos:
doc["_score"] *= 2 + len(first_word_pos)
Expand Down Expand Up @@ -1257,8 +1249,6 @@ def _run(self, query: dict, query_field: str) -> List[str]:
.get("uuid", {})
)

print(all_matches)

for word_index in range(0, len(words)):
current_word = words[word_index]
if word_index + 1 == len(words):
Expand Down Expand Up @@ -1307,8 +1297,6 @@ def _run(self, query: dict, query_field: str) -> List[str]:
current_word + " " + next_word
] = match_positions

print(all_match_positions)

if all_matches:
matching_documents.extend(
set.intersection(
Expand Down

0 comments on commit d3cc451

Please sign in to comment.