Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions backend/src/app/core/data/crud/span_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,9 @@ def link_groups_spans_batch(
# insert links (group <-> span)
from sqlalchemy.dialects.postgresql import insert

if len(links) == 0:
return 0

insert_values = [
{"span_group_id": str(group_id), "span_annotation_id": str(span_id)}
for group_id, span_ids in links.items()
Expand Down
10 changes: 7 additions & 3 deletions backend/src/app/core/data/orm/span_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,12 @@ def sdoc_id(self):

class SpanAnnotationSpanGroupLinkTable(ORMBase):
span_annotation_id: Mapped[int] = mapped_column(
Integer, ForeignKey("spanannotation.id", ondelete="CASCADE"), primary_key=True
Integer,
ForeignKey("spanannotation.id", ondelete="CASCADE"),
primary_key=True,
)
span_group_id = mapped_column(
Integer, ForeignKey("spangroup.id", ondelete="CASCADE"), primary_key=True
span_group_id: Mapped[int] = mapped_column(
Integer,
ForeignKey("spangroup.id", ondelete="CASCADE"),
primary_key=True,
)
9 changes: 7 additions & 2 deletions backend/src/app/core/ml/coref_service.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from datetime import datetime
from typing import Dict, List, Tuple

from sqlalchemy import ColumnElement
from sqlalchemy import ColumnElement, and_
from sqlalchemy.orm import Session

from app.core.data.crud.annotation_document import crud_adoc
Expand Down Expand Up @@ -85,7 +85,11 @@ def _process_batch(
)
.outerjoin(
SourceDocumentJobStatusORM,
SourceDocumentJobStatusORM.id == SourceDocumentDataORM.id,
and_(
SourceDocumentJobStatusORM.id == SourceDocumentDataORM.id,
SourceDocumentJobStatusORM.type
== JobType.COREFERENCE_RESOLUTION,
),
full=True,
)
.filter(filter_criterion)
Expand All @@ -97,6 +101,7 @@ def _process_batch(
.limit(100)
)
sdoc_data = query.all()
sdoc_data = [doc for doc in sdoc_data if doc is not None]
num_docs = len(sdoc_data)

if num_docs == 0:
Expand Down
23 changes: 2 additions & 21 deletions backend/src/app/core/ml/ml_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
)
from app.core.data.orm.source_document_job_status import (
JobStatus,
JobType,
SourceDocumentJobStatusORM,
)
from app.core.db.redis_service import RedisService
Expand Down Expand Up @@ -86,25 +85,16 @@ def start_ml_job_sync(self, ml_job_id: str) -> MLJobRead:
QuotationAttributionParams,
), "QuotationAttributionParams expected"
recompute = mlj.parameters.specific_ml_job_parameters.recompute
valid_type = or_(
SourceDocumentJobStatusORM.type
== JobType.QUOTATION_ATTRIBUTION,
SourceDocumentJobStatusORM.type == None, # noqa: E711
)
filter_criterion = (
and_(
valid_type,
inactive_status,
or_(
timestamp_column < start_time,
timestamp_column == None, # noqa: E711
),
)
if recompute
else and_(
valid_type,
or_(unfinished_status, timestamp_column == None), # noqa: E711
) # noqa: E711
else or_(unfinished_status, timestamp_column == None) # noqa: E711
)

QuoteService().perform_quotation_detection(
Expand All @@ -126,25 +116,16 @@ def start_ml_job_sync(self, ml_job_id: str) -> MLJobRead:
CoreferenceResolutionParams,
):
recompute = mlj.parameters.specific_ml_job_parameters.recompute
valid_type = or_(
SourceDocumentJobStatusORM.type
== JobType.COREFERENCE_RESOLUTION,
SourceDocumentJobStatusORM.type == None, # noqa: E711
)
filter_criterion = (
and_(
valid_type,
inactive_status,
or_(
timestamp_column < start_time,
timestamp_column == None, # noqa: E711
),
)
if recompute
else and_(
valid_type,
or_(unfinished_status, timestamp_column == None), # noqa: E711
) # noqa: E711
else or_(unfinished_status, timestamp_column == None) # noqa: E711
)

CorefService().perform_coreference_resolution(
Expand Down
41 changes: 38 additions & 3 deletions backend/src/app/core/ml/quote_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,25 +2,29 @@
from typing import Dict, List, NamedTuple, Tuple

from loguru import logger
from sqlalchemy import ColumnElement
from sqlalchemy import ColumnElement, and_
from sqlalchemy.orm import Session

from app.core.data.crud.annotation_document import crud_adoc
from app.core.data.crud.code import crud_code
from app.core.data.crud.project_metadata import crud_project_meta
from app.core.data.crud.source_document_job_status import crud_sdoc_job_status
from app.core.data.crud.span_annotation import crud_span_anno
from app.core.data.crud.span_group import crud_span_group
from app.core.data.crud.user import SYSTEM_USER_ID
from app.core.data.doc_type import DocType
from app.core.data.dto.source_document_job_status import SourceDocumentJobStatusCreate
from app.core.data.dto.span_annotation import SpanAnnotationCreateIntern
from app.core.data.dto.span_group import SpanGroupCreateIntern
from app.core.data.meta_type import MetaType
from app.core.data.orm.annotation_document import AnnotationDocumentORM
from app.core.data.orm.source_document_data import SourceDocumentDataORM
from app.core.data.orm.source_document_job_status import (
JobStatus,
JobType,
SourceDocumentJobStatusORM,
)
from app.core.data.orm.source_document_metadata import SourceDocumentMetadataORM
from app.core.data.orm.span_annotation import SpanAnnotationORM
from app.core.db.sql_service import SQLService
from app.preprocessing.ray_model_service import RayModelService
Expand Down Expand Up @@ -68,12 +72,27 @@ def perform_quotation_detection(
addr=self._get_code_id(db, "ADDRESSEE", project_id),
cue=self._get_code_id(db, "CUE", project_id),
)
language_metadata = (
crud_project_meta.read_by_project_and_key_and_metatype_and_doctype(
db,
project_id,
"language",
MetaType.STRING.value,
DocType.text.value,
)
)
if language_metadata is None:
raise ValueError("error with project, no language metadata available")

total_processed = 0
num_processed = -1
while num_processed != 0:
num_processed = self._process_batch(
filter_criterion, project_id, codes, recompute
filter_criterion,
project_id,
codes,
language_metadata.id,
recompute,
)
total_processed = +num_processed
return total_processed
Expand All @@ -83,20 +102,36 @@ def _process_batch(
filter_criterion: ColumnElement,
project_id: int,
code: _CodeQuoteId,
language_metadata_id: int,
recompute: bool = False,
):
with self.sqls.db_session() as db:
query = (
db.query(SourceDocumentDataORM)
.join(
SourceDocumentMetadataORM,
SourceDocumentMetadataORM.source_document_id
== SourceDocumentDataORM.id,
)
.outerjoin(
SourceDocumentJobStatusORM,
SourceDocumentJobStatusORM.id == SourceDocumentDataORM.id,
and_(
SourceDocumentJobStatusORM.id == SourceDocumentDataORM.id,
SourceDocumentJobStatusORM.type
== JobType.QUOTATION_ATTRIBUTION,
),
full=True,
)
.filter(filter_criterion)
.filter(
SourceDocumentMetadataORM.project_metadata_id
== language_metadata_id,
SourceDocumentMetadataORM.str_value == "de",
)
.limit(10)
)
sdoc_data = query.all()
sdoc_data = [doc for doc in sdoc_data if doc is not None]
num_docs = len(sdoc_data)

if num_docs == 0:
Expand Down
2 changes: 1 addition & 1 deletion backend/src/app/preprocessing/ray_model_worker/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ RUN --mount=type=cache,target=/root/.cache pip install uv
COPY requirements.txt /tmp/requirements.txt

# install and cache dependencies via uv (this drastically (!) reduces build time)
RUN --mount=type=cache,target=/root/.cache uv pip install -r /tmp/requirements.txt --system
RUN --mount=type=cache,target=/root/.cache uv pip install -r /tmp/requirements.txt --system && uv pip install --system --no-build-isolation flash-attn==2.7.4.post1

# copy source code into the image
WORKDIR /dats_code_ray
Expand Down
4 changes: 2 additions & 2 deletions frontend/src/views/tools/MlAutomation/MlAutomation.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ function MlAutomation() {
requestBody: {
ml_job_type: MLJobType.QUOTATION_ATTRIBUTION,
project_id: projectId,
specific_ml_job_parameters: { recompute: false, ml_job_type: MLJobType.QUOTATION_ATTRIBUTION },
specific_ml_job_parameters: { recompute: true, ml_job_type: MLJobType.QUOTATION_ATTRIBUTION },
},
});
},
Expand Down Expand Up @@ -101,7 +101,7 @@ function MlAutomation() {
requestBody: {
ml_job_type: MLJobType.COREFERENCE_RESOLUTION,
project_id: projectId,
specific_ml_job_parameters: { recompute: false, ml_job_type: MLJobType.COREFERENCE_RESOLUTION },
specific_ml_job_parameters: { recompute: true, ml_job_type: MLJobType.COREFERENCE_RESOLUTION },
},
});
},
Expand Down
Loading