Skip to content

Commit

Permalink
Merge pull request #29 from nattvara/invalidate-query-caches
Browse files Browse the repository at this point in the history
Add support for invalidating query caches
  • Loading branch information
nattvara authored Mar 8, 2023
2 parents 3c20177 + 4383d99 commit f1d0e4c
Show file tree
Hide file tree
Showing 9 changed files with 181 additions and 3 deletions.
16 changes: 15 additions & 1 deletion api/routers/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,21 @@ def new_query(input_data: InputModel) -> OutputModel:
query = get_most_recent_query_by_sha(lecture, Query.make_sha(input_data.query_string))

cached = True
if query is None or query.response is None or input_data.override_cache is True:
should_create_new_query = False

if query is None:
should_create_new_query = True

elif query.response is None:
should_create_new_query = True

elif input_data.override_cache:
should_create_new_query = True

elif query.cache_is_valid is False:
should_create_new_query = True

if should_create_new_query:
cached = False
query = create_query(lecture, input_data.query_string)

Expand Down
17 changes: 17 additions & 0 deletions db/cmd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@


from db.crud import get_all_lectures


def invalidate_all_query_caches():
print('invalidating all query caches')

lectures = get_all_lectures()
for lecture in lectures:
print(f'invalidating cache for lecture {lecture}: ', end='')
queries = lecture.queries()
print(f'found {len(queries)} queries', end='')
for query in queries:
query.cache_is_valid = False
query.save()
print(' done.')
15 changes: 14 additions & 1 deletion db/crud.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,15 @@ def delete_all_except_last_message_in_analysis(analysis_id: int):
# Query
def get_most_recent_query_by_sha(lecture, sha: str):
from db.models.query import Query
return Query.filter(Query.lecture_id == lecture.id).filter(Query.query_hash == sha).order_by(Query.modified_at.desc()).first() # noqa: E501
return Query.filter(
Query.lecture_id == lecture.id
).filter(
Query.query_hash == sha
).filter(
Query.cache_is_valid == True
).order_by(
Query.modified_at.desc()
).first()


def create_query(lecture, query_string: str):
Expand All @@ -112,6 +120,11 @@ def create_query(lecture, query_string: str):
return query


def find_all_queries_for_lecture(lecture):
from db.models.query import Query
return Query.select().where(Query.lecture_id == lecture.id)


# Message
def save_message_for_analysis(analysis, title: str, body: Union[str, None] = None):
from db.models.message import Message
Expand Down
20 changes: 20 additions & 0 deletions db/migrations/011_add_more_cache_is_valid_column_to_query_table.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
"""Peewee migrations -- 011_add_more_cache_is_valid_column_to_query_table.py."""
import peewee as pw
from peewee_migrate import Migrator

from db.models import Query


def migrate(migrator: Migrator, database: pw.Database, fake=False, **kwargs):
"""Write your migrations here."""
field = pw.BooleanField(null=False, default=True)
migrator.add_fields(
Query,
cache_is_valid=field,
)
migrator.run()


def rollback(migrator: Migrator, database: pw.Database, fake=False, **kwargs):
"""Write your rollback migrations here."""
migrator.remove_fields(Query, 'cache_is_valid', cascade=True)
5 changes: 5 additions & 0 deletions db/models/lecture.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from db.crud import (
find_all_courses_relations_for_lecture_id,
find_all_courses_for_lecture_id,
find_all_queries_for_lecture,
)


Expand Down Expand Up @@ -150,6 +151,10 @@ def get_last_analysis(self):
.order_by(Analysis.modified_at.desc())
.first())

def queries(self):
queries = find_all_queries_for_lecture(self)
return queries

def courses(self):
out = []

Expand Down
1 change: 1 addition & 0 deletions db/models/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ class Query(Base):
query_hash = peewee.CharField(index=True, null=False)
query_string = peewee.TextField(null=False)
count = peewee.IntegerField(null=False, default=0)
cache_is_valid = peewee.BooleanField(null=False, default=True)
response = peewee.TextField(null=True)

@staticmethod
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
'create_migration = db.migrations:create_migration',
'migrate_up = db.migrations:run_migrations',
'migrate_down = db.migrations:rollback',
'invalidate_query_cache = db.cmd:invalidate_all_query_caches',
'analysis_queues_restart = jobs:analysis_queues_restart',
'dispatch_fetch_metadata_for_all_lectures = jobs.cmd:fetch_metadata_for_all_lectures',
'dispatch_capture_preview_for_all_lectures = jobs.cmd:capture_preview_for_all_lectures',
Expand Down
32 changes: 31 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from fastapi.testclient import TestClient
from db.models import Lecture, Analysis
from fakeredis import FakeStrictRedis
from random import randbytes
from rq import Queue
import subprocess
import tempfile
Expand Down Expand Up @@ -88,6 +88,36 @@ def api_client():
return client


@pytest.fixture
def analysed_lecture():
id = 'some_id'

lecture = Lecture(
public_id=id,
language='sv',
)
lecture.save()

analysis = Analysis(lecture_id=lecture.id)
analysis.save()

save_dummy_summary_for_lecture(lecture)

return lecture


def save_dummy_summary_for_lecture(lecture: Lecture):
summary_filename = lecture.summary_filename()
if os.path.isfile(summary_filename):
os.unlink(summary_filename)

with open(summary_filename, 'w+') as file:
file.write('some summary')

lecture.summary_filepath = summary_filename
lecture.save()


@pytest.fixture
def mp4_file():
tf = tempfile.NamedTemporaryFile(
Expand Down
77 changes: 77 additions & 0 deletions tests/feature/api/test_query.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@


def test_query_can_be_made_about_lecture(mocker, api_client, analysed_lecture):
mocker.patch('tools.text.ai.gpt3', return_value='gpt-3 response')

response = api_client.post('/query', json={
'lecture_id': analysed_lecture.public_id,
'language': analysed_lecture.language,
'query_string': 'some interesting question',
})

assert response.json()['response'] == 'gpt-3 response'


def test_query_response_is_cached(mocker, api_client, analysed_lecture):
gpt3 = mocker.patch('tools.text.ai.gpt3', return_value='gpt-3 response')

def request():
return api_client.post('/query', json={
'lecture_id': analysed_lecture.public_id,
'language': analysed_lecture.language,
'query_string': 'some interesting question',
})

response = request()
response = request()

assert response.json()['response'] == 'gpt-3 response'
assert gpt3.call_count == 1


def test_query_response_cache_can_be_overridden(mocker, api_client, analysed_lecture):
gpt3 = mocker.patch('tools.text.ai.gpt3', return_value='gpt-3 response')

def request():
return api_client.post('/query', json={
'lecture_id': analysed_lecture.public_id,
'language': analysed_lecture.language,
'query_string': 'some interesting question',
'override_cache': True,
})

response = request()
response = request()

assert response.json()['response'] == 'gpt-3 response'
assert gpt3.call_count == 2


def test_query_response_cache_can_be_invalidated(mocker, api_client, analysed_lecture):
gpt3 = mocker.patch('tools.text.ai.gpt3', return_value='gpt-3 response')

def make_query(query_string: str):
return api_client.post('/query', json={
'lecture_id': analysed_lecture.public_id,
'language': analysed_lecture.language,
'query_string': query_string,
})

def make_requests():
make_query('some query')
make_query('some query')
make_query('some other query')
make_query('some third query')

make_requests()
assert gpt3.call_count == 3

# invalidate the cache
queries = analysed_lecture.queries()

for query in queries:
query.cache_is_valid = False
query.save()

make_requests()
assert gpt3.call_count == 6

0 comments on commit f1d0e4c

Please sign in to comment.