Skip to content

Commit 34abeda

Browse files
Doma1612Dominik Martens
authored andcommitted
Integrate recommendation backend (#504)
* implement tag-recommendation data model * implement crud and dto * implement dummy classification * document tag recommendation * update alembic version * remove faulty alembic version * Address Tims findings * update None value query * revert changes * remove dto transformation --------- Co-authored-by: Dominik Martens <[email protected]>
1 parent dd0ca1d commit 34abeda

24 files changed

+1228
-33
lines changed
Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
"""Add document tag recommendation
2+
3+
Revision ID: 523e193d91a0
4+
Revises: 241cfa625db2
5+
Create Date: 2025-02-19 15:02:00.342999
6+
7+
"""
8+
9+
from typing import Sequence, Union
10+
11+
import sqlalchemy as sa
12+
13+
from alembic import op
14+
15+
# revision identifiers, used by Alembic.
16+
revision: str = "523e193d91a0"
17+
down_revision: Union[str, None] = "241cfa625db2"
18+
branch_labels: Union[str, Sequence[str], None] = None
19+
depends_on: Union[str, Sequence[str], None] = None
20+
21+
22+
def upgrade() -> None:
23+
# ### commands auto generated by Alembic - please adjust! ###
24+
op.create_table(
25+
"documenttagrecommendationjob",
26+
sa.Column("task_id", sa.Integer(), nullable=False),
27+
sa.Column("model_name", sa.String(), nullable=True),
28+
sa.Column(
29+
"created", sa.DateTime(), server_default=sa.text("now()"), nullable=False
30+
),
31+
sa.Column("user_id", sa.Integer(), nullable=False),
32+
sa.Column("project_id", sa.Integer(), nullable=False),
33+
sa.ForeignKeyConstraint(["project_id"], ["project.id"], ondelete="CASCADE"),
34+
sa.ForeignKeyConstraint(["user_id"], ["user.id"], ondelete="CASCADE"),
35+
sa.PrimaryKeyConstraint("task_id"),
36+
)
37+
op.create_index(
38+
op.f("ix_documenttagrecommendationjob_created"),
39+
"documenttagrecommendationjob",
40+
["created"],
41+
unique=False,
42+
)
43+
op.create_index(
44+
op.f("ix_documenttagrecommendationjob_model_name"),
45+
"documenttagrecommendationjob",
46+
["model_name"],
47+
unique=False,
48+
)
49+
op.create_index(
50+
op.f("ix_documenttagrecommendationjob_project_id"),
51+
"documenttagrecommendationjob",
52+
["project_id"],
53+
unique=False,
54+
)
55+
op.create_index(
56+
op.f("ix_documenttagrecommendationjob_task_id"),
57+
"documenttagrecommendationjob",
58+
["task_id"],
59+
unique=False,
60+
)
61+
op.create_index(
62+
op.f("ix_documenttagrecommendationjob_user_id"),
63+
"documenttagrecommendationjob",
64+
["user_id"],
65+
unique=False,
66+
)
67+
op.create_table(
68+
"documenttagrecommendationlink",
69+
sa.Column("id", sa.Integer(), nullable=False),
70+
sa.Column("recommendation_task_id", sa.Integer(), nullable=False),
71+
sa.Column("source_document_id", sa.Integer(), nullable=False),
72+
sa.Column("predicted_tag_id", sa.Integer(), nullable=False),
73+
sa.Column("prediction_score", sa.Float(), nullable=True),
74+
sa.Column("is_accepted", sa.Boolean(), nullable=True),
75+
sa.ForeignKeyConstraint(
76+
["predicted_tag_id"], ["documenttag.id"], ondelete="CASCADE"
77+
),
78+
sa.ForeignKeyConstraint(
79+
["recommendation_task_id"],
80+
["documenttagrecommendationjob.task_id"],
81+
ondelete="CASCADE",
82+
),
83+
sa.ForeignKeyConstraint(
84+
["source_document_id"], ["sourcedocument.id"], ondelete="CASCADE"
85+
),
86+
sa.PrimaryKeyConstraint("id"),
87+
)
88+
op.create_index(
89+
op.f("ix_documenttagrecommendationlink_id"),
90+
"documenttagrecommendationlink",
91+
["id"],
92+
unique=False,
93+
)
94+
op.create_index(
95+
op.f("ix_documenttagrecommendationlink_is_accepted"),
96+
"documenttagrecommendationlink",
97+
["is_accepted"],
98+
unique=False,
99+
)
100+
op.create_index(
101+
op.f("ix_documenttagrecommendationlink_prediction_score"),
102+
"documenttagrecommendationlink",
103+
["prediction_score"],
104+
unique=False,
105+
)
106+
# ### end Alembic commands ###
107+
108+
109+
def downgrade() -> None:
110+
# ### commands auto generated by Alembic - please adjust! ###
111+
op.drop_index(
112+
op.f("ix_documenttagrecommendationlink_prediction_score"),
113+
table_name="documenttagrecommendationlink",
114+
)
115+
op.drop_index(
116+
op.f("ix_documenttagrecommendationlink_is_accepted"),
117+
table_name="documenttagrecommendationlink",
118+
)
119+
op.drop_index(
120+
op.f("ix_documenttagrecommendationlink_id"),
121+
table_name="documenttagrecommendationlink",
122+
)
123+
op.drop_table("documenttagrecommendationlink")
124+
op.drop_index(
125+
op.f("ix_documenttagrecommendationjob_user_id"),
126+
table_name="documenttagrecommendationjob",
127+
)
128+
op.drop_index(
129+
op.f("ix_documenttagrecommendationjob_task_id"),
130+
table_name="documenttagrecommendationjob",
131+
)
132+
op.drop_index(
133+
op.f("ix_documenttagrecommendationjob_project_id"),
134+
table_name="documenttagrecommendationjob",
135+
)
136+
op.drop_index(
137+
op.f("ix_documenttagrecommendationjob_model_name"),
138+
table_name="documenttagrecommendationjob",
139+
)
140+
op.drop_index(
141+
op.f("ix_documenttagrecommendationjob_created"),
142+
table_name="documenttagrecommendationjob",
143+
)
144+
op.drop_table("documenttagrecommendationjob")
145+
# ### end Alembic commands ###
Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
from typing import List
2+
3+
from fastapi import APIRouter, Depends, HTTPException
4+
from sqlalchemy.orm import Session
5+
6+
from api.dependencies import get_current_user, get_db_session
7+
from app.celery.background_jobs import (
8+
prepare_and_start_document_classification_job_async,
9+
)
10+
from app.core.authorization.authz_user import AuthzUser
11+
from app.core.data.classification.document_classification_service import (
12+
DocumentClassificationService,
13+
)
14+
from app.core.data.crud.document_tag_recommendation import (
15+
crud_document_tag_recommendation,
16+
)
17+
from app.core.data.dto.document_tag_recommendation import (
18+
DocumentTagRecommendationJobCreate,
19+
DocumentTagRecommendationJobCreateIntern,
20+
DocumentTagRecommendationJobRead,
21+
DocumentTagRecommendationSummary,
22+
)
23+
24+
dcs: DocumentClassificationService = DocumentClassificationService()
25+
26+
router = APIRouter(
27+
prefix="/doctagrecommendationjob",
28+
dependencies=[Depends(get_current_user)],
29+
tags=["documentTagRecommendationJob"],
30+
)
31+
32+
33+
@router.put(
34+
"",
35+
response_model=DocumentTagRecommendationJobRead,
36+
summary="Creates a new Document Tag Recommendation Task and returns it.",
37+
)
38+
def create_new_doc_tag_rec_task(
39+
*,
40+
db: Session = Depends(get_db_session),
41+
doc_tag_rec: DocumentTagRecommendationJobCreate,
42+
authz_user: AuthzUser = Depends(),
43+
) -> DocumentTagRecommendationJobRead:
44+
authz_user.assert_in_project(doc_tag_rec.project_id)
45+
46+
db_obj = crud_document_tag_recommendation.create(
47+
db=db,
48+
create_dto=DocumentTagRecommendationJobCreateIntern(
49+
project_id=doc_tag_rec.project_id, user_id=authz_user.user.id
50+
),
51+
)
52+
response = DocumentTagRecommendationJobRead.model_validate(db_obj)
53+
prepare_and_start_document_classification_job_async(
54+
db_obj.task_id,
55+
doc_tag_rec.project_id,
56+
)
57+
58+
return response
59+
60+
61+
@router.get(
62+
"/{task_id}",
63+
response_model=List[DocumentTagRecommendationSummary],
64+
summary="Retrieve all document tag recommendations for the given task ID.",
65+
)
66+
def get_recommendations_from_task_endpoint(
67+
task_id: int,
68+
) -> List[DocumentTagRecommendationSummary]:
69+
"""
70+
Retrieves document tag recommendations based on the specified task ID.
71+
72+
### Response Format:
73+
The endpoint returns a list of recommendations, where each recommendation
74+
is represented as a DocumentTagRecommendationSummary DTO with the following structure:
75+
76+
```python
77+
{
78+
"recommendation_id": int, # Unique identifier for the recommendation
79+
"source_document": str, # Name of the source document
80+
"predicted_tag_id": int, # ID of the predicted tag
81+
"predicted_tag": str, # Name of the predicted tag
82+
"prediction_score": float # Confidence score of the prediction
83+
}
84+
```
85+
86+
### Error Handling:
87+
- Returns HTTP 404 if no recommendations are found for the given task ID.
88+
"""
89+
recommendations = dcs.get_recommendations_from_task(task_id)
90+
if not recommendations:
91+
raise HTTPException(status_code=404, detail="No recommendations found.")
92+
return recommendations
93+
94+
95+
@router.patch(
96+
"/update_recommendations",
97+
response_model=int,
98+
summary="The endpoint receives IDs of wrongly and correctly tagged document recommendations and sets `is_accepted` to `true` or `false`, while setting the corresponding document tags if `true`.",
99+
)
100+
def update_recommendations(
101+
*,
102+
accepted_recommendation_ids: List[int],
103+
declined_recommendation_ids: List[int],
104+
) -> int:
105+
modifications = dcs.validate_recommendations(
106+
accepted_recommendation_ids=accepted_recommendation_ids,
107+
declined_recommendation_ids=declined_recommendation_ids,
108+
)
109+
if modifications == -1:
110+
raise HTTPException(
111+
status_code=400, detail="An error occurred while updating recommendations."
112+
)
113+
114+
return modifications

backend/src/app/celery/background_jobs/__init__.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,14 @@
44
from celery import Task, group
55
from celery.result import GroupResult
66

7+
from app.core.data.classification.document_classification_service import (
8+
DocumentClassificationService as DocumentClassificationService,
9+
)
710
from app.core.data.crawler.crawler_service import CrawlerService
811
from app.core.data.dto.crawler_job import CrawlerJobParameters, CrawlerJobRead
12+
from app.core.data.dto.document_tag_recommendation import (
13+
DocumentTagRecommendationJobRead as DocumentTagRecommendationJobRead,
14+
)
915
from app.core.data.dto.export_job import ExportJobParameters, ExportJobRead
1016
from app.core.data.dto.import_job import ImportJobParameters, ImportJobRead
1117
from app.core.data.dto.llm_job import LLMJobParameters2, LLMJobRead
@@ -172,3 +178,14 @@ def execute_video_preprocessing_pipeline_apply_async(
172178

173179
for cargo in cargos:
174180
execute_video_preprocessing_pipeline_task.apply_async(kwargs={"cargo": cargo})
181+
182+
183+
def prepare_and_start_document_classification_job_async(
184+
task_id: int, project_id: int
185+
) -> None:
186+
from app.celery.background_jobs.tasks import (
187+
start_document_classification_job,
188+
)
189+
190+
assert isinstance(start_document_classification_job, Task), "Not a Celery Task"
191+
start_document_classification_job(task_id=task_id, project_id=project_id)
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
from loguru import logger
2+
3+
from app.core.data.classification.document_classification_service import (
4+
DocumentClassificationService,
5+
)
6+
7+
dcs: DocumentClassificationService = DocumentClassificationService()
8+
9+
10+
def start_document_classification_job_(task_id: int, project_id):
11+
logger.info((f"Starting classification job with task id {task_id}",))
12+
dcs.classify_untagged_documents(task_id=task_id, project_id=project_id)
13+
14+
logger.info(f"Classification job {task_id} has finished.")

backend/src/app/celery/background_jobs/tasks.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@
33

44
from app.celery.background_jobs.cota import start_cota_refinement_job_
55
from app.celery.background_jobs.crawl import start_crawler_job_
6+
from app.celery.background_jobs.document_classification import (
7+
start_document_classification_job_,
8+
)
69
from app.celery.background_jobs.export import start_export_job_
710
from app.celery.background_jobs.import_ import start_import_job_
811
from app.celery.background_jobs.llm import start_llm_job_
@@ -118,3 +121,12 @@ def import_uploaded_archive(archive_file_path_and_project_id: Tuple[Path, int])
118121
# we need a tuple to chain the task since chaining only allows for one return object
119122
archive_file_path, project_id = archive_file_path_and_project_id
120123
import_uploaded_archive_(archive_file_path=archive_file_path, project_id=project_id)
124+
125+
126+
@celery_worker.task(
127+
acks_late=True,
128+
autoretry_for=(Exception,),
129+
retry_kwargs={"max_retries": 5, "countdown": 5},
130+
)
131+
def start_document_classification_job(task_id: int, project_id: int) -> None:
132+
start_document_classification_job_(task_id=task_id, project_id=project_id)

0 commit comments

Comments
 (0)