Skip to content

Commit

Permalink
set higher performance expectations in tests
Browse files Browse the repository at this point in the history
  • Loading branch information
capjamesg committed Oct 17, 2024
1 parent afd02a6 commit 1e1b218
Show file tree
Hide file tree
Showing 12 changed files with 89 additions and 73 deletions.
112 changes: 64 additions & 48 deletions jamesql/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,21 +10,22 @@
from enum import Enum
from functools import lru_cache
from typing import Dict, List
from sortedcontainers import SortedDict

import nltk
import orjson
import pybmoore
import pygtrie
from BTrees.OOBTree import OOBTree
from lark import Lark
from nltk.corpus import stopwords
from nltk import download

from jamesql.rewriter import grammar as rewriter_grammar
from jamesql.rewriter import string_query_to_jamesql

from .script_lang import JameSQLScriptTransformer, grammar

nltk.download("stopwords")
download("stopwords")

INDEX_STORE = os.path.join(os.path.expanduser("~"), ".jamesql")
JOURNAL_FILE = os.path.join(os.getcwd(), "journal.jamesql")
Expand Down Expand Up @@ -108,6 +109,9 @@ def __init__(self, match_limit_for_large_result_pages = 1000) -> None:
maybe_placeholders=False,
)
self.match_limit_for_large_result_pages = match_limit_for_large_result_pages
self.tf = defaultdict(dict)
self.idf = {}
self.tf_idf = defaultdict(lambda: SortedDict())

def __len__(self):
return len(self.global_index)
Expand Down Expand Up @@ -184,19 +188,50 @@ def _create_reverse_index(
}
)

total_documents = len(documents)
document_frequencies = defaultdict(int)

for document in documents:
word_count = defaultdict(int) # Track word counts in this document
words = document[index_by].split() # Tokenize the document

unique_words_in_document = set()

index[document[index_by]]["documents"]["uuid"][document["uuid"]].append(0)

for pos, word in enumerate(document[index_by].split()):
index[word]["count"] += 1
index[word]["documents"]["uuid"][document["uuid"]].append(pos)
index[word]["documents"]["count"][document["uuid"]] += 1
self.word_counts[word.lower()] += 1
word_lower = word.lower()

# Update index
index[word_lower]["count"] += 1
index[word_lower]["documents"]["uuid"][document["uuid"]].append(pos)
index[word_lower]["documents"]["count"][document["uuid"]] += 1

self.word_counts[word_lower] += 1
self.word_counts[word] += 1
word_count[word_lower] += 1

index[document[index_by]]["count"] += 1
index[document[index_by]]["documents"]["uuid"][document["uuid"]].append(
pos
)
index[document[index_by]]["documents"]["count"][document["uuid"]] += 1
# Track first occurrence of the word in the document for document frequencies
if word_lower not in unique_words_in_document:
document_frequencies[word_lower] += 1
unique_words_in_document.add(word_lower)

# Compute term frequency (TF) for each word in the document
total_words_in_document = len(words)
for word, count in word_count.items():
self.tf[document["uuid"]][word] = count / total_words_in_document

# Compute inverse document frequency (IDF) for each word in the corpus
for word, doc_count in document_frequencies.items():
self.idf[word] = math.log(total_documents / doc_count)

# Compute TF-IDF for each document
for document in documents:
for word, tf_value in self.tf[document["uuid"]].items():
if self.tf_idf[word].get(tf_value * self.idf[word]):
self.tf_idf[word][tf_value * self.idf[word]].append(document["uuid"])
else:
self.tf_idf[word][tf_value * self.idf[word]] = [document["uuid"]]

return index

Expand Down Expand Up @@ -690,7 +725,6 @@ def create_gsi(
elif all([isinstance(item, dict) for item in documents_in_indexed_by]):
strategy = GSI_INDEX_STRATEGIES.NOT_INDEXABLE
else:
print(documents_in_indexed_by)
strategy = GSI_INDEX_STRATEGIES.FLAT

if strategy == GSI_INDEX_STRATEGIES.PREFIX:
Expand Down Expand Up @@ -804,7 +838,6 @@ def search(self, query: dict) -> List[str]:

document["_score"] = transformer.transform(tree)

print(document["_score"])
results = sorted(results, key=lambda x: x.get("_score", 1), reverse=True)

if query.get("skip"):
Expand Down Expand Up @@ -1181,48 +1214,31 @@ def _run(self, query: dict, query_field: str) -> List[str]:
matching_highlights.update(matches_with_context)
else:
for word in query_term.split(" "):
if gsi.get(word) is None:
if gsi.get(word.lower()) is None:
continue

uuid_of_documents = gsi[word]["documents"]["uuid"]
if len(matching_documents) == 0:
matching_documents.extend(uuid_of_documents)
else:
matching_documents.extend(
list(
set(matching_documents).intersection(
set(uuid_of_documents)
)
)
)

matching_documents_count = len(uuid_of_documents)
index_length = len(self.global_index)
results = self.tf_idf[word.lower()]
count = 0

inverse_document_frequency = math.log(
index_length / 1 + matching_documents_count
)

for uuid_of_document in uuid_of_documents:
document_term_frequency = (
gsi[word]["documents"]["count"][uuid_of_document]
/ self.doc_lengths[uuid_of_document][query_field]
)

tf_idf = (
document_term_frequency * inverse_document_frequency
)
for k, v in results.items():
if count > self.match_limit_for_large_result_pages:
break

for item in v[:self.match_limit_for_large_result_pages]:
# skip if word not in document
if self.gsis[query_field]["gsi"].get(
word.lower(), {}
).get("documents", {}).get("uuid", {}).get(item) is None:
continue

matching_document_scores[uuid_of_document] = tf_idf
matching_document_scores.update({item: k})
matching_documents.append(item)

elif query_type == "equals":
print(query_term)
print(gsi.get(query_term, {}))
matching_documents.extend(
[
doc_uuid
for doc_uuid in gsi.get(query_term, {})
.get("documents", {})
.get("uuid", [])
]
gsi.get(query_term, {}).get("documents", {}).get("uuid", [])
)
elif (
query_type == "contains" and gsi_type == GSI_INDEX_STRATEGIES.PREFIX
Expand Down
4 changes: 2 additions & 2 deletions tests/aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,10 +129,10 @@ def test_search(
if number_of_documents_expected > 0:
assert response["documents"][0]["title"] == top_result_value

assert float(response["query_time"]) < 0.1
assert float(response["query_time"]) < 0.06

# run if --benchmark is passed
if "--benchmark" in sys.argv:
response = large_index.search(query)

assert float(response["query_time"]) < 0.1
assert float(response["query_time"]) < 0.06
4 changes: 2 additions & 2 deletions tests/code_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,10 +104,10 @@ def test_code_search(
if number_of_documents_expected > 0:
assert response["documents"][0]["file_name"] == top_result_value

assert float(response["query_time"]) < 0.1
assert float(response["query_time"]) < 0.06

# run if --benchmark is passed
if "--benchmark" in sys.argv:
response = large_index.search(query)

assert float(response["query_time"]) < 0.1
assert float(response["query_time"]) < 0.06
4 changes: 2 additions & 2 deletions tests/data_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,10 +131,10 @@ def test_search(
if number_of_documents_expected > 0:
assert response["documents"][0]["title"] == top_result_value

assert float(response["query_time"]) < 0.1
assert float(response["query_time"]) < 0.06

# run if --benchmark is passed
if "--benchmark" in sys.argv:
response = large_index.search(query)

assert float(response["query_time"]) < 0.1
assert float(response["query_time"]) < 0.06
4 changes: 2 additions & 2 deletions tests/group_by.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,10 +149,10 @@ def test_search(
if number_of_documents_expected > 0:
assert response["documents"][0]["title"] == top_result_value

assert float(response["query_time"]) < 0.1
assert float(response["query_time"]) < 0.06

# run if --benchmark is passed
if "--benchmark" in sys.argv:
response = large_index.search(query)

assert float(response["query_time"]) < 0.1
assert float(response["query_time"]) < 0.06
4 changes: 2 additions & 2 deletions tests/highlight.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,10 +104,10 @@ def test_search(
if number_of_documents_expected > 0:
assert response["documents"][0]["title"] == top_result_value

assert float(response["query_time"]) < 0.1
assert float(response["query_time"]) < 0.06

# run if --benchmark is passed
if "--benchmark" in sys.argv:
response = large_index.search(query)

assert float(response["query_time"]) < 0.1
assert float(response["query_time"]) < 0.06
4 changes: 2 additions & 2 deletions tests/query_simplification.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,10 +131,10 @@ def test_simplification_then_search(
else:
assert response["documents"][0]["title"] == top_result_value

assert float(response["query_time"]) < 0.1
assert float(response["query_time"]) < 0.06

# run if --benchmark is passed
if "--benchmark" in sys.argv:
response = large_index.string_query_search(query)

assert float(response["query_time"]) < 0.1
assert float(response["query_time"]) < 0.06
4 changes: 2 additions & 2 deletions tests/range_queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,10 +134,10 @@ def test_search(
if number_of_documents_expected > 0:
assert response["documents"][0]["title"] == top_result_value

assert float(response["query_time"]) < 0.1
assert float(response["query_time"]) < 0.06

# run if --benchmark is passed
if "--benchmark" in sys.argv:
response = large_index.search(query)

assert float(response["query_time"]) < 0.1
assert float(response["query_time"]) < 0.06
4 changes: 2 additions & 2 deletions tests/sort_by.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,10 +176,10 @@ def test_search(
if number_of_documents_expected > 0:
assert response["documents"][0]["title"] == top_result_title

assert float(response["query_time"]) < 0.1
assert float(response["query_time"]) < 0.06

# run if --benchmark is passed
if "--benchmark" in sys.argv:
response = large_index.search(query)

assert float(response["query_time"]) < 0.1
assert float(response["query_time"]) < 0.06
4 changes: 2 additions & 2 deletions tests/string_queries_categorical_and_range.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,10 +205,10 @@ def test_search(
if number_of_documents_expected > 0:
assert response["documents"][0]["title"] == top_result_value

assert float(response["query_time"]) < 0.1
assert float(response["query_time"]) < 0.06

# run if --benchmark is passed
if "--benchmark" in sys.argv:
response = large_index.string_query_search(query)

assert float(response["query_time"]) < 0.1
assert float(response["query_time"]) < 0.06
4 changes: 2 additions & 2 deletions tests/string_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,10 +310,10 @@ def test_search(
assert response["documents"][0]["title"] == top_result_value

if response.get("query_time"):
assert float(response["query_time"]) < 0.1
assert float(response["query_time"]) < 0.06

# run if --benchmark is passed
if "--benchmark" in sys.argv:
response = large_index.string_query_search(query)

assert float(response["query_time"]) < 0.1
assert float(response["query_time"]) < 0.06
10 changes: 5 additions & 5 deletions tests/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,13 +365,13 @@ def test_search(
if number_of_documents_expected > 0:
assert response["documents"][0]["title"] == top_result_value

assert float(response["query_time"]) < 0.1
assert float(response["query_time"]) < 0.06

# run if --benchmark is passed
if "--benchmark" in sys.argv:
response = large_index.search(query)

assert float(response["query_time"]) < 0.1
assert float(response["query_time"]) < 0.06


@pytest.mark.parametrize(
Expand All @@ -384,7 +384,7 @@ def test_search(
"query_score": "(_score + 2)",
},
"tolerate it",
2.6931471805599454,
4.19722457733622,
DoesNotRaise(),
),
(
Expand All @@ -395,7 +395,7 @@ def test_search(
"sort_by": "_score",
},
"tolerate it",
1.3862943611198906,
6.591673732008659,
DoesNotRaise(),
),
(
Expand All @@ -405,7 +405,7 @@ def test_search(
"sort_by": "title",
},
"tolerate it",
10.014280344034402,
17.660258042044497,
DoesNotRaise(), # test searching TF/IDF indexed field
),
],
Expand Down

0 comments on commit 1e1b218

Please sign in to comment.