Skip to content

Commit

Permalink
run black and isort on code
Browse files Browse the repository at this point in the history
  • Loading branch information
capjamesg committed Oct 17, 2024
1 parent af61899 commit 06c2692
Show file tree
Hide file tree
Showing 15 changed files with 240 additions and 214 deletions.
34 changes: 8 additions & 26 deletions tests/autosuggest.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,31 +61,13 @@ def create_indices(request):
@pytest.mark.parametrize(
"query, suggestion",
[
(
"tolerat",
"tolerate"
),
(
"toler",
"tolerate"
),
(
"th",
"the"
),
(
"b",
"bolter"
),
(
"he",
""
), # not in index; part of another word
(
"cod",
""
), # not in index
]
("tolerat", "tolerate"),
("toler", "tolerate"),
("th", "the"),
("b", "bolter"),
("he", ""), # not in index; part of another word
("cod", ""), # not in index
],
)
def test_autosuggest(create_indices, query, suggestion):
index = create_indices[0]
Expand All @@ -95,4 +77,4 @@ def test_autosuggest(create_indices, query, suggestion):
assert index.autosuggest(query)[0] == suggestion

if large_index and suggestion != "":
assert large_index.autosuggest(query)[0] == suggestion
assert large_index.autosuggest(query)[0] == suggestion
5 changes: 4 additions & 1 deletion tests/code_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

CODE_BASE_DIR = "tests/fixtures/code"


def pytest_addoption(parser):
parser.addoption("--benchmark", action="store")

Expand Down Expand Up @@ -94,7 +95,9 @@ def test_code_search(
response = index.search(query)

# sort response by documents[0]["title"] to make it easier to compare
response["documents"] = sorted(response["documents"], key=lambda x: x["file_name"])
response["documents"] = sorted(
response["documents"], key=lambda x: x["file_name"]
)

assert len(response["documents"]) == number_of_documents_expected

Expand Down
1 change: 0 additions & 1 deletion tests/data_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ def pytest_addoption(parser):
parser.addoption("--benchmark", action="store")



@pytest.fixture(scope="session")
def create_indices(request):
with open("tests/fixtures/documents_with_varied_data_types.json") as f:
Expand Down
35 changes: 19 additions & 16 deletions tests/fixtures/code/index.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import hashlib
import json
import math
import os
import string
import time
Expand All @@ -10,11 +12,8 @@
import orjson
import pybmoore
import pygtrie
import hashlib
import math
from BTrees.OOBTree import OOBTree
from lark import Lark

from lark import Lark

from jamesql.rewriter import string_query_to_jamesql

Expand Down Expand Up @@ -70,9 +69,11 @@ class RANKING_STRATEGIES(Enum):
],
}


def get_trigrams(line):
return [line[i : i + 3] for i in range(len(line) - 2)]


class JameSQL:
SELF_METHODS = {"close_to": "_close_to"}

Expand Down Expand Up @@ -238,17 +239,19 @@ def enable_autosuggest(self, field):

self.autosuggest_on = field

def autosuggest(self, query: str, match_full_record = False, limit = 5) -> List[str]:
def autosuggest(self, query: str, match_full_record=False, limit=5) -> List[str]:
"""
Accepts a query and returns a list of suggestions.
"""
if not self.autosuggest_index or not query:
return []

if match_full_record:
results = []

for i in self.autosuggest_index.itervalues(prefix=query.lower(), shallow = False):

for i in self.autosuggest_index.itervalues(
prefix=query.lower(), shallow=False
):
if self.autosuggest_index.has_subtrie(i):
continue
results.append(i)
Expand All @@ -257,7 +260,6 @@ def autosuggest(self, query: str, match_full_record = False, limit = 5) -> List[
else:
return self.autosuggest_index.keys(prefix=query.lower())[0:limit]


def _compute_string_query(self, query: str, query_keys: list = []) -> List[str]:
"""
Accepts a string query and returns a list of matching documents.
Expand Down Expand Up @@ -359,7 +361,7 @@ def add(
for key, value in document.items():
if key not in self.gsis:
continue

if self.gsis[key]["strategy"] == GSI_INDEX_STRATEGIES.CONTAINS.name:
if not self.gsis[key]["gsi"].get(value):
self.gsis[key]["gsi"][value] = {
Expand Down Expand Up @@ -425,15 +427,15 @@ def add(
if not self.gsis[key]["gsi"].get(trigram):
self.gsis[key]["gsi"][trigram] = []

self.gsis[key]["gsi"][trigram].append((file_name, line_num, document["uuid"]))
self.gsis[key]["gsi"][trigram].append(
(file_name, line_num, document["uuid"])
)
self.gsis[key]["id2line"][f"{file_name}:{line_num}"] = line
self.gsis[key]["doc_lengths"][file_name] = total_lines
else:
raise ValueError(
"Invalid GSI strategy. Must be one of: "
+ ", ".join(
[strategy.name for strategy in GSI_INDEX_STRATEGIES]
)
+ ", ".join([strategy.name for strategy in GSI_INDEX_STRATEGIES])
+ "."
)

Expand Down Expand Up @@ -896,7 +898,9 @@ def _run(self, query: dict, query_field: str) -> List[str]:
)

# candidate[2] is the document uuid
matching_documents.extend([candidate[2] for candidate in candidates])
matching_documents.extend(
[candidate[2] for candidate in candidates]
)

# get line numbers
for candidate in candidates:
Expand Down Expand Up @@ -1090,5 +1094,4 @@ def _run(self, query: dict, query_field: str) -> List[str]:

doc["_context"].extend(matching_highlights.get(doc["uuid"], {}))


return {}, matching_documents
5 changes: 2 additions & 3 deletions tests/fixtures/code/simplifier.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
def normalize_operator_query(t):
if isinstance(t, str):
return t

return "_".join(t)


Expand Down Expand Up @@ -43,6 +43,5 @@ def simplifier(terms):
):
if t[1] in outer_terms:
to_remove.add(t[1])

return [i for i in new_terms if normalize_operator_query(i) not in to_remove]

return [i for i in new_terms if normalize_operator_query(i) not in to_remove]
10 changes: 7 additions & 3 deletions tests/fixtures/code/simplifier_demo.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from collections import defaultdict
import math
import os
from collections import defaultdict


def get_trigrams(line):
return [line[i : i + 3] for i in range(len(line) - 2)]


index = defaultdict(list)

# read all python files in .
Expand Down Expand Up @@ -45,8 +47,10 @@ def get_trigrams(line):

for file, line_num in candidates:
print(f"{file.name}:{line_num}")
for i in range(max(0, line_num - context), min(doc_lengths[file.name], line_num + context + 1)):
for i in range(
max(0, line_num - context), min(doc_lengths[file.name], line_num + context + 1)
):
line = id2line[f"{file.name}:{i}"]
print(f"{i}: {line}")

print()
print()
36 changes: 28 additions & 8 deletions tests/gsi_type_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
def pytest_addoption(parser):
parser.addoption("--benchmark", action="store")


@pytest.mark.timeout(20)
def test_gsi_type_inference(request):
with open("tests/fixtures/documents_with_varied_data_types.json") as f:
Expand All @@ -26,7 +27,9 @@ def test_gsi_type_inference(request):
assert index.gsis["album_in_stock"]["strategy"] == GSI_INDEX_STRATEGIES.FLAT.name
assert index.gsis["rating"]["strategy"] == GSI_INDEX_STRATEGIES.NUMERIC.name
assert index.gsis["metadata"]["strategy"] == GSI_INDEX_STRATEGIES.NOT_INDEXABLE.name
assert index.gsis["record_last_updated"]["strategy"] == GSI_INDEX_STRATEGIES.DATE.name
assert (
index.gsis["record_last_updated"]["strategy"] == GSI_INDEX_STRATEGIES.DATE.name
)

with open("tests/fixtures/documents.json") as f:
documents = json.load(f)
Expand All @@ -51,10 +54,27 @@ def test_gsi_type_inference(request):
large_index.create_gsi("title", strategy=GSI_INDEX_STRATEGIES.CONTAINS)
large_index.create_gsi("lyric", strategy=GSI_INDEX_STRATEGIES.CONTAINS)

assert large_index.gsis["title"]["strategy"] == GSI_INDEX_STRATEGIES.CONTAINS.name
assert large_index.gsis["lyric"]["strategy"] == GSI_INDEX_STRATEGIES.CONTAINS.name
assert large_index.gsis["listens"]["strategy"] == GSI_INDEX_STRATEGIES.NUMERIC.name
assert large_index.gsis["album_in_stock"]["strategy"] == GSI_INDEX_STRATEGIES.FLAT.name
assert large_index.gsis["rating"]["strategy"] == GSI_INDEX_STRATEGIES.NUMERIC.name
assert large_index.gsis["metadata"]["strategy"] == GSI_INDEX_STRATEGIES.NOT_INDEXABLE.name
assert large_index.gsis["record_last_updated"]["strategy"] == GSI_INDEX_STRATEGIES.DATE.name
assert (
large_index.gsis["title"]["strategy"] == GSI_INDEX_STRATEGIES.CONTAINS.name
)
assert (
large_index.gsis["lyric"]["strategy"] == GSI_INDEX_STRATEGIES.CONTAINS.name
)
assert (
large_index.gsis["listens"]["strategy"] == GSI_INDEX_STRATEGIES.NUMERIC.name
)
assert (
large_index.gsis["album_in_stock"]["strategy"]
== GSI_INDEX_STRATEGIES.FLAT.name
)
assert (
large_index.gsis["rating"]["strategy"] == GSI_INDEX_STRATEGIES.NUMERIC.name
)
assert (
large_index.gsis["metadata"]["strategy"]
== GSI_INDEX_STRATEGIES.NOT_INDEXABLE.name
)
assert (
large_index.gsis["record_last_updated"]["strategy"]
== GSI_INDEX_STRATEGIES.DATE.name
)
6 changes: 3 additions & 3 deletions tests/query_simplification.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@
from contextlib import ExitStack as DoesNotRaise

import pytest
from lark import Lark

from jamesql import JameSQL
from jamesql.index import GSI_INDEX_STRATEGIES
from jamesql.rewriter import simplify_string_query, grammar
from lark import Lark
from jamesql.rewriter import grammar, simplify_string_query


def pytest_addoption(parser):
Expand Down Expand Up @@ -102,7 +102,7 @@ def create_indices(request):
"",
DoesNotRaise(),
), # test double negation of in clause
]
],
)
@pytest.mark.timeout(20)
def test_simplification_then_search(
Expand Down
1 change: 1 addition & 0 deletions tests/range_queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
def pytest_addoption(parser):
parser.addoption("--benchmark", action="store")


@pytest.fixture(scope="session")
def create_indices(request):
with open("tests/fixtures/documents_with_numeric_values.json") as f:
Expand Down
2 changes: 1 addition & 1 deletion tests/save_and_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,6 @@ def test_load_from_local_index():

assert len(index.global_index) == len(documents)
assert index.global_index
assert len(index.gsis) == 2 # indexing two fields
assert len(index.gsis) == 2 # indexing two fields
assert index.gsis["title"]
assert len(index.uuids_to_position_in_global_index) == len(documents)
3 changes: 1 addition & 2 deletions tests/script_lang.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@
from contextlib import ExitStack as DoesNotRaise

import pytest
from lark import Lark

from lark import Lark
from pytest import raises

from jamesql import JameSQL
Expand Down
44 changes: 12 additions & 32 deletions tests/spelling_correction.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
def pytest_addoption(parser):
parser.addoption("--benchmark", action="store")


@pytest.fixture(scope="session")
def create_indices(request):
with open("tests/fixtures/documents.json") as f:
Expand Down Expand Up @@ -56,39 +57,18 @@ def create_indices(request):
@pytest.mark.parametrize(
"query, corrected_query",
[
(
"tolerat",
"tolerate"
),
(
"tolerateit",
"tolerate it"
), # test segmentation
("tolerat", "tolerate"),
("tolerateit", "tolerate it"), # test segmentation
(
"startedwith",
"started with"
), # query word that appears uppercase in corpus of text
(
"toleratt",
"tolerate"
),
(
"toleratt",
"tolerate"
),
(
"tolerate",
"tolerate"
),
(
"toler",
"toler"
), # not in index
(
"cod",
"cod"
), # not in index
]
"started with",
), # query word that appears uppercase in corpus of text
("toleratt", "tolerate"),
("toleratt", "tolerate"),
("tolerate", "tolerate"),
("toler", "toler"), # not in index
("cod", "cod"), # not in index
],
)
def test_spelling_correction(create_indices, query, corrected_query):
index = create_indices[0]
Expand All @@ -97,4 +77,4 @@ def test_spelling_correction(create_indices, query, corrected_query):
assert index.spelling_correction(query) == corrected_query

if large_index:
assert large_index.spelling_correction(query) == corrected_query
assert large_index.spelling_correction(query) == corrected_query
Loading

0 comments on commit 06c2692

Please sign in to comment.