Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Auto Hashing ID for VectorDB Classes #4746

Merged
merged 6 commits into from
Dec 23, 2024
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 70 additions & 24 deletions autogen/agentchat/contrib/vectordb/mongodb.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from copy import deepcopy
from time import monotonic, sleep
from typing import Any, Callable, Dict, Iterable, List, Literal, Mapping, Set, Tuple, Union
from typing import Any, Callable, Dict, Iterable, List, Literal, Mapping, Optional, Set, Tuple, Union

import numpy as np
from pymongo import MongoClient, UpdateOne, errors
from bson import ObjectId
from pymongo import InsertOne, MongoClient, UpdateOne, errors
from pymongo.collection import Collection
from pymongo.driver_info import DriverInfo
from pymongo.operations import SearchIndexModel
Expand All @@ -24,6 +25,10 @@ def with_id_rename(docs: Iterable) -> List[Dict[str, Any]]:
return [{**{k: v for k, v in d.items() if k != "_id"}, "id": d["_id"]} for d in docs]


class MongoDocument(Document):
id: Optional[ItemID]


class MongoDBAtlasVectorDB(VectorDB):
"""
A Collection object for MongoDB.
Expand Down Expand Up @@ -115,15 +120,18 @@ def _wait_for_document(self, collection: Collection, index_name: str, doc: Docum
start = monotonic()
while monotonic() - start < self._wait_until_document_ready:
query_result = _vector_search(
embedding_vector=np.array(self.embedding_function(doc["content"])).tolist(),
embedding_vector=doc.get("embedding", np.array(self.embedding_function(doc["content"])).tolist()),
n_results=1,
collection=collection,
index_name=index_name,
)
if query_result and query_result[0][0]["_id"] == doc["id"]:
if query_result and str(query_result[0][0]["_id"]) == str(doc["id"]):
return
sleep(_DELAY)

if query_result and float(query_result[0][1]) == 1.0:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We may want to have documents with exactly the same content but different metadata. Like comments from different users. I'd suggest log a warning message instead of raising an error. And it's not a TimeoutError if it's an error. WDYT?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adjusted now with a logger.warning, and check for metadata.

raise TimeoutError(
"Documents may be ready, but the search has found an identical file with a different ID."
)
raise TimeoutError(f"Document {self.index_name} is not ready!")

def _get_embedding_size(self):
Expand Down Expand Up @@ -265,7 +273,7 @@ def create_vector_search_index(

def insert_docs(
self,
docs: List[Document],
docs: List[MongoDocument],
collection_name: str = None,
upsert: bool = False,
batch_size=DEFAULT_INSERT_BATCH_SIZE,
Expand All @@ -276,15 +284,15 @@ def insert_docs(
For large numbers of Documents, insertion is performed in batches.

Args:
docs: List[Document] | A list of documents. Each document is a TypedDict `Document`.
docs: List[MongoDocument] | A list of documents. Each document is a TypedDict `MongoDocument`, which may contain ID.
collection_name: str | The name of the collection. Default is None.
upsert: bool | Whether to update the document if it exists. Default is False.
batch_size: Number of documents to be inserted in each batch
"""
if not docs:
logger.info("No documents to insert.")
return

docs = deepcopy(docs)
collection = self.get_collection(collection_name)
if upsert:
self.update_docs(docs, collection.name, upsert=True)
Expand All @@ -293,36 +301,48 @@ def insert_docs(
if docs[0].get("content") is None:
raise ValueError("The document content is required.")
if docs[0].get("id") is None:
raise ValueError("The document id is required.")

logger.info(
"No id field in the document. The document will be inserted without an id. MongoDB will id this document."
)
input_ids = set()
result_ids = set()
id_batch = []
text_batch = []
metadata_batch = []
embedding_batch = []
size = 0
i = 0
for doc in docs:
id = doc["id"]
text = doc["content"]
metadata = doc.get("metadata", {})
id_batch.append(id)
embedding = doc.get("embedding", None) # None Explicitly Typed for purpose clarity
text_batch.append(text)
metadata_batch.append(metadata)
id_size = 1 if isinstance(id, int) else len(id)
size += len(text) + len(metadata) + id_size
embedding_batch.append(embedding)
id = doc.get("id", None) # None Explicitly Typed for purpose clarity
id_batch.append(id)
if id is not None:
size += 1 if isinstance(id, int) else len(id)
size += len(text) + len(metadata)
if (i + 1) % batch_size == 0 or size >= 47_000_000:
result_ids.update(self._insert_batch(collection, text_batch, metadata_batch, id_batch))
upload_ids = self._insert_batch(collection, text_batch, metadata_batch, id_batch, embedding_batch)
result_ids.update([upload_ids for upload_ids in upload_ids if not isinstance(upload_ids, ObjectId)])
last_id = upload_ids[-1]
input_ids.update(id_batch)
id_batch = []
text_batch = []
metadata_batch = []
embedding_batch = []
size = 0
i += 1
if text_batch:
result_ids.update(self._insert_batch(collection, text_batch, metadata_batch, id_batch)) # type: ignore
upload_ids = self._insert_batch(collection, text_batch, metadata_batch, id_batch, embedding_batch)
result_ids.update([upload_ids for upload_ids in upload_ids if not isinstance(upload_ids, ObjectId)])
last_id = upload_ids[-1]
input_ids.update(id_batch)

input_ids.remove(None)

if result_ids != input_ids:
logger.warning(
"Possible data corruption. "
Expand All @@ -331,17 +351,27 @@ def insert_docs(
in_diff=input_ids.difference(result_ids), out_diff=result_ids.difference(input_ids)
)
)

docs[-1]["id"] = last_id # Update the last document with the last id inserted

if self._wait_until_document_ready and docs:
self._wait_for_document(collection, self.index_name, docs[-1])

def _insert_batch(
self, collection: Collection, texts: List[str], metadatas: List[Mapping[str, Any]], ids: List[ItemID]
) -> Set[ItemID]:
self,
collection: Collection,
texts: List[str],
metadatas: List[Mapping[str, Any]],
ids: List[ItemID],
embeddings: List[List[float]],
) -> List[ItemID]:
"""Compute embeddings for and insert a batch of Documents into the Collection.

For performance reasons, we chose to call self.embedding_function just once,
with the hopefully small tradeoff of having recreating Document dicts.

The embeddings are sense-checked for dimensionality.

Args:
collection: MongoDB Collection
texts: List of the main contents of each document
Expand All @@ -354,13 +384,29 @@ def _insert_batch(
n_texts = len(texts)
if n_texts == 0:
return []
# Embed and create the documents
embeddings = self.embedding_function(texts).tolist()

# Embed and create the missing document embeddings
if None in embeddings:
to_embed_keys = [i for i, embed in enumerate(embeddings) if embed is None]
to_embed = [texts[i] for i in to_embed_keys]
new_embeddings = self.embedding_function(to_embed)
try:
new_embeddings = new_embeddings.tolist() # attempts one to list method before the other
except AttributeError:
new_embeddings = new_embeddings.to_list()
for i, pos in enumerate(to_embed_keys):
embeddings[pos] = new_embeddings[i]

assert (
all(len(embed) == len(embeddings[0]) for embed in embeddings) if embeddings else True
), f"Embedding Vectors are not all equal in length. Sizes: {set([len(embed) for embed in embeddings])}"

assert (
len(embeddings) == n_texts
), f"The number of embeddings produced by self.embedding_function ({len(embeddings)} does not match the number of texts provided to it ({n_texts})."

to_insert = [
{"_id": i, "content": t, "metadata": m, "embedding": e}
{**({"_id": i} if i is not None else {}), "content": t, "metadata": m, "embedding": e}
for i, t, m, e in zip(ids, texts, metadatas, embeddings)
]
# insert the documents in MongoDB Atlas
Expand All @@ -375,11 +421,11 @@ def update_docs(self, docs: List[Document], collection_name: str = None, **kwarg
Uses deepcopy to avoid changing docs.

Args:
docs: List[Document] | A list of documents.
docs: List[Document] | A list of documents, with ID, to ensure the correct document is updated.
collection_name: str | The name of the collection. Default is None.
kwargs: Any | Use upsert=True` to insert documents whose ids are not present in collection.
"""

docs = [doc for doc in docs if doc.get("id") is not None]
n_docs = len(docs)
logger.info(f"Preparing to embed and update {n_docs=}")
# Compute the embeddings
Expand All @@ -390,8 +436,8 @@ def update_docs(self, docs: List[Document], collection_name: str = None, **kwarg
doc = deepcopy(docs[i])
doc["embedding"] = embeddings[i]
doc["_id"] = doc.pop("id")

all_updates.append(UpdateOne({"_id": doc["_id"]}, {"$set": doc}, upsert=kwargs.get("upsert", False)))

# Perform update in bulk
collection = self.get_collection(collection_name)
result = collection.bulk_write(all_updates)
Expand Down
Loading