Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 22 additions & 3 deletions byaldi/colpali.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,10 +254,29 @@ def _export_index(self):
embeddings_path.mkdir(exist_ok=True)
num_embeddings = len(self.indexed_embeddings)
chunk_size = 500

# Save files that are missing and track last modified file
last_modif_i = None
just_created = []
for i in range(0, num_embeddings, chunk_size):
chunk = self.indexed_embeddings[i : i + chunk_size]
torch.save(chunk, embeddings_path / f"embeddings_{i}.pt")

filepath = embeddings_path / f"embeddings_{i}.pt"
# save missing file
if not os.path.exists(filepath):
chunk = self.indexed_embeddings[i : i + chunk_size]
print(f"Saving new file {filepath}")
torch.save(chunk, filepath)
just_created.append(i)
# track last created file
elif i not in just_created:
last_modif_i = i

# Save last modified file
if last_modif_i is not None:
last_chunk = self.indexed_embeddings[last_modif_i : last_modif_i + chunk_size]
last_filepath = embeddings_path / f"embeddings_{last_modif_i}.pt"
print(f"Saving last modified file {last_filepath}")
torch.save(last_chunk, last_filepath)

# Save index config
index_config = {
"model_name": self.model_name,
Expand Down