Skip to content

Commit

Permalink
performance improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
capjamesg committed Nov 18, 2024
1 parent 18f09b3 commit c16853b
Showing 1 changed file with 118 additions and 118 deletions.
236 changes: 118 additions & 118 deletions jamesql/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -820,174 +820,174 @@ def create_gsi(
return gsi

def search(self, query: dict) -> List[str]:
with self.write_lock:
start_time = time.time()
# with self.write_lock:
start_time = time.time()

results_limit = query.get("limit", 10)

results_limit = query.get("limit", 10)
metadata = {}

metadata = {}
if not query.get("query"):
return {
"documents": [],
"error": "No query provided",
"query_time": str(round(time.time() - start_time, 4)),
}

if not query.get("query"):
if query["query"] == {}: # empty query
results = []
elif query["query"] == "*": # all query
results = list(self.global_index.values())
else:
number_of_query_conditions = self._get_query_conditions(query["query"])

if len(number_of_query_conditions) > MAXIMUM_QUERY_STATEMENTS:
return {
"documents": [],
"error": "No query provided",
"error": "Too many query conditions. Maximum is "
+ str(MAXIMUM_QUERY_STATEMENTS)
+ ".",
"query_time": str(round(time.time() - start_time, 4)),
}

if query["query"] == {}: # empty query
results = []
elif query["query"] == "*": # all query
results = list(self.global_index.values())
else:
number_of_query_conditions = self._get_query_conditions(query["query"])
metadata, result_ids = self._recursively_parse_query(query["query"])

if len(number_of_query_conditions) > MAXIMUM_QUERY_STATEMENTS:
return {
"documents": [],
"error": "Too many query conditions. Maximum is "
+ str(MAXIMUM_QUERY_STATEMENTS)
+ ".",
"query_time": str(round(time.time() - start_time, 4)),
}
results = [self.global_index.get(doc_id) for doc_id in result_ids if doc_id in self.global_index]

metadata, result_ids = self._recursively_parse_query(query["query"])
results = orjson.loads(orjson.dumps(results))

results = [self.global_index.get(doc_id) for doc_id in result_ids if doc_id in self.global_index]
for r in results:
r["_score"] = 0
if r["uuid"] in metadata.get("scores", {}):
r["_score"] = metadata["scores"][r["uuid"]]
if r["uuid"] in metadata.get("highlights", {}):
r["_context"] = metadata["highlights"][r["uuid"]]

results = orjson.loads(orjson.dumps(results))
end_time = time.time()

for r in results:
r["_score"] = 0
if r["uuid"] in metadata.get("scores", {}):
r["_score"] = metadata["scores"][r["uuid"]]
if r["uuid"] in metadata.get("highlights", {}):
r["_context"] = metadata["highlights"][r["uuid"]]
if query.get("sort_by") is None:
query["sort_by"] = "_score"

end_time = time.time()
results_sort_by = query["sort_by"]

if query.get("sort_by") is None:
query["sort_by"] = "_score"
if query.get("sort_order") == "asc":
results = sorted(
results, key=itemgetter(results_sort_by), reverse=False
)
else:
results = sorted(
results, key=itemgetter(results_sort_by), reverse=True
)

results_sort_by = query["sort_by"]
if self.enable_experimental_bm25_ranker:
# TODO: Make sure this code can process boosts.

if query.get("sort_order") == "asc":
results = sorted(
results, key=itemgetter(results_sort_by), reverse=False
)
else:
results = sorted(
results, key=itemgetter(results_sort_by), reverse=True
)
self.avgdl = sum(self.document_length_words.values()) / len(self.document_length_words)

if self.enable_experimental_bm25_ranker:
# TODO: Make sure this code can process boosts.
if query["query"].get("or"):
operator = "or"
term_queries = [term.get("or")[0][list(term.get("or")[0].keys())[0]]["contains"] for term in query["query"]["or"]]
fields = [list(term.get("or")[0].keys()) for term in query["query"]["or"]]
else:
operator = "and"
term_queries = [term[list(term.keys())[0]]["contains"] for term in query["query"]["and"][0]["or"]]
fields = [list(term[list(term.keys())[0]].keys()) for term in query["query"]["and"][0]["or"]]

self.avgdl = sum(self.document_length_words.values()) / len(self.document_length_words)
term_queries = list(set(term_queries))
fields = [field for sublist in fields for field in sublist]

if query["query"].get("or"):
operator = "or"
term_queries = [term.get("or")[0][list(term.get("or")[0].keys())[0]]["contains"] for term in query["query"]["or"]]
fields = [list(term.get("or")[0].keys()) for term in query["query"]["or"]]
else:
operator = "and"
term_queries = [term[list(term.keys())[0]]["contains"] for term in query["query"]["and"][0]["or"]]
fields = [list(term[list(term.keys())[0]].keys()) for term in query["query"]["and"][0]["or"]]
gsis = {field: self.gsis[field]["gsi"] for field in fields if self.gsis.get(field)}

term_queries = list(set(term_queries))
fields = [field for sublist in fields for field in sublist]
for doc in results:
doc["_score"] = 0

gsis = {field: self.gsis[field]["gsi"] for field in fields if self.gsis.get(field)}
for term in term_queries:
tf = self.tf.get(doc["uuid"], {}).get(term, 0)
idf = self.idf.get(term, 0)

for doc in results:
doc["_score"] = 0
term_score = (tf * (self.k1 + 1)) / (tf + self.k1 * (1 - self.b + self.b * (self.document_length_words[doc["uuid"]] / self.avgdl)))
term_score *= idf

for term in term_queries:
tf = self.tf.get(doc["uuid"], {}).get(term, 0)
idf = self.idf.get(term, 0)
doc["_score"] += term_score

term_score = (tf * (self.k1 + 1)) / (tf + self.k1 * (1 - self.b + self.b * (self.document_length_words[doc["uuid"]] / self.avgdl)))
term_score *= idf
for field in fields:
word_pos = gsis[field][term]["documents"]["uuid"][doc["uuid"]]
# give a boost if all terms are within 1 word of each other
# so a doc with "all too well" would do btter than "all well too"
if all([w in word_pos for w in term_queries]):
first_word_pos = set(word_pos[term_queries[0]])
for i, term in enumerate(term_queries):
positions = set([x - i for x in word_pos[term]])
first_word_pos &= positions

doc["_score"] += term_score
if first_word_pos:
doc["_score"] += (len(first_word_pos) + 1) * len(first_word_pos)

for field in fields:
word_pos = gsis[field][term]["documents"]["uuid"][doc["uuid"]]
# give a boost if all terms are within 1 word of each other
# so a doc with "all too well" would do btter than "all well too"
if all([w in word_pos for w in term_queries]):
if field != "title_lower":
# TODO: Run only if query len > 1 word
first_word_pos = set(word_pos[term_queries[0]])
for i, term in enumerate(term_queries):
positions = set([x - i for x in word_pos[term]])
first_word_pos &= positions

if first_word_pos:
doc["_score"] += (len(first_word_pos) + 1) * len(first_word_pos)

if field != "title_lower":
# TODO: Run only if query len > 1 word
first_word_pos = set(word_pos[term_queries[0]])
for i, term in enumerate(term_queries):
positions = set([x - i for x in word_pos[term]])
first_word_pos &= positions
doc["_score"] *= 2 + len(first_word_pos)

if first_word_pos:
doc["_score"] *= 2 + len(first_word_pos)
# if "title_lower" in fields:
# # TODO: Make this more dynamic
# doc["_score"] *= len([term.get("or")[0].get("title_lower", {}).get("contains") in doc["title"].lower() for term in query["query"]["or"] if str(term.get("or")[0].get("title_lower", {}).get("contains")).lower() in doc["title"].lower()]) + 1

# if "title_lower" in fields:
# # TODO: Make this more dynamic
# doc["_score"] *= len([term.get("or")[0].get("title_lower", {}).get("contains") in doc["title"].lower() for term in query["query"]["or"] if str(term.get("or")[0].get("title_lower", {}).get("contains")).lower() in doc["title"].lower()]) + 1
# sort by doc score
results = sorted(
results, key=lambda x: x.get("_score", 0), reverse=True
)

# sort by doc score
results = sorted(
results, key=lambda x: x.get("_score", 0), reverse=True
)
if query.get("query_score"):
tree = parse_script_score(query["query_score"])

if query.get("query_score"):
tree = parse_script_score(query["query_score"])
for document in results:
if document.get("_score") is None:
document["_score"] = 0

for document in results:
if document.get("_score") is None:
document["_score"] = 0
transformer = JameSQLScriptTransformer(document)

transformer = JameSQLScriptTransformer(document)
document["_score"] = transformer.transform(tree)

document["_score"] = transformer.transform(tree)
results = sorted(results, key=lambda x: x.get("_score", 0), reverse=True)

results = sorted(results, key=lambda x: x.get("_score", 0), reverse=True)
if query.get("skip"):
results = results[int(query["skip"]) :]

if query.get("skip"):
results = results[int(query["skip"]) :]
total_results = len(results)

total_results = len(results)
if results_limit:
results = results[:results_limit]

if results_limit:
results = results[:results_limit]
if results_limit == 0:
results = []

if results_limit == 0:
results = []
result = {
"documents": results,
"query_time": str(round(end_time - start_time, 4)),
"total_results": total_results,
}

result = {
"documents": results,
"query_time": str(round(end_time - start_time, 4)),
"total_results": total_results,
if query.get("metrics") and "aggregate" in query["metrics"]:
result["metrics"] = {
"unique_record_values": self._get_unique_record_count(results),
}

if query.get("metrics") and "aggregate" in query["metrics"]:
result["metrics"] = {
"unique_record_values": self._get_unique_record_count(results),
}

if query.get("group_by"):
result["groups"] = defaultdict(list)
if query.get("group_by"):
result["groups"] = defaultdict(list)

for doc in results:
if isinstance(doc.get(query["group_by"]), list):
for item in doc.get(query["group_by"]):
result["groups"][item].append(doc)
else:
result["groups"][doc.get(query["group_by"])].append(doc)
for doc in results:
if isinstance(doc.get(query["group_by"]), list):
for item in doc.get(query["group_by"]):
result["groups"][item].append(doc)
else:
result["groups"][doc.get(query["group_by"])].append(doc)

return result
return result

def _get_query_conditions(self, query_tree):
first_key = list(query_tree.keys())[0]
Expand Down

0 comments on commit c16853b

Please sign in to comment.