Skip to content

Commit 65c13b7

Browse files
authored
fix broken coref and quote detect (#526)
1 parent c6b01ba commit 65c13b7

File tree

7 files changed

+60
-32
lines changed

7 files changed

+60
-32
lines changed

backend/src/app/core/data/crud/span_group.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,9 @@ def link_groups_spans_batch(
6262
# insert links (group <-> span)
6363
from sqlalchemy.dialects.postgresql import insert
6464

65+
if len(links) == 0:
66+
return 0
67+
6568
insert_values = [
6669
{"span_group_id": str(group_id), "span_annotation_id": str(span_id)}
6770
for group_id, span_ids in links.items()

backend/src/app/core/data/orm/span_group.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,12 @@ def sdoc_id(self):
5959

6060
class SpanAnnotationSpanGroupLinkTable(ORMBase):
6161
span_annotation_id: Mapped[int] = mapped_column(
62-
Integer, ForeignKey("spanannotation.id", ondelete="CASCADE"), primary_key=True
62+
Integer,
63+
ForeignKey("spanannotation.id", ondelete="CASCADE"),
64+
primary_key=True,
6365
)
64-
span_group_id = mapped_column(
65-
Integer, ForeignKey("spangroup.id", ondelete="CASCADE"), primary_key=True
66+
span_group_id: Mapped[int] = mapped_column(
67+
Integer,
68+
ForeignKey("spangroup.id", ondelete="CASCADE"),
69+
primary_key=True,
6670
)

backend/src/app/core/ml/coref_service.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from datetime import datetime
22
from typing import Dict, List, Tuple
33

4-
from sqlalchemy import ColumnElement
4+
from sqlalchemy import ColumnElement, and_
55
from sqlalchemy.orm import Session
66

77
from app.core.data.crud.annotation_document import crud_adoc
@@ -85,7 +85,11 @@ def _process_batch(
8585
)
8686
.outerjoin(
8787
SourceDocumentJobStatusORM,
88-
SourceDocumentJobStatusORM.id == SourceDocumentDataORM.id,
88+
and_(
89+
SourceDocumentJobStatusORM.id == SourceDocumentDataORM.id,
90+
SourceDocumentJobStatusORM.type
91+
== JobType.COREFERENCE_RESOLUTION,
92+
),
8993
full=True,
9094
)
9195
.filter(filter_criterion)
@@ -97,6 +101,7 @@ def _process_batch(
97101
.limit(100)
98102
)
99103
sdoc_data = query.all()
104+
sdoc_data = [doc for doc in sdoc_data if doc is not None]
100105
num_docs = len(sdoc_data)
101106

102107
if num_docs == 0:

backend/src/app/core/ml/ml_service.py

Lines changed: 2 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
)
1717
from app.core.data.orm.source_document_job_status import (
1818
JobStatus,
19-
JobType,
2019
SourceDocumentJobStatusORM,
2120
)
2221
from app.core.db.redis_service import RedisService
@@ -86,25 +85,16 @@ def start_ml_job_sync(self, ml_job_id: str) -> MLJobRead:
8685
QuotationAttributionParams,
8786
), "QuotationAttributionParams expected"
8887
recompute = mlj.parameters.specific_ml_job_parameters.recompute
89-
valid_type = or_(
90-
SourceDocumentJobStatusORM.type
91-
== JobType.QUOTATION_ATTRIBUTION,
92-
SourceDocumentJobStatusORM.type == None, # noqa: E711
93-
)
9488
filter_criterion = (
9589
and_(
96-
valid_type,
9790
inactive_status,
9891
or_(
9992
timestamp_column < start_time,
10093
timestamp_column == None, # noqa: E711
10194
),
10295
)
10396
if recompute
104-
else and_(
105-
valid_type,
106-
or_(unfinished_status, timestamp_column == None), # noqa: E711
107-
) # noqa: E711
97+
else or_(unfinished_status, timestamp_column == None) # noqa: E711
10898
)
10999

110100
QuoteService().perform_quotation_detection(
@@ -126,25 +116,16 @@ def start_ml_job_sync(self, ml_job_id: str) -> MLJobRead:
126116
CoreferenceResolutionParams,
127117
):
128118
recompute = mlj.parameters.specific_ml_job_parameters.recompute
129-
valid_type = or_(
130-
SourceDocumentJobStatusORM.type
131-
== JobType.COREFERENCE_RESOLUTION,
132-
SourceDocumentJobStatusORM.type == None, # noqa: E711
133-
)
134119
filter_criterion = (
135120
and_(
136-
valid_type,
137121
inactive_status,
138122
or_(
139123
timestamp_column < start_time,
140124
timestamp_column == None, # noqa: E711
141125
),
142126
)
143127
if recompute
144-
else and_(
145-
valid_type,
146-
or_(unfinished_status, timestamp_column == None), # noqa: E711
147-
) # noqa: E711
128+
else or_(unfinished_status, timestamp_column == None) # noqa: E711
148129
)
149130

150131
CorefService().perform_coreference_resolution(

backend/src/app/core/ml/quote_service.py

Lines changed: 38 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,25 +2,29 @@
22
from typing import Dict, List, NamedTuple, Tuple
33

44
from loguru import logger
5-
from sqlalchemy import ColumnElement
5+
from sqlalchemy import ColumnElement, and_
66
from sqlalchemy.orm import Session
77

88
from app.core.data.crud.annotation_document import crud_adoc
99
from app.core.data.crud.code import crud_code
10+
from app.core.data.crud.project_metadata import crud_project_meta
1011
from app.core.data.crud.source_document_job_status import crud_sdoc_job_status
1112
from app.core.data.crud.span_annotation import crud_span_anno
1213
from app.core.data.crud.span_group import crud_span_group
1314
from app.core.data.crud.user import SYSTEM_USER_ID
15+
from app.core.data.doc_type import DocType
1416
from app.core.data.dto.source_document_job_status import SourceDocumentJobStatusCreate
1517
from app.core.data.dto.span_annotation import SpanAnnotationCreateIntern
1618
from app.core.data.dto.span_group import SpanGroupCreateIntern
19+
from app.core.data.meta_type import MetaType
1720
from app.core.data.orm.annotation_document import AnnotationDocumentORM
1821
from app.core.data.orm.source_document_data import SourceDocumentDataORM
1922
from app.core.data.orm.source_document_job_status import (
2023
JobStatus,
2124
JobType,
2225
SourceDocumentJobStatusORM,
2326
)
27+
from app.core.data.orm.source_document_metadata import SourceDocumentMetadataORM
2428
from app.core.data.orm.span_annotation import SpanAnnotationORM
2529
from app.core.db.sql_service import SQLService
2630
from app.preprocessing.ray_model_service import RayModelService
@@ -68,12 +72,27 @@ def perform_quotation_detection(
6872
addr=self._get_code_id(db, "ADDRESSEE", project_id),
6973
cue=self._get_code_id(db, "CUE", project_id),
7074
)
75+
language_metadata = (
76+
crud_project_meta.read_by_project_and_key_and_metatype_and_doctype(
77+
db,
78+
project_id,
79+
"language",
80+
MetaType.STRING.value,
81+
DocType.text.value,
82+
)
83+
)
84+
if language_metadata is None:
85+
raise ValueError("error with project, no language metadata available")
7186

7287
total_processed = 0
7388
num_processed = -1
7489
while num_processed != 0:
7590
num_processed = self._process_batch(
76-
filter_criterion, project_id, codes, recompute
91+
filter_criterion,
92+
project_id,
93+
codes,
94+
language_metadata.id,
95+
recompute,
7796
)
7897
total_processed = +num_processed
7998
return total_processed
@@ -83,20 +102,36 @@ def _process_batch(
83102
filter_criterion: ColumnElement,
84103
project_id: int,
85104
code: _CodeQuoteId,
105+
language_metadata_id: int,
86106
recompute: bool = False,
87107
):
88108
with self.sqls.db_session() as db:
89109
query = (
90110
db.query(SourceDocumentDataORM)
111+
.join(
112+
SourceDocumentMetadataORM,
113+
SourceDocumentMetadataORM.source_document_id
114+
== SourceDocumentDataORM.id,
115+
)
91116
.outerjoin(
92117
SourceDocumentJobStatusORM,
93-
SourceDocumentJobStatusORM.id == SourceDocumentDataORM.id,
118+
and_(
119+
SourceDocumentJobStatusORM.id == SourceDocumentDataORM.id,
120+
SourceDocumentJobStatusORM.type
121+
== JobType.QUOTATION_ATTRIBUTION,
122+
),
94123
full=True,
95124
)
96125
.filter(filter_criterion)
126+
.filter(
127+
SourceDocumentMetadataORM.project_metadata_id
128+
== language_metadata_id,
129+
SourceDocumentMetadataORM.str_value == "de",
130+
)
97131
.limit(10)
98132
)
99133
sdoc_data = query.all()
134+
sdoc_data = [doc for doc in sdoc_data if doc is not None]
100135
num_docs = len(sdoc_data)
101136

102137
if num_docs == 0:

backend/src/app/preprocessing/ray_model_worker/Dockerfile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ RUN --mount=type=cache,target=/root/.cache pip install uv
1818
COPY requirements.txt /tmp/requirements.txt
1919

2020
# install and cache dependencies via uv (this drastically (!) reduces build time)
21-
RUN --mount=type=cache,target=/root/.cache uv pip install -r /tmp/requirements.txt --system
21+
RUN --mount=type=cache,target=/root/.cache uv pip install -r /tmp/requirements.txt --system && uv pip install --system --no-build-isolation flash-attn==2.7.4.post1
2222

2323
# copy source code into the image
2424
WORKDIR /dats_code_ray

frontend/src/views/tools/MlAutomation/MlAutomation.tsx

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ function MlAutomation() {
4949
requestBody: {
5050
ml_job_type: MLJobType.QUOTATION_ATTRIBUTION,
5151
project_id: projectId,
52-
specific_ml_job_parameters: { recompute: false, ml_job_type: MLJobType.QUOTATION_ATTRIBUTION },
52+
specific_ml_job_parameters: { recompute: true, ml_job_type: MLJobType.QUOTATION_ATTRIBUTION },
5353
},
5454
});
5555
},
@@ -101,7 +101,7 @@ function MlAutomation() {
101101
requestBody: {
102102
ml_job_type: MLJobType.COREFERENCE_RESOLUTION,
103103
project_id: projectId,
104-
specific_ml_job_parameters: { recompute: false, ml_job_type: MLJobType.COREFERENCE_RESOLUTION },
104+
specific_ml_job_parameters: { recompute: true, ml_job_type: MLJobType.COREFERENCE_RESOLUTION },
105105
},
106106
});
107107
},

0 commit comments

Comments
 (0)