Skip to content

Commit accc0d5

Browse files
committed
document tag recommendation
1 parent 203ad44 commit accc0d5

16 files changed

+859
-85
lines changed
Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
"""vscode launcher
2+
3+
Revision ID: d7e768a5b23d
4+
Revises: f3108bb5e496
5+
Create Date: 2025-01-22 08:51:22.353360
6+
7+
"""
8+
9+
from typing import Sequence, Union
10+
11+
import sqlalchemy as sa
12+
13+
from alembic import op
14+
15+
# revision identifiers, used by Alembic.
16+
revision: str = "d7e768a5b23d"
17+
down_revision: Union[str, None] = "f3108bb5e496"
18+
branch_labels: Union[str, Sequence[str], None] = None
19+
depends_on: Union[str, Sequence[str], None] = None
20+
21+
22+
def upgrade() -> None:
23+
# ### commands auto generated by Alembic - please adjust! ###
24+
op.create_table(
25+
"documenttagrecommendation",
26+
sa.Column("task_id", sa.Integer(), nullable=False),
27+
sa.Column("model_name", sa.String(), nullable=True),
28+
sa.Column(
29+
"created", sa.DateTime(), server_default=sa.text("now()"), nullable=False
30+
),
31+
sa.Column("user_id", sa.Integer(), nullable=False),
32+
sa.Column("project_id", sa.Integer(), nullable=False),
33+
sa.ForeignKeyConstraint(["project_id"], ["project.id"], ondelete="CASCADE"),
34+
sa.ForeignKeyConstraint(["user_id"], ["user.id"], ondelete="CASCADE"),
35+
sa.PrimaryKeyConstraint("task_id"),
36+
)
37+
op.create_index(
38+
op.f("ix_documenttagrecommendation_created"),
39+
"documenttagrecommendation",
40+
["created"],
41+
unique=False,
42+
)
43+
op.create_index(
44+
op.f("ix_documenttagrecommendation_model_name"),
45+
"documenttagrecommendation",
46+
["model_name"],
47+
unique=False,
48+
)
49+
op.create_index(
50+
op.f("ix_documenttagrecommendation_project_id"),
51+
"documenttagrecommendation",
52+
["project_id"],
53+
unique=False,
54+
)
55+
op.create_index(
56+
op.f("ix_documenttagrecommendation_task_id"),
57+
"documenttagrecommendation",
58+
["task_id"],
59+
unique=False,
60+
)
61+
op.create_index(
62+
op.f("ix_documenttagrecommendation_user_id"),
63+
"documenttagrecommendation",
64+
["user_id"],
65+
unique=False,
66+
)
67+
op.create_table(
68+
"documenttagrecommendationlink",
69+
sa.Column("id", sa.Integer(), nullable=False),
70+
sa.Column("recommendation_task_id", sa.Integer(), nullable=False),
71+
sa.Column("source_document_id", sa.Integer(), nullable=False),
72+
sa.Column("predicted_tag_id", sa.Integer(), nullable=False),
73+
sa.Column("prediction_score", sa.Float(), nullable=True),
74+
sa.Column("is_accepted", sa.Boolean(), nullable=True),
75+
sa.ForeignKeyConstraint(
76+
["predicted_tag_id"], ["documenttag.id"], ondelete="CASCADE"
77+
),
78+
sa.ForeignKeyConstraint(
79+
["recommendation_task_id"],
80+
["documenttagrecommendation.task_id"],
81+
ondelete="CASCADE",
82+
),
83+
sa.ForeignKeyConstraint(
84+
["source_document_id"], ["sourcedocument.id"], ondelete="CASCADE"
85+
),
86+
sa.PrimaryKeyConstraint("id"),
87+
)
88+
op.create_index(
89+
op.f("ix_documenttagrecommendationlink_id"),
90+
"documenttagrecommendationlink",
91+
["id"],
92+
unique=False,
93+
)
94+
op.create_index(
95+
op.f("ix_documenttagrecommendationlink_is_accepted"),
96+
"documenttagrecommendationlink",
97+
["is_accepted"],
98+
unique=False,
99+
)
100+
op.create_index(
101+
op.f("ix_documenttagrecommendationlink_prediction_score"),
102+
"documenttagrecommendationlink",
103+
["prediction_score"],
104+
unique=False,
105+
)
106+
# ### end Alembic commands ###
107+
108+
109+
def downgrade() -> None:
110+
# ### commands auto generated by Alembic - please adjust! ###
111+
op.drop_index(
112+
op.f("ix_documenttagrecommendationlink_prediction_score"),
113+
table_name="documenttagrecommendationlink",
114+
)
115+
op.drop_index(
116+
op.f("ix_documenttagrecommendationlink_is_accepted"),
117+
table_name="documenttagrecommendationlink",
118+
)
119+
op.drop_index(
120+
op.f("ix_documenttagrecommendationlink_id"),
121+
table_name="documenttagrecommendationlink",
122+
)
123+
op.drop_table("documenttagrecommendationlink")
124+
op.drop_index(
125+
op.f("ix_documenttagrecommendation_user_id"),
126+
table_name="documenttagrecommendation",
127+
)
128+
op.drop_index(
129+
op.f("ix_documenttagrecommendation_task_id"),
130+
table_name="documenttagrecommendation",
131+
)
132+
op.drop_index(
133+
op.f("ix_documenttagrecommendation_project_id"),
134+
table_name="documenttagrecommendation",
135+
)
136+
op.drop_index(
137+
op.f("ix_documenttagrecommendation_model_name"),
138+
table_name="documenttagrecommendation",
139+
)
140+
op.drop_index(
141+
op.f("ix_documenttagrecommendation_created"),
142+
table_name="documenttagrecommendation",
143+
)
144+
op.drop_table("documenttagrecommendation")
145+
# ### end Alembic commands ###

backend/src/api/endpoints/document_tag_recommendation.py

Lines changed: 58 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,16 @@
1-
from fastapi import APIRouter, BackgroundTasks, Depends
1+
from typing import List
2+
3+
from fastapi import APIRouter, Depends, HTTPException
24
from sqlalchemy.orm import Session
35

46
from api.dependencies import get_current_user, get_db_session
57
from app.celery.background_jobs import (
68
prepare_and_start_document_classification_job_async,
79
)
810
from app.core.authorization.authz_user import AuthzUser
11+
from app.core.data.classification.document_classification_service import (
12+
DocumentClassificationService,
13+
)
914
from app.core.data.crud.document_tag_recommendation import (
1015
crud_document_tag_recommendation,
1116
)
@@ -15,6 +20,8 @@
1520
DocumentTagRecommendationRead,
1621
)
1722

23+
dcs: DocumentClassificationService = DocumentClassificationService()
24+
1825
router = APIRouter(
1926
prefix="/doctagrecommendation",
2027
dependencies=[Depends(get_current_user)],
@@ -32,7 +39,6 @@ def create_new_doc_tag_rec_task(
3239
db: Session = Depends(get_db_session),
3340
doc_tag_rec: DocumentTagRecommendationCreate,
3441
authz_user: AuthzUser = Depends(),
35-
background_tasks: BackgroundTasks,
3642
) -> DocumentTagRecommendationRead:
3743
authz_user.assert_in_project(doc_tag_rec.project_id)
3844

@@ -51,4 +57,53 @@ def create_new_doc_tag_rec_task(
5157
return response
5258

5359

54-
# To-Do: Update of tag recommendation
60+
@router.get(
61+
"/{task_id}",
62+
response_model=List[dict],
63+
summary="Retrieve all document tag recommendations for the given task ID.",
64+
)
65+
def get_recommendations_from_task_endpoint(task_id: int) -> List[dict]:
66+
"""
67+
Retrieves document tag recommendations based on the specified task ID.
68+
69+
### Response Format:
70+
The endpoint returns a list of recommendations, where each recommendation
71+
is represented as a dictionary with the following structure:
72+
73+
```python
74+
{
75+
"recommendation_id": int, # Unique identifier for the recommendation
76+
"source_document": str, # Name of the source document
77+
"predicted_tag_id": int, # ID of the predicted tag
78+
"predicted_tag": str, # Name of the predicted tag
79+
"prediction_score": float # Confidence score of the prediction
80+
}
81+
```
82+
83+
### Error Handling:
84+
- Returns HTTP 404 if no recommendations are found for the given task ID.
85+
"""
86+
recommendations = dcs.get_recommendations_from_task(task_id)
87+
if not recommendations:
88+
raise HTTPException(status_code=404, detail="No recommendations found.")
89+
return recommendations
90+
91+
92+
@router.patch(
93+
"/accept_recommendations",
94+
response_model=int,
95+
summary="The endpoint receives IDs of correctly tagged document recommendations and sets `is_accepted` to `true`, while setting the corresponding document tags.",
96+
)
97+
def update_document_tag_recommendations(
98+
*,
99+
accepted_recommendation_ids: List[int],
100+
) -> int:
101+
modifications = dcs.validate_recommendations(
102+
recommendation_ids=accepted_recommendation_ids
103+
)
104+
if modifications == -1:
105+
raise HTTPException(
106+
status_code=400, detail="An error occurred while updating recommendations."
107+
)
108+
109+
return modifications

backend/src/app/celery/background_jobs/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from celery import Task
55

66
from app.core.data.classification.document_classification_service import (
7-
ClassificationService as ClassificationService,
7+
DocumentClassificationService as DocumentClassificationService,
88
)
99
from app.core.data.crawler.crawler_service import CrawlerService
1010
from app.core.data.dto.crawler_job import CrawlerJobParameters, CrawlerJobRead
@@ -169,4 +169,5 @@ def prepare_and_start_document_classification_job_async(
169169
start_document_classification_job,
170170
)
171171

172+
assert isinstance(start_document_classification_job, Task), "Not a Celery Task"
172173
start_document_classification_job(task_id=task_id, project_id=project_id)
Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
from loguru import logger
22

33
from app.core.data.classification.document_classification_service import (
4-
ClassificationService,
4+
DocumentClassificationService,
55
)
66

7-
cs: ClassificationService = ClassificationService()
7+
dcs: DocumentClassificationService = DocumentClassificationService()
88

99

1010
def start_document_classification_job_(task_id: int, project_id):
1111
logger.info((f"Starting classification job with task id {task_id}",))
12-
cs.perform_dummy_classification(task_id=task_id, project_id=project_id)
12+
dcs.classify_untagged_documents(task_id=task_id, project_id=project_id)
1313

1414
logger.info(f"Classification job {task_id} has finished.")

0 commit comments

Comments
 (0)