1818from app .core .data .repo .utils import image_to_base64 , load_image
1919from app .core .db .sql_service import SQLService
2020from app .core .vector .crud .document_embedding import crud_document_embedding
21+ from app .core .vector .crud .sentence_embedding import crud_sentence_embedding
2122from app .core .vector .dto .document_embedding import DocumentObjectIdentifier
23+ from app .core .vector .dto .sentence_embedding import SentenceObjectIdentifier
2224from app .core .vector .weaviate_service import WeaviateService
2325from app .preprocessing .ray_model_service import RayModelService
2426from app .preprocessing .ray_model_worker .dto .clip import (
@@ -46,10 +48,7 @@ def encode_sentences(self, sentences: List[str]) -> np.ndarray:
4648 encoded_query = self .rms .clip_text_embedding (
4749 ClipTextEmbeddingInput (text = sentences )
4850 )
49- if len (encoded_query .embeddings ) == 1 :
50- return encoded_query .numpy ().squeeze ()
51- else :
52- return encoded_query .numpy ()
51+ return encoded_query .numpy ()
5352
5453 def encode_image (self , sdoc_id : int ) -> np .ndarray :
5554 with self .sqls .db_session () as db :
@@ -69,6 +68,87 @@ def encode_image(self, sdoc_id: int) -> np.ndarray:
6968 )
7069 return encoded_query .numpy ().squeeze ()
7170
71+ def embed_sentences (
72+ self , project_id : int , filter_criterion : ColumnElement , recompute = False
73+ ) -> int :
74+ total_processed = 0
75+ num_processed = - 1
76+
77+ with self .weaviate .weaviate_session () as client :
78+ if recompute :
79+ crud_sentence_embedding .remove_embeddings_by_project (client , project_id )
80+
81+ while num_processed != 0 :
82+ num_processed = self ._process_sentences_batch (
83+ client ,
84+ filter_criterion ,
85+ project_id ,
86+ )
87+ total_processed += num_processed
88+ return total_processed
89+
90+ def _process_sentences_batch (
91+ self ,
92+ client : WeaviateClient ,
93+ filter_criterion : ColumnElement ,
94+ project_id : int ,
95+ batch_size = 16 ,
96+ ):
97+ with self .sqls .db_session () as db :
98+ query = (
99+ db .query (SourceDocumentDataORM )
100+ .outerjoin (
101+ SourceDocumentJobStatusORM ,
102+ and_ (
103+ SourceDocumentJobStatusORM .id == SourceDocumentDataORM .id ,
104+ SourceDocumentJobStatusORM .type == JobType .SENTENCE_EMBEDDING ,
105+ ),
106+ full = True ,
107+ )
108+ .filter (filter_criterion )
109+ .limit (batch_size )
110+ )
111+ sdoc_data = query .all ()
112+ doc_sentences = [doc .sentences for doc in sdoc_data ]
113+ sdoc_ids = [doc .id for doc in sdoc_data ]
114+ num_docs = len (doc_sentences )
115+
116+ if num_docs == 0 :
117+ return num_docs
118+
119+ # Embed the sentences for a batch of documents
120+ embeddings = self .encode_sentences (
121+ [s for sents in doc_sentences for s in sents ]
122+ ).tolist ()
123+
124+ ids = [
125+ SentenceObjectIdentifier (sdoc_id = sdoc_id , sentence_id = i )
126+ for sdoc_id , sents in zip (sdoc_ids , doc_sentences )
127+ for i in range (len (sents ))
128+ ]
129+
130+ # Store the embeddings of a batch of documents
131+ crud_sentence_embedding .add_embedding_batch (
132+ client ,
133+ project_id ,
134+ ids = ids ,
135+ embeddings = embeddings ,
136+ )
137+
138+ crud_sdoc_job_status .create_multi (
139+ db ,
140+ create_dtos = [
141+ SourceDocumentJobStatusCreate (
142+ id = id ,
143+ type = JobType .SENTENCE_EMBEDDING ,
144+ status = JobStatus .FINISHED ,
145+ timestamp = datetime .now (),
146+ )
147+ for id in sdoc_ids
148+ ],
149+ )
150+ return num_docs
151+
72152 def embed_documents (
73153 self , project_id : int , filter_criterion : ColumnElement , recompute = False
74154 ) -> int :
@@ -88,7 +168,7 @@ def embed_documents(
88168 project_id ,
89169 force_override = (recompute and (total_processed == 0 )),
90170 )
91- total_processed = + num_processed
171+ total_processed += num_processed
92172 return total_processed
93173
94174 def _process_document_batch (
0 commit comments