Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
"""Add document tag recommendation

Revision ID: 434becce1fad
Revises: 241cfa625db2
Create Date: 2025-02-18 14:33:08.794269

"""

from typing import Sequence, Union

import sqlalchemy as sa

from alembic import op

# revision identifiers, used by Alembic.
revision: str = "434becce1fad"
down_revision: Union[str, None] = "241cfa625db2"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None


def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table(
"documenttagrecommendation",
sa.Column("task_id", sa.Integer(), nullable=False),
sa.Column("model_name", sa.String(), nullable=True),
sa.Column(
"created", sa.DateTime(), server_default=sa.text("now()"), nullable=False
),
sa.Column("user_id", sa.Integer(), nullable=False),
sa.Column("project_id", sa.Integer(), nullable=False),
sa.ForeignKeyConstraint(["project_id"], ["project.id"], ondelete="CASCADE"),
sa.ForeignKeyConstraint(["user_id"], ["user.id"], ondelete="CASCADE"),
sa.PrimaryKeyConstraint("task_id"),
)
op.create_index(
op.f("ix_documenttagrecommendation_created"),
"documenttagrecommendation",
["created"],
unique=False,
)
op.create_index(
op.f("ix_documenttagrecommendation_model_name"),
"documenttagrecommendation",
["model_name"],
unique=False,
)
op.create_index(
op.f("ix_documenttagrecommendation_project_id"),
"documenttagrecommendation",
["project_id"],
unique=False,
)
op.create_index(
op.f("ix_documenttagrecommendation_task_id"),
"documenttagrecommendation",
["task_id"],
unique=False,
)
op.create_index(
op.f("ix_documenttagrecommendation_user_id"),
"documenttagrecommendation",
["user_id"],
unique=False,
)
op.create_table(
"documenttagrecommendationlink",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("recommendation_task_id", sa.Integer(), nullable=False),
sa.Column("source_document_id", sa.Integer(), nullable=False),
sa.Column("predicted_tag_id", sa.Integer(), nullable=False),
sa.Column("prediction_score", sa.Float(), nullable=True),
sa.Column("is_accepted", sa.Boolean(), nullable=True),
sa.ForeignKeyConstraint(
["predicted_tag_id"], ["documenttag.id"], ondelete="CASCADE"
),
sa.ForeignKeyConstraint(
["recommendation_task_id"],
["documenttagrecommendation.task_id"],
ondelete="CASCADE",
),
sa.ForeignKeyConstraint(
["source_document_id"], ["sourcedocument.id"], ondelete="CASCADE"
),
sa.PrimaryKeyConstraint("id"),
)
op.create_index(
op.f("ix_documenttagrecommendationlink_id"),
"documenttagrecommendationlink",
["id"],
unique=False,
)
op.create_index(
op.f("ix_documenttagrecommendationlink_is_accepted"),
"documenttagrecommendationlink",
["is_accepted"],
unique=False,
)
op.create_index(
op.f("ix_documenttagrecommendationlink_prediction_score"),
"documenttagrecommendationlink",
["prediction_score"],
unique=False,
)
# ### end Alembic commands ###


def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_index(
op.f("ix_documenttagrecommendationlink_prediction_score"),
table_name="documenttagrecommendationlink",
)
op.drop_index(
op.f("ix_documenttagrecommendationlink_is_accepted"),
table_name="documenttagrecommendationlink",
)
op.drop_index(
op.f("ix_documenttagrecommendationlink_id"),
table_name="documenttagrecommendationlink",
)
op.drop_table("documenttagrecommendationlink")
op.drop_index(
op.f("ix_documenttagrecommendation_user_id"),
table_name="documenttagrecommendation",
)
op.drop_index(
op.f("ix_documenttagrecommendation_task_id"),
table_name="documenttagrecommendation",
)
op.drop_index(
op.f("ix_documenttagrecommendation_project_id"),
table_name="documenttagrecommendation",
)
op.drop_index(
op.f("ix_documenttagrecommendation_model_name"),
table_name="documenttagrecommendation",
)
op.drop_index(
op.f("ix_documenttagrecommendation_created"),
table_name="documenttagrecommendation",
)
op.drop_table("documenttagrecommendation")
# ### end Alembic commands ###
109 changes: 109 additions & 0 deletions backend/src/api/endpoints/document_tag_recommendation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
from typing import List

from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy.orm import Session

from api.dependencies import get_current_user, get_db_session
from app.celery.background_jobs import (
prepare_and_start_document_classification_job_async,
)
from app.core.authorization.authz_user import AuthzUser
from app.core.data.classification.document_classification_service import (
DocumentClassificationService,
)
from app.core.data.crud.document_tag_recommendation import (
crud_document_tag_recommendation,
)
from app.core.data.dto.document_tag_recommendation import (
DocumentTagRecommendationCreate,
DocumentTagRecommendationCreateIntern,
DocumentTagRecommendationRead,
)

dcs: DocumentClassificationService = DocumentClassificationService()

router = APIRouter(
prefix="/doctagrecommendation",
dependencies=[Depends(get_current_user)],
tags=["documentTagRecommendation"],
)


@router.put(
"",
response_model=DocumentTagRecommendationRead,
summary="Creates a new Document Tag Recommendation Task and returns it.",
)
def create_new_doc_tag_rec_task(
*,
db: Session = Depends(get_db_session),
doc_tag_rec: DocumentTagRecommendationCreate,
authz_user: AuthzUser = Depends(),
) -> DocumentTagRecommendationRead:
authz_user.assert_in_project(doc_tag_rec.project_id)

db_obj = crud_document_tag_recommendation.create(
db=db,
create_dto=DocumentTagRecommendationCreateIntern(
project_id=doc_tag_rec.project_id, user_id=authz_user.user.id
),
)
response = DocumentTagRecommendationRead.model_validate(db_obj)
prepare_and_start_document_classification_job_async(
db_obj.task_id,
doc_tag_rec.project_id,
)

return response


@router.get(
"/{task_id}",
response_model=List[dict],
summary="Retrieve all document tag recommendations for the given task ID.",
)
def get_recommendations_from_task_endpoint(task_id: int) -> List[dict]:
"""
Retrieves document tag recommendations based on the specified task ID.

### Response Format:
The endpoint returns a list of recommendations, where each recommendation
is represented as a dictionary with the following structure:

```python
{
"recommendation_id": int, # Unique identifier for the recommendation
"source_document": str, # Name of the source document
"predicted_tag_id": int, # ID of the predicted tag
"predicted_tag": str, # Name of the predicted tag
"prediction_score": float # Confidence score of the prediction
}
```

### Error Handling:
- Returns HTTP 404 if no recommendations are found for the given task ID.
"""
recommendations = dcs.get_recommendations_from_task(task_id)
if not recommendations:
raise HTTPException(status_code=404, detail="No recommendations found.")
return recommendations


@router.patch(
"/accept_recommendations",
response_model=int,
summary="The endpoint receives IDs of correctly tagged document recommendations and sets `is_accepted` to `true`, while setting the corresponding document tags.",
)
def update_document_tag_recommendations(
*,
accepted_recommendation_ids: List[int],
) -> int:
modifications = dcs.validate_recommendations(
recommendation_ids=accepted_recommendation_ids
)
if modifications == -1:
raise HTTPException(
status_code=400, detail="An error occurred while updating recommendations."
)

return modifications
17 changes: 17 additions & 0 deletions backend/src/app/celery/background_jobs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,14 @@
from celery import Task, group
from celery.result import GroupResult

from app.core.data.classification.document_classification_service import (
DocumentClassificationService as DocumentClassificationService,
)
from app.core.data.crawler.crawler_service import CrawlerService
from app.core.data.dto.crawler_job import CrawlerJobParameters, CrawlerJobRead
from app.core.data.dto.document_tag_recommendation import (
DocumentTagRecommendationRead as DocumentTagRecommendationRead,
)
from app.core.data.dto.export_job import ExportJobParameters, ExportJobRead
from app.core.data.dto.import_job import ImportJobParameters, ImportJobRead
from app.core.data.dto.llm_job import LLMJobParameters2, LLMJobRead
Expand Down Expand Up @@ -172,3 +178,14 @@ def execute_video_preprocessing_pipeline_apply_async(

for cargo in cargos:
execute_video_preprocessing_pipeline_task.apply_async(kwargs={"cargo": cargo})


def prepare_and_start_document_classification_job_async(
task_id: int, project_id: int
) -> None:
from app.celery.background_jobs.tasks import (
start_document_classification_job,
)

assert isinstance(start_document_classification_job, Task), "Not a Celery Task"
start_document_classification_job(task_id=task_id, project_id=project_id)
14 changes: 14 additions & 0 deletions backend/src/app/celery/background_jobs/document_classification.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from loguru import logger

from app.core.data.classification.document_classification_service import (
DocumentClassificationService,
)

dcs: DocumentClassificationService = DocumentClassificationService()


def start_document_classification_job_(task_id: int, project_id):
logger.info((f"Starting classification job with task id {task_id}",))
dcs.classify_untagged_documents(task_id=task_id, project_id=project_id)

logger.info(f"Classification job {task_id} has finished.")
12 changes: 12 additions & 0 deletions backend/src/app/celery/background_jobs/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@

from app.celery.background_jobs.cota import start_cota_refinement_job_
from app.celery.background_jobs.crawl import start_crawler_job_
from app.celery.background_jobs.document_classification import (
start_document_classification_job_,
)
from app.celery.background_jobs.export import start_export_job_
from app.celery.background_jobs.import_ import start_import_job_
from app.celery.background_jobs.llm import start_llm_job_
Expand Down Expand Up @@ -118,3 +121,12 @@ def import_uploaded_archive(archive_file_path_and_project_id: Tuple[Path, int])
# we need a tuple to chain the task since chaining only allows for one return object
archive_file_path, project_id = archive_file_path_and_project_id
import_uploaded_archive_(archive_file_path=archive_file_path, project_id=project_id)


@celery_worker.task(
acks_late=True,
autoretry_for=(Exception,),
retry_kwargs={"max_retries": 5, "countdown": 5},
)
def start_document_classification_job(task_id: int, project_id: int) -> None:
start_document_classification_job_(task_id=task_id, project_id=project_id)
Loading