Skip to content

Commit

Permalink
fix broken tests
Browse files Browse the repository at this point in the history
  • Loading branch information
capjamesg committed Oct 17, 2024
1 parent 3d2f916 commit 9296969
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 12 deletions.
25 changes: 18 additions & 7 deletions jamesql/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,8 +237,17 @@ def _create_reverse_index(
self.tf_idf[word][score].append(document["uuid"])
else:
self.tf_idf[word][score] = [document["uuid"]]

for w in document[index_by].split(" "):
if self.reverse_tf_idf[w].get(index_by) is None:
self.reverse_tf_idf[w][index_by] = {}

self.reverse_tf_idf[word][document["uuid"]] = score
self.reverse_tf_idf[w][index_by][document["uuid"]] = score

if self.reverse_tf_idf[w.lower()].get(index_by) is None:
self.reverse_tf_idf[w.lower()][index_by] = {}

self.reverse_tf_idf[w.lower()][index_by][document["uuid"]] = score

return index

Expand Down Expand Up @@ -823,7 +832,7 @@ def search(self, query: dict) -> List[str]:

metadata, result_ids = self._recursively_parse_query(query["query"])

results = [self.global_index.get(doc_id) for doc_id in result_ids]
results = [self.global_index.get(doc_id) for doc_id in result_ids if doc_id in self.global_index]

results = orjson.loads(orjson.dumps(results))

Expand Down Expand Up @@ -855,13 +864,15 @@ def search(self, query: dict) -> List[str]:

for document in results:
if document.get("_score") is None:
document["_score"] = 1
document["_score"] = 0

transformer = JameSQLScriptTransformer(document)

document["_score"] = transformer.transform(tree)

results = sorted(results, key=lambda x: x.get("_score", 1), reverse=True)
print(document["_score"])

results = sorted(results, key=lambda x: x.get("_score", 0), reverse=True)

if query.get("skip"):
results = results[int(query["skip"]) :]
Expand Down Expand Up @@ -1249,7 +1260,7 @@ def _run(self, query: dict, query_field: str) -> List[str]:
if gsi.get(word) is None:
continue

results = self.reverse_tf_idf[word]
results = self.reverse_tf_idf[word].get(query_field, {})

matching_document_scores = results
matching_documents.extend(results.keys())
Expand Down Expand Up @@ -1308,7 +1319,7 @@ def _run(self, query: dict, query_field: str) -> List[str]:

advanced_query_information = {
"scores": defaultdict(dict),
"contexts": defaultdict(dict),
"highlights": defaultdict(dict),
}

for doc in matching_documents[:self.match_limit_for_large_result_pages]:
Expand All @@ -1317,6 +1328,6 @@ def _run(self, query: dict, query_field: str) -> List[str]:
)

if matching_highlights:
advanced_query_information["contexts"][doc] = matching_highlights.get(doc, {})
advanced_query_information["highlights"][doc] = matching_highlights.get(doc, {})

return advanced_query_information, matching_documents[:self.match_limit_for_large_result_pages]
3 changes: 0 additions & 3 deletions tests/group_by.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,6 @@ def create_indices(request):
"category": ["pop", "acoustic"],
"uuid": "18fbe44e19a24153b0a22841261db61c",
"_score": 1,
"_context": {},
}
]
},
Expand All @@ -97,7 +96,6 @@ def create_indices(request):
"category": ["pop", "acoustic"],
"uuid": "eb11180b16e34467a5d457f7115fda38",
"_score": 1,
"_context": {},
}
],
"acoustic": [
Expand All @@ -107,7 +105,6 @@ def create_indices(request):
"category": ["pop", "acoustic"],
"uuid": "eb11180b16e34467a5d457f7115fda38",
"_score": 1,
"_context": {},
}
],
},
Expand Down
4 changes: 2 additions & 2 deletions tests/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,7 @@ def test_search(
"query_score": "(_score + 2)",
},
"tolerate it",
2.549306144334055,
2.0,
DoesNotRaise(),
),
(
Expand All @@ -395,7 +395,7 @@ def test_search(
"sort_by": "_score",
},
"tolerate it",
1.09861228866810989,
0.09010335735736986,
DoesNotRaise(),
),
(
Expand Down

0 comments on commit 9296969

Please sign in to comment.