diff --git a/api/routers/query.py b/api/routers/query.py index b21ae06..aee734f 100644 --- a/api/routers/query.py +++ b/api/routers/query.py @@ -53,6 +53,9 @@ def new_query(input_data: InputModel) -> OutputModel: elif input_data.override_cache: should_create_new_query = True + elif query.cache_is_valid is False: + should_create_new_query = True + if should_create_new_query: cached = False query = create_query(lecture, input_data.query_string) diff --git a/db/crud.py b/db/crud.py index e8b3e52..43cbd46 100644 --- a/db/crud.py +++ b/db/crud.py @@ -102,7 +102,15 @@ def delete_all_except_last_message_in_analysis(analysis_id: int): # Query def get_most_recent_query_by_sha(lecture, sha: str): from db.models.query import Query - return Query.filter(Query.lecture_id == lecture.id).filter(Query.query_hash == sha).order_by(Query.modified_at.desc()).first() # noqa: E501 + return Query.filter( + Query.lecture_id == lecture.id + ).filter( + Query.query_hash == sha + ).filter( + Query.cache_is_valid == True + ).order_by( + Query.modified_at.desc() + ).first() def create_query(lecture, query_string: str): diff --git a/tests/feature/api/test_query.py b/tests/feature/api/test_query.py index 771f907..9908105 100644 --- a/tests/feature/api/test_query.py +++ b/tests/feature/api/test_query.py @@ -45,3 +45,33 @@ def request(): assert response.json()['response'] == 'gpt-3 response' assert gpt3.call_count == 2 + + +def test_query_response_cache_can_be_invalidated(mocker, api_client, analysed_lecture): + gpt3 = mocker.patch('tools.text.ai.gpt3', return_value='gpt-3 response') + + def make_query(query_string: str): + return api_client.post('/query', json={ + 'lecture_id': analysed_lecture.public_id, + 'language': analysed_lecture.language, + 'query_string': query_string, + }) + + def make_requests(): + make_query('some query') + make_query('some query') + make_query('some other query') + make_query('some third query') + + make_requests() + assert gpt3.call_count == 3 + + # invalidate the cache + queries = analysed_lecture.queries() + + for query in queries: + query.cache_is_valid = False + query.save() + + make_requests() + assert gpt3.call_count == 6