Skip to content

Commit 906657b

Browse files
authored
ML job to recompute & reindex sentence embeddings (#552)
* ML job to recompute & reindex sentence embeddings
1 parent 0174963 commit 906657b

File tree

13 files changed

+246
-27
lines changed

13 files changed

+246
-27
lines changed

backend/src/app/core/data/crud/source_document_job_status.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,9 @@
1+
from typing import List
2+
3+
from fastapi.encoders import jsonable_encoder
4+
from sqlalchemy import and_, false, or_
5+
from sqlalchemy.orm import Session
6+
17
from app.core.data.crud.crud_base import CRUDBase
28
from app.core.data.dto.source_document_job_status import (
39
SourceDocumentJobStatusCreate,
@@ -13,7 +19,26 @@ class CRUDSourceDocumentJobStatus(
1319
SourceDocumentJobStatusUpdate,
1420
]
1521
):
16-
pass
22+
def create_multi(
23+
self, db: Session, *, create_dtos: List[SourceDocumentJobStatusCreate]
24+
) -> List[SourceDocumentJobStatusORM]:
25+
db_objs = [self.model(**jsonable_encoder(x)) for x in create_dtos]
26+
q = db.query(self.model).where(
27+
or_(
28+
false(),
29+
*[
30+
and_(
31+
SourceDocumentJobStatusORM.id == x.id,
32+
SourceDocumentJobStatusORM.type == x.type,
33+
)
34+
for x in create_dtos
35+
],
36+
)
37+
)
38+
q.delete()
39+
db.add_all(db_objs)
40+
db.commit()
41+
return db_objs
1742

1843

1944
crud_sdoc_job_status = CRUDSourceDocumentJobStatus(SourceDocumentJobStatusORM)

backend/src/app/core/data/dto/ml_job.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ class MLJobType(StrEnum):
1515
DOC_TAG_RECOMMENDATION = "DOC_TAG_RECOMMENDATION"
1616
COREFERENCE_RESOLUTION = "COREFERENCE_RESOLUTION"
1717
DOCUMENT_EMBEDDING = "DOCUMENT_EMBEDDING"
18+
SENTENCE_EMBEDDING = "SENTENCE_EMBEDDING"
1819

1920

2021
class QuotationAttributionParams(BaseModel):
@@ -53,6 +54,13 @@ class DocumentEmbeddingParams(BaseModel):
5354
)
5455

5556

57+
class SentenceEmbeddingParams(BaseModel):
58+
ml_job_type: Literal[MLJobType.SENTENCE_EMBEDDING]
59+
recompute: bool = Field(
60+
default=False, description="Whether to recompute already processed documents"
61+
)
62+
63+
5664
class MLJobParameters(BaseModel):
5765
ml_job_type: MLJobType = Field(description="The type of the MLJob")
5866
project_id: int = Field(description="The ID of the Project to analyse")
@@ -61,6 +69,7 @@ class MLJobParameters(BaseModel):
6169
DocTagRecommendationParams,
6270
CoreferenceResolutionParams,
6371
DocumentEmbeddingParams,
72+
SentenceEmbeddingParams,
6473
None,
6574
] = Field(
6675
description="Specific parameters for the MLJob w.r.t it's type",

backend/src/app/core/data/orm/source_document_job_status.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ class JobType(IntEnum):
1515
QUOTATION_ATTRIBUTION = 100
1616
COREFERENCE_RESOLUTION = 101
1717
DOCUMENT_EMBEDDING = 102
18+
SENTENCE_EMBEDDING = 103
1819

1920

2021
class JobStatus(IntEnum):

backend/src/app/core/db/simsearch_service.py

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,9 @@
1-
from typing import (
2-
Any,
3-
Dict,
4-
List,
5-
Optional,
6-
Union,
7-
)
1+
from typing import Any, Dict, List, Optional, Union
82

93
import numpy as np
104
from loguru import logger
115

12-
from app.core.data.dto.search import (
13-
SimSearchImageHit,
14-
SimSearchSentenceHit,
15-
)
6+
from app.core.data.dto.search import SimSearchImageHit, SimSearchSentenceHit
167
from app.core.ml.embedding_service import EmbeddingService
178
from app.core.vector.crud.image_embedding import crud_image_embedding
189
from app.core.vector.crud.sentence_embedding import crud_sentence_embedding
@@ -44,7 +35,7 @@ def _encode_query(
4435
query_emb = (
4536
self.emb.encode_document(" ".join(text_query))
4637
if document_query
47-
else self.emb.encode_sentences(sentences=text_query)
38+
else self.emb.encode_sentences(sentences=text_query)[0]
4839
)
4940
elif image_query_id is not None:
5041
query_emb = self.emb.encode_image(sdoc_id=image_query_id)

backend/src/app/core/ml/embedding_service.py

Lines changed: 85 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,9 @@
1818
from app.core.data.repo.utils import image_to_base64, load_image
1919
from app.core.db.sql_service import SQLService
2020
from app.core.vector.crud.document_embedding import crud_document_embedding
21+
from app.core.vector.crud.sentence_embedding import crud_sentence_embedding
2122
from app.core.vector.dto.document_embedding import DocumentObjectIdentifier
23+
from app.core.vector.dto.sentence_embedding import SentenceObjectIdentifier
2224
from app.core.vector.weaviate_service import WeaviateService
2325
from app.preprocessing.ray_model_service import RayModelService
2426
from app.preprocessing.ray_model_worker.dto.clip import (
@@ -46,10 +48,7 @@ def encode_sentences(self, sentences: List[str]) -> np.ndarray:
4648
encoded_query = self.rms.clip_text_embedding(
4749
ClipTextEmbeddingInput(text=sentences)
4850
)
49-
if len(encoded_query.embeddings) == 1:
50-
return encoded_query.numpy().squeeze()
51-
else:
52-
return encoded_query.numpy()
51+
return encoded_query.numpy()
5352

5453
def encode_image(self, sdoc_id: int) -> np.ndarray:
5554
with self.sqls.db_session() as db:
@@ -69,6 +68,87 @@ def encode_image(self, sdoc_id: int) -> np.ndarray:
6968
)
7069
return encoded_query.numpy().squeeze()
7170

71+
def embed_sentences(
72+
self, project_id: int, filter_criterion: ColumnElement, recompute=False
73+
) -> int:
74+
total_processed = 0
75+
num_processed = -1
76+
77+
with self.weaviate.weaviate_session() as client:
78+
if recompute:
79+
crud_sentence_embedding.remove_embeddings_by_project(client, project_id)
80+
81+
while num_processed != 0:
82+
num_processed = self._process_sentences_batch(
83+
client,
84+
filter_criterion,
85+
project_id,
86+
)
87+
total_processed += num_processed
88+
return total_processed
89+
90+
def _process_sentences_batch(
91+
self,
92+
client: WeaviateClient,
93+
filter_criterion: ColumnElement,
94+
project_id: int,
95+
batch_size=16,
96+
):
97+
with self.sqls.db_session() as db:
98+
query = (
99+
db.query(SourceDocumentDataORM)
100+
.outerjoin(
101+
SourceDocumentJobStatusORM,
102+
and_(
103+
SourceDocumentJobStatusORM.id == SourceDocumentDataORM.id,
104+
SourceDocumentJobStatusORM.type == JobType.SENTENCE_EMBEDDING,
105+
),
106+
full=True,
107+
)
108+
.filter(filter_criterion)
109+
.limit(batch_size)
110+
)
111+
sdoc_data = query.all()
112+
doc_sentences = [doc.sentences for doc in sdoc_data]
113+
sdoc_ids = [doc.id for doc in sdoc_data]
114+
num_docs = len(doc_sentences)
115+
116+
if num_docs == 0:
117+
return num_docs
118+
119+
# Embed the sentences for a batch of documents
120+
embeddings = self.encode_sentences(
121+
[s for sents in doc_sentences for s in sents]
122+
).tolist()
123+
124+
ids = [
125+
SentenceObjectIdentifier(sdoc_id=sdoc_id, sentence_id=i)
126+
for sdoc_id, sents in zip(sdoc_ids, doc_sentences)
127+
for i in range(len(sents))
128+
]
129+
130+
# Store the embeddings of a batch of documents
131+
crud_sentence_embedding.add_embedding_batch(
132+
client,
133+
project_id,
134+
ids=ids,
135+
embeddings=embeddings,
136+
)
137+
138+
crud_sdoc_job_status.create_multi(
139+
db,
140+
create_dtos=[
141+
SourceDocumentJobStatusCreate(
142+
id=id,
143+
type=JobType.SENTENCE_EMBEDDING,
144+
status=JobStatus.FINISHED,
145+
timestamp=datetime.now(),
146+
)
147+
for id in sdoc_ids
148+
],
149+
)
150+
return num_docs
151+
72152
def embed_documents(
73153
self, project_id: int, filter_criterion: ColumnElement, recompute=False
74154
) -> int:
@@ -88,7 +168,7 @@ def embed_documents(
88168
project_id,
89169
force_override=(recompute and (total_processed == 0)),
90170
)
91-
total_processed = +num_processed
171+
total_processed += num_processed
92172
return total_processed
93173

94174
def _process_document_batch(

backend/src/app/core/ml/ml_service.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
MLJobType,
1313
MLJobUpdate,
1414
QuotationAttributionParams,
15+
SentenceEmbeddingParams,
1516
)
1617
from app.core.data.orm.source_document_job_status import (
1718
JobStatus,
@@ -122,6 +123,19 @@ def start_ml_job_sync(self, ml_job_id: str) -> MLJobRead:
122123
EmbeddingService().embed_documents(
123124
mlj.parameters.project_id, filter_criterion, recompute
124125
)
126+
case MLJobType.SENTENCE_EMBEDDING:
127+
assert isinstance(
128+
mlj.parameters.specific_ml_job_parameters,
129+
SentenceEmbeddingParams,
130+
), "SentencetEmbeddingParams expected"
131+
recompute = mlj.parameters.specific_ml_job_parameters.recompute
132+
filter_criterion = self._build_filter_criterion(
133+
start_time, recompute
134+
)
135+
EmbeddingService().embed_sentences(
136+
mlj.parameters.project_id, filter_criterion, recompute
137+
)
138+
125139
mlj = self._update_ml_job(
126140
ml_job_id, MLJobUpdate(status=BackgroundJobStatus.FINISHED)
127141
)

backend/src/app/preprocessing/pipeline/steps/common/storage/index_text_document_for_simsearch.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@ def index_text_document_for_simsearch(cargo: PipelineCargo) -> PipelineCargo:
1919
if len(sentences) > 0:
2020
# embed the sentences
2121
embeddings = emb.encode_sentences(sentences=sentences).tolist()
22-
embeddings = embeddings if len(sentences) > 1 else [embeddings]
2322

2423
# store the embeddings
2524
logger.debug(

frontend/src/api/openapi/models/MLJobParameters_Input.ts

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import type { DocTagRecommendationParams } from "./DocTagRecommendationParams";
77
import type { DocumentEmbeddingParams } from "./DocumentEmbeddingParams";
88
import type { MLJobType } from "./MLJobType";
99
import type { QuotationAttributionParams } from "./QuotationAttributionParams";
10+
import type { SentenceEmbeddingParams } from "./SentenceEmbeddingParams";
1011
export type MLJobParameters_Input = {
1112
/**
1213
* The type of the MLJob
@@ -20,6 +21,12 @@ export type MLJobParameters_Input = {
2021
* Specific parameters for the MLJob w.r.t it's type
2122
*/
2223
specific_ml_job_parameters:
23-
| (QuotationAttributionParams | DocTagRecommendationParams | CoreferenceResolutionParams | DocumentEmbeddingParams)
24+
| (
25+
| QuotationAttributionParams
26+
| DocTagRecommendationParams
27+
| CoreferenceResolutionParams
28+
| DocumentEmbeddingParams
29+
| SentenceEmbeddingParams
30+
)
2431
| null;
2532
};

frontend/src/api/openapi/models/MLJobParameters_Output.ts

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import type { DocTagRecommendationParams } from "./DocTagRecommendationParams";
77
import type { DocumentEmbeddingParams } from "./DocumentEmbeddingParams";
88
import type { MLJobType } from "./MLJobType";
99
import type { QuotationAttributionParams } from "./QuotationAttributionParams";
10+
import type { SentenceEmbeddingParams } from "./SentenceEmbeddingParams";
1011
export type MLJobParameters_Output = {
1112
/**
1213
* The type of the MLJob
@@ -20,6 +21,12 @@ export type MLJobParameters_Output = {
2021
* Specific parameters for the MLJob w.r.t it's type
2122
*/
2223
specific_ml_job_parameters:
23-
| (QuotationAttributionParams | DocTagRecommendationParams | CoreferenceResolutionParams | DocumentEmbeddingParams)
24+
| (
25+
| QuotationAttributionParams
26+
| DocTagRecommendationParams
27+
| CoreferenceResolutionParams
28+
| DocumentEmbeddingParams
29+
| SentenceEmbeddingParams
30+
)
2431
| null;
2532
};

frontend/src/api/openapi/models/MLJobType.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,5 @@ export enum MLJobType {
77
DOC_TAG_RECOMMENDATION = "DOC_TAG_RECOMMENDATION",
88
COREFERENCE_RESOLUTION = "COREFERENCE_RESOLUTION",
99
DOCUMENT_EMBEDDING = "DOCUMENT_EMBEDDING",
10+
SENTENCE_EMBEDDING = "SENTENCE_EMBEDDING",
1011
}

0 commit comments

Comments
 (0)