Skip to content

Commit

Permalink
work on AND operator
Browse files Browse the repository at this point in the history
  • Loading branch information
capjamesg committed Nov 23, 2024
1 parent 6e9ca8c commit cebad5f
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 52 deletions.
85 changes: 35 additions & 50 deletions jamesql/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from typing import Dict, List

import orjson
import numpy
import pybmoore
import pygtrie
from BTrees.OOBTree import OOBTree
Expand Down Expand Up @@ -873,8 +874,6 @@ def search(

highlights = metadata.get("highlights", {})

end_time = time.time()

if query.get("sort_by") is None:
query["sort_by"] = "_score"

Expand All @@ -890,8 +889,13 @@ def search(
)

for doc in results:
doc_uuid = doc["uuid"]
doc_value = self.tf.get(doc_uuid, {})
doc_word_length = self.document_length_words[doc_uuid]
doc_score = 0

for term in term_queries:
tf = self.tf.get(doc["uuid"], {}).get(term, 0)
tf = doc_value.get(term, 0)
idf = self.idf.get(term, 0)

term_score = (tf * (self.k1 + 1)) / (
Expand All @@ -901,63 +905,44 @@ def search(
1
- self.b
+ self.b
* (self.document_length_words[doc["uuid"]] / self.avgdl)
* (doc_word_length / self.avgdl)
)
)
term_score *= idf

doc_scores[doc["uuid"]] += term_score

doc_score += term_score
for field in fields:
# word_pos = defaultdict(list)

# for i, term in enumerate(doc[field].lower().split(" ")):
# word_pos[term].append(i)

# give a boost if all terms are within 1 word of each other
# so a doc with "all too well" would do better than "all well too"
if (
# all([w in word_pos for w in term_queries])
len(term_queries) > 0
):
starting_word_pos = self.gsis[field]["gsi"][term_queries[0]]["documents"]["uuid"][doc["uuid"]]
# print(term_queries[0], word_pos)
first_word_pos = set(starting_word_pos)

for i, term in enumerate(term_queries):
word_pos = self.gsis[field]["gsi"][term]["documents"]["uuid"][doc["uuid"]]
field_gsi = self.gsis[field]["gsi"]
starting_word_pos = field_gsi[term_queries[0]]["documents"]["uuid"][doc_uuid]
first_word_pos = set(starting_word_pos)

for i, term in enumerate(term_queries[1:]):
word_pos = field_gsi[term]["documents"]["uuid"][doc_uuid]
first_word_pos &= set(x - i for x in word_pos)

if first_word_pos and field != "title_lower":
doc_score += len(first_word_pos)

positions = set([x - i for x in word_pos]) # | set([x + i for x in word_pos])
first_word_pos &= positions
# if len(first_word_pos.intersection(positions)) > 0: # and i != len(term_queries) - 1:
# first_word_pos &= positions

if first_word_pos and field != "title_lower":
doc_scores[doc["uuid"]] += len(first_word_pos)
# * len(set(word_pos[term_queries[0]]))
if field == "title_lower":
# get word overlap between title and terms
overlap = set(term_queries).intersection(set(doc["title_lower"].split(" ")))
# calculate overlap ratio
overlap_ratio = len(overlap) / len(set(doc["title_lower"].split(" ")))
# print((2 * (len(overlap) / overlap_ratio)))
# print(overlap_ratio, doc["title"], field, "overlap")
# if "shake up" in doc["title_lower"]:
# if "shake up" in doc["title_lower"] or "random aeropress" in doc["title_lower"]:
# print(overlap_ratio, doc["title"])
doc_scores[doc["uuid"]] *= (50 / (1 - overlap_ratio + 1))

# # add weight for the first time the term is mentioned
# the closer the mention is to the beginning of the document, the higher the weight
if first_word_pos and field == "title_lower":
# print(first_word_pos, doc["title"], field)
min_pos = min(first_word_pos)
# print(min_pos)
doc_scores[doc["uuid"]] *= 1 + (1 / (min_pos + 1))
if field == "title_lower":
title_terms = set(doc["title_lower"].split())
overlap = title_terms & set(term_queries)
overlap_ratio = len(overlap) / len(title_terms) if title_terms else 0
doc_score *= (50 / (1 - overlap_ratio + 1))

# add weight for the first time the term is mentioned
# the closer the mention is to the beginning of the document, the higher the weight
if first_word_pos and field == "title_lower":
min_pos = min(first_word_pos)
doc_score *= 1 + (1 / (min_pos + 1))

doc_scores[doc["uuid"]] = doc_score

end_time = time.time()
# add _score key to all results; create new object
for doc in results:
doc["_score"] = doc_scores.get(doc["uuid"], 0)

if query.get("sort_order") == "asc":
results = sorted(results, key=itemgetter(results_sort_by), reverse=False)
else:
Expand Down
8 changes: 6 additions & 2 deletions jamesql/rewriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
start: (query)+ sort_component?
or_query: (query ("OR ") query)*
query: or_query | query_component
and_query: (query ("AND ") query)*
query: and_query | or_query | query_component
query_component: (negate_query | range_query | strict_search_query | word_query | field_query | comparison)+
sort_component: "sort:" TERM (ORDER)?
Expand All @@ -27,7 +28,7 @@
TERM: /[a-zA-Z0-9_]+/
ORDER: "ASC" | "DESC" | "asc" | "desc"
%import common.WS
%import common.W
%ignore WS
"""

Expand Down Expand Up @@ -131,6 +132,9 @@ def FLOAT(self, items):

def or_query(self, items):
return {"or": items}

def and_query(self, items):
return {"and": items}

def negate_query(self, items):
return {"not": items[0]}
Expand Down

0 comments on commit cebad5f

Please sign in to comment.