diff --git a/integration-tests/corpus/test_basic_corpus.py b/integration-tests/corpus/test_basic_corpus.py index 318eb63..91cd3a0 100644 --- a/integration-tests/corpus/test_basic_corpus.py +++ b/integration-tests/corpus/test_basic_corpus.py @@ -25,8 +25,8 @@ def test_save_then_search_one_corpus(ctx): output = test_corpus.search("It is sunny") # print("OUTPUT IS : ") # print(output) - assert "sunshine" in output[1][0] - assert "weather" in output[0][0] + assert "sunshine" in output[1][1] + assert "weather" in output[0][1] def test_delete_all_content(ctx): @@ -45,7 +45,7 @@ def test_delete_all_content(ctx): time.sleep(1) output = test_corpus.search("It is sunny") - assert "sunshine" in output[1][0] + assert "sunshine" in output[1][1] test_corpus.delete_all_content() time.sleep(1) diff --git a/integration-tests/corpus/test_multicorpus.py b/integration-tests/corpus/test_multicorpus.py new file mode 100644 index 0000000..8b8860b --- /dev/null +++ b/integration-tests/corpus/test_multicorpus.py @@ -0,0 +1,42 @@ +import numpy as np +import uuid +import time +from memas.corpus import basic_corpus +from memas.corpus.corpus_searching import multi_corpus_search +from memas.interface.corpus import Citation, CorpusInfo, CorpusType + +corpus_name = "test corpus1" + + +def test_multicorpus_search(ctx, test_client): + namespace_id = uuid.uuid4() + corpus_id1 = uuid.uuid4() + corpus_id2 = uuid.uuid4() + corpus_id3 = uuid.uuid4() + corpus_info1 = CorpusInfo("test_corpus1", namespace_id, corpus_id1, CorpusType.CONVERSATION) + corpus_info2 = CorpusInfo("test_corpus2", namespace_id, corpus_id2, CorpusType.KNOWLEDGE) + corpus_info3 = CorpusInfo("test_corpus3", namespace_id, corpus_id3, CorpusType.CONVERSATION) + test_corpus1 = basic_corpus.BasicCorpus(corpus_info1, ctx.corpus_metadata, ctx.corpus_doc, ctx.corpus_vec) + test_corpus2 = basic_corpus.BasicCorpus(corpus_info2, ctx.corpus_metadata, ctx.corpus_doc, ctx.corpus_vec) + test_corpus3 = basic_corpus.BasicCorpus(corpus_info3, ctx.corpus_metadata, ctx.corpus_doc, ctx.corpus_vec) + + text1 = "The sun is high. California sunshine is great. " + text2 = "I picked up my phone and then dropped it again. I cant seem to get a good grip on things these days. It persists into my everyday tasks" + text3 = "The weather is great today, but I worry that tomorrow it won't be. My umbrella is in the repair shop." + + assert test_corpus1.store_and_index(text1, Citation("www.docsource1", "SSSdoc1", "", "doc1")) + assert test_corpus2.store_and_index(text2, Citation("were.docsource2", "SSSdoc2", "", "doc2")) + assert test_corpus3.store_and_index(text3, Citation("docsource3.ai", "SSSdoc3", "", "doc3")) + + time.sleep(1) + + corpus_dict = {} + corpus_dict[CorpusType.CONVERSATION] = [test_corpus1, test_corpus3] + corpus_dict[CorpusType.KNOWLEDGE] = [test_corpus2] + + output = multi_corpus_search(corpus_dict, "It is sunny", ctx, 5) + # Check that text was retrieved from all 3 corpuses upon searching + assert len(output) == 3 + + assert "sunshine" in output[1][1] + assert "weather" in output[0][1] diff --git a/memas/corpus/corpus_searching.py b/memas/corpus/corpus_searching.py index fd2727e..044ee2c 100644 --- a/memas/corpus/corpus_searching.py +++ b/memas/corpus/corpus_searching.py @@ -1,17 +1,55 @@ # from search_redirect import SearchSettings from uuid import UUID from functools import reduce -from memas.interface.corpus import Corpus, CorpusFactory +from memas.interface.corpus import Corpus, CorpusFactory, CorpusType from memas.interface.corpus import Citation +from collections import defaultdict from memas.interface.storage_driver import DocumentEntity from memas.interface.exceptions import SentenceLengthOverflowException -def corpora_search(corpus_ids: list[UUID], clue: str) -> list[tuple[float, str, Citation]]: - vector_search_count: int = 10 +def multi_corpus_search(corpus_sets: dict[CorpusType, list[Corpus]], clue: str, ctx, result_limit: int) -> list[tuple[float, str, Citation]]: + results = defaultdict(list) + + # Direct each multicorpus search to the right algorithm + for corpus_type, corpora_list in corpus_sets.items(): + # Default basic corpus handling + if corpus_type == CorpusType.KNOWLEDGE or corpus_type == CorpusType.CONVERSATION: + corpus_type_results = basic_corpora_search(corpora_list, clue, ctx) + results["BASIC_SCORING"].extend(corpus_type_results) + + sorted_results_matrix = [] + # Sort results with compareable scoring schemes + for scored_results in results.values(): + # Sort by descending scoring so best results come first + sorted_scored_results = sorted(scored_results, key=lambda x: x[0], reverse=True) + sorted_results_matrix.append(sorted_scored_results) + + # To combine results for corpora that don't have compareable scoring take equal sized subsets of each Corpus type + # TODO : Consider changing this at some point in the future to have better searching of corpus sets with non-comparable scoring + combined_results = [] + for j in range(max([len(x) for x in sorted_results_matrix])): + for i in range(len(sorted_results_matrix)): + if j >= len(sorted_results_matrix[i]) or len(combined_results) >= result_limit: + break + combined_results.append(sorted_results_matrix[i][j]) + if len(combined_results) >= result_limit: + break + + return combined_results + + +""" +All corpora here should be of the same CorpusType implementation (basic_corpus) +""" + + +def basic_corpora_search(corpora: list[Corpus], clue: str, ctx) -> list[tuple[float, str, Citation]]: + # Extract information needed for a search + corpus_ids = [x.corpus_id for x in corpora] doc_store_results: list[tuple[float, str, Citation]] = [] - temp_res = ctx.corpus_doc.multi_corpus_search(corpus_ids, clue) + temp_res = ctx.corpus_doc.search_corpora(corpus_ids, clue) # Search the document store for score, doc_entity in temp_res: document_text = doc_entity.document @@ -21,7 +59,7 @@ def corpora_search(corpus_ids: list[UUID], clue: str) -> list[tuple[float, str, # Search for the vectors vec_store_results: list[tuple[float, str, Citation]] = [] - temp_res2 = ctx.corpus_vec.multi_corpus_search(corpus_ids, clue) + temp_res2 = ctx.corpus_vec.search_corpora(corpus_ids, clue) for score, doc_entity, start_index, end_index in temp_res2: # Verify that the text recovered from the vectors fits the maximum sentence criteria @@ -33,10 +71,6 @@ def corpora_search(corpus_ids: list[UUID], clue: str) -> list[tuple[float, str, vec_store_results.append([score, doc_entity.document, citation]) - # print("Docs then Vecs : ") - # print(doc_store_results) - # print(vec_store_results) - # If any of the searches returned no results combine and return if len(vec_store_results) == 0: doc_store_results.sort(key=lambda x: x[0], reverse=True) @@ -52,10 +86,6 @@ def corpora_search(corpus_ids: list[UUID], clue: str) -> list[tuple[float, str, def normalize_and_combine(doc_results: list, vec_results: list): - # print("Docs then Vecs : ") - # print(doc_results) - # print(vec_results) - # normalization with assumption that top score matches are approximately equal # Vec scores are based on distance, so smaller is better. Need to inverse the @@ -117,7 +147,4 @@ def normalize_and_combine(doc_results: list, vec_results: list): doc_results_normalized.extend(unique_vectors) - # Sort by descending scoring so best results come first - doc_results_normalized.sort(key=lambda x: x[0], reverse=True) - - return [(y, z) for [x, y, z] in doc_results_normalized] + return doc_results_normalized diff --git a/memas/dataplane.py b/memas/dataplane.py index 893765f..b1699a8 100644 --- a/memas/dataplane.py +++ b/memas/dataplane.py @@ -1,7 +1,10 @@ from dataclasses import asdict from flask import Blueprint, current_app, request from memas.context_manager import ctx +from memas.corpus.corpus_searching import multi_corpus_search from memas.interface.corpus import Citation, Corpus, CorpusType +from collections import defaultdict +from memas.interface.namespace import CORPUS_SEPARATOR dataplane = Blueprint("dp", __name__, url_prefix="/dp") @@ -17,18 +20,27 @@ def recall(): corpus_infos = ctx.memas_metadata.get_query_corpora(namespace_pathname) current_app.logger.debug(f"Querying corpuses: {corpus_infos}") - search_results: list[tuple[str, Citation]] = [] + # search_results: list[tuple[str, Citation]] = [] + # for corpus_info in corpus_infos: + # corpus: Corpus = ctx.corpus_provider.get_corpus_by_info(corpus_info) + # search_results.extend(corpus.search(clue=clue)) + + # Group the corpora to search into sets based on their CorpusType + corpora_grouped_by_type = defaultdict(list) for corpus_info in corpus_infos: + corpus_type = corpus_info.corpus_type corpus: Corpus = ctx.corpus_provider.get_corpus_by_info(corpus_info) - search_results.extend(corpus.search(clue=clue)) + corpora_grouped_by_type[corpus_type].append(corpus) - # Combine the results and only take the top ones - search_results.sort(key=lambda x: x[0], reverse=True) + # Execute a multicorpus search + # TODO : Should look into refactor to remove ctx later and have a cleaner solution + search_results = multi_corpus_search(corpora_grouped_by_type, clue, ctx, 4) + current_app.logger.debug(f"Search Results are: {search_results}") # TODO : It will improve Query speed significantly to fetch citations after determining which documents to send to user # Take only top few scores and remove scoring element before sending - return [{"document": doc, "citation": asdict(citation)} for doc, citation in search_results[0:5]] + return [{"document": doc, "citation": asdict(citation)} for score, doc, citation in search_results[0:5]] @dataplane.route('/memorize', methods=["POST"])