diff --git a/api/routers/query.py b/api/routers/query.py index 8bd217a..aee734f 100644 --- a/api/routers/query.py +++ b/api/routers/query.py @@ -42,7 +42,21 @@ def new_query(input_data: InputModel) -> OutputModel: query = get_most_recent_query_by_sha(lecture, Query.make_sha(input_data.query_string)) cached = True - if query is None or query.response is None or input_data.override_cache is True: + should_create_new_query = False + + if query is None: + should_create_new_query = True + + elif query.response is None: + should_create_new_query = True + + elif input_data.override_cache: + should_create_new_query = True + + elif query.cache_is_valid is False: + should_create_new_query = True + + if should_create_new_query: cached = False query = create_query(lecture, input_data.query_string) diff --git a/db/cmd.py b/db/cmd.py new file mode 100644 index 0000000..346bb07 --- /dev/null +++ b/db/cmd.py @@ -0,0 +1,17 @@ + + +from db.crud import get_all_lectures + + +def invalidate_all_query_caches(): + print('invalidating all query caches') + + lectures = get_all_lectures() + for lecture in lectures: + print(f'invalidating cache for lecture {lecture}: ', end='') + queries = lecture.queries() + print(f'found {len(queries)} queries', end='') + for query in queries: + query.cache_is_valid = False + query.save() + print(' done.') diff --git a/db/crud.py b/db/crud.py index e8b3e52..2d30f04 100644 --- a/db/crud.py +++ b/db/crud.py @@ -102,7 +102,15 @@ def delete_all_except_last_message_in_analysis(analysis_id: int): # Query def get_most_recent_query_by_sha(lecture, sha: str): from db.models.query import Query - return Query.filter(Query.lecture_id == lecture.id).filter(Query.query_hash == sha).order_by(Query.modified_at.desc()).first() # noqa: E501 + return Query.filter( + Query.lecture_id == lecture.id + ).filter( + Query.query_hash == sha + ).filter( + Query.cache_is_valid == True + ).order_by( + Query.modified_at.desc() + ).first() def create_query(lecture, query_string: str): @@ -112,6 +120,11 @@ def create_query(lecture, query_string: str): return query +def find_all_queries_for_lecture(lecture): + from db.models.query import Query + return Query.select().where(Query.lecture_id == lecture.id) + + # Message def save_message_for_analysis(analysis, title: str, body: Union[str, None] = None): from db.models.message import Message diff --git a/db/migrations/011_add_more_cache_is_valid_column_to_query_table.py b/db/migrations/011_add_more_cache_is_valid_column_to_query_table.py new file mode 100644 index 0000000..b711bb3 --- /dev/null +++ b/db/migrations/011_add_more_cache_is_valid_column_to_query_table.py @@ -0,0 +1,20 @@ +"""Peewee migrations -- 011_add_more_cache_is_valid_column_to_query_table.py.""" +import peewee as pw +from peewee_migrate import Migrator + +from db.models import Query + + +def migrate(migrator: Migrator, database: pw.Database, fake=False, **kwargs): + """Write your migrations here.""" + field = pw.BooleanField(null=False, default=True) + migrator.add_fields( + Query, + cache_is_valid=field, + ) + migrator.run() + + +def rollback(migrator: Migrator, database: pw.Database, fake=False, **kwargs): + """Write your rollback migrations here.""" + migrator.remove_fields(Query, 'cache_is_valid', cascade=True) diff --git a/db/models/lecture.py b/db/models/lecture.py index 2e61145..9ff1fc1 100644 --- a/db/models/lecture.py +++ b/db/models/lecture.py @@ -16,6 +16,7 @@ from db.crud import ( find_all_courses_relations_for_lecture_id, find_all_courses_for_lecture_id, + find_all_queries_for_lecture, ) @@ -150,6 +151,10 @@ def get_last_analysis(self): .order_by(Analysis.modified_at.desc()) .first()) + def queries(self): + queries = find_all_queries_for_lecture(self) + return queries + def courses(self): out = [] diff --git a/db/models/query.py b/db/models/query.py index 8299977..7b412d7 100644 --- a/db/models/query.py +++ b/db/models/query.py @@ -10,6 +10,7 @@ class Query(Base): query_hash = peewee.CharField(index=True, null=False) query_string = peewee.TextField(null=False) count = peewee.IntegerField(null=False, default=0) + cache_is_valid = peewee.BooleanField(null=False, default=True) response = peewee.TextField(null=True) @staticmethod diff --git a/setup.py b/setup.py index c31920a..37358de 100644 --- a/setup.py +++ b/setup.py @@ -35,6 +35,7 @@ 'create_migration = db.migrations:create_migration', 'migrate_up = db.migrations:run_migrations', 'migrate_down = db.migrations:rollback', + 'invalidate_query_cache = db.cmd:invalidate_all_query_caches', 'analysis_queues_restart = jobs:analysis_queues_restart', 'dispatch_fetch_metadata_for_all_lectures = jobs.cmd:fetch_metadata_for_all_lectures', 'dispatch_capture_preview_for_all_lectures = jobs.cmd:capture_preview_for_all_lectures', diff --git a/tests/conftest.py b/tests/conftest.py index 2e5aa58..51e3abc 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,6 +1,6 @@ from fastapi.testclient import TestClient +from db.models import Lecture, Analysis from fakeredis import FakeStrictRedis -from random import randbytes from rq import Queue import subprocess import tempfile @@ -88,6 +88,36 @@ def api_client(): return client +@pytest.fixture +def analysed_lecture(): + id = 'some_id' + + lecture = Lecture( + public_id=id, + language='sv', + ) + lecture.save() + + analysis = Analysis(lecture_id=lecture.id) + analysis.save() + + save_dummy_summary_for_lecture(lecture) + + return lecture + + +def save_dummy_summary_for_lecture(lecture: Lecture): + summary_filename = lecture.summary_filename() + if os.path.isfile(summary_filename): + os.unlink(summary_filename) + + with open(summary_filename, 'w+') as file: + file.write('some summary') + + lecture.summary_filepath = summary_filename + lecture.save() + + @pytest.fixture def mp4_file(): tf = tempfile.NamedTemporaryFile( diff --git a/tests/feature/api/test_query.py b/tests/feature/api/test_query.py new file mode 100644 index 0000000..9908105 --- /dev/null +++ b/tests/feature/api/test_query.py @@ -0,0 +1,77 @@ + + +def test_query_can_be_made_about_lecture(mocker, api_client, analysed_lecture): + mocker.patch('tools.text.ai.gpt3', return_value='gpt-3 response') + + response = api_client.post('/query', json={ + 'lecture_id': analysed_lecture.public_id, + 'language': analysed_lecture.language, + 'query_string': 'some interesting question', + }) + + assert response.json()['response'] == 'gpt-3 response' + + +def test_query_response_is_cached(mocker, api_client, analysed_lecture): + gpt3 = mocker.patch('tools.text.ai.gpt3', return_value='gpt-3 response') + + def request(): + return api_client.post('/query', json={ + 'lecture_id': analysed_lecture.public_id, + 'language': analysed_lecture.language, + 'query_string': 'some interesting question', + }) + + response = request() + response = request() + + assert response.json()['response'] == 'gpt-3 response' + assert gpt3.call_count == 1 + + +def test_query_response_cache_can_be_overridden(mocker, api_client, analysed_lecture): + gpt3 = mocker.patch('tools.text.ai.gpt3', return_value='gpt-3 response') + + def request(): + return api_client.post('/query', json={ + 'lecture_id': analysed_lecture.public_id, + 'language': analysed_lecture.language, + 'query_string': 'some interesting question', + 'override_cache': True, + }) + + response = request() + response = request() + + assert response.json()['response'] == 'gpt-3 response' + assert gpt3.call_count == 2 + + +def test_query_response_cache_can_be_invalidated(mocker, api_client, analysed_lecture): + gpt3 = mocker.patch('tools.text.ai.gpt3', return_value='gpt-3 response') + + def make_query(query_string: str): + return api_client.post('/query', json={ + 'lecture_id': analysed_lecture.public_id, + 'language': analysed_lecture.language, + 'query_string': query_string, + }) + + def make_requests(): + make_query('some query') + make_query('some query') + make_query('some other query') + make_query('some third query') + + make_requests() + assert gpt3.call_count == 3 + + # invalidate the cache + queries = analysed_lecture.queries() + + for query in queries: + query.cache_is_valid = False + query.save() + + make_requests() + assert gpt3.call_count == 6