diff --git a/byaldi/colpali.py b/byaldi/colpali.py index cc11dcb..c233821 100644 --- a/byaldi/colpali.py +++ b/byaldi/colpali.py @@ -254,10 +254,29 @@ def _export_index(self): embeddings_path.mkdir(exist_ok=True) num_embeddings = len(self.indexed_embeddings) chunk_size = 500 + + # Save files that are missing and track last modified file + last_modif_i = None + just_created = [] for i in range(0, num_embeddings, chunk_size): - chunk = self.indexed_embeddings[i : i + chunk_size] - torch.save(chunk, embeddings_path / f"embeddings_{i}.pt") - + filepath = embeddings_path / f"embeddings_{i}.pt" + # save missing file + if not os.path.exists(filepath): + chunk = self.indexed_embeddings[i : i + chunk_size] + print(f"Saving new file {filepath}") + torch.save(chunk, filepath) + just_created.append(i) + # track last created file + elif i not in just_created: + last_modif_i = i + + # Save last modified file + if last_modif_i is not None: + last_chunk = self.indexed_embeddings[last_modif_i : last_modif_i + chunk_size] + last_filepath = embeddings_path / f"embeddings_{last_modif_i}.pt" + print(f"Saving last modified file {last_filepath}") + torch.save(last_chunk, last_filepath) + # Save index config index_config = { "model_name": self.model_name,