Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Auto Hashing ID for VectorDB Classes #4746

Merged
merged 6 commits into from
Dec 23, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 70 additions & 24 deletions autogen/agentchat/contrib/vectordb/mongodb.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from copy import deepcopy
from time import monotonic, sleep
from typing import Any, Callable, Dict, Iterable, List, Literal, Mapping, Set, Tuple, Union
from typing import Any, Callable, Dict, Iterable, List, Literal, Mapping, Optional, Set, Tuple, Union

import numpy as np
from pymongo import MongoClient, UpdateOne, errors
from bson import ObjectId
from pymongo import InsertOne, MongoClient, UpdateOne, errors
from pymongo.collection import Collection
from pymongo.driver_info import DriverInfo
from pymongo.operations import SearchIndexModel
Expand All @@ -24,6 +25,10 @@ def with_id_rename(docs: Iterable) -> List[Dict[str, Any]]:
return [{**{k: v for k, v in d.items() if k != "_id"}, "id": d["_id"]} for d in docs]


class MongoDocument(Document):
id: Optional[ItemID]


class MongoDBAtlasVectorDB(VectorDB):
"""
A Collection object for MongoDB.
Expand Down Expand Up @@ -115,15 +120,18 @@ def _wait_for_document(self, collection: Collection, index_name: str, doc: Docum
start = monotonic()
while monotonic() - start < self._wait_until_document_ready:
query_result = _vector_search(
embedding_vector=np.array(self.embedding_function(doc["content"])).tolist(),
embedding_vector=doc.get("embedding", np.array(self.embedding_function(doc["content"])).tolist()),
n_results=1,
collection=collection,
index_name=index_name,
)
if query_result and query_result[0][0]["_id"] == doc["id"]:
if query_result and str(query_result[0][0]["_id"]) == str(doc["id"]):
return
sleep(_DELAY)

if query_result and float(query_result[0][1]) == 1.0:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We may want to have documents with exactly the same content but different metadata. Like comments from different users. I'd suggest log a warning message instead of raising an error. And it's not a TimeoutError if it's an error. WDYT?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adjusted now with a logger.warning, and check for metadata.

raise TimeoutError(
"Documents may be ready, but the search has found an identical file with a different ID."
)
raise TimeoutError(f"Document {self.index_name} is not ready!")

def _get_embedding_size(self):
Expand Down Expand Up @@ -265,7 +273,7 @@ def create_vector_search_index(

def insert_docs(
self,
docs: List[Document],
docs: List[MongoDocument],
collection_name: str = None,
upsert: bool = False,
batch_size=DEFAULT_INSERT_BATCH_SIZE,
Expand All @@ -276,15 +284,15 @@ def insert_docs(
For large numbers of Documents, insertion is performed in batches.

Args:
docs: List[Document] | A list of documents. Each document is a TypedDict `Document`.
docs: List[MongoDocument] | A list of documents. Each document is a TypedDict `MongoDocument`, which may contain ID.
collection_name: str | The name of the collection. Default is None.
upsert: bool | Whether to update the document if it exists. Default is False.
batch_size: Number of documents to be inserted in each batch
"""
if not docs:
logger.info("No documents to insert.")
return

docs = deepcopy(docs)
collection = self.get_collection(collection_name)
if upsert:
self.update_docs(docs, collection.name, upsert=True)
Expand All @@ -293,36 +301,48 @@ def insert_docs(
if docs[0].get("content") is None:
raise ValueError("The document content is required.")
if docs[0].get("id") is None:
raise ValueError("The document id is required.")

logger.info(
"No id field in the document. The document will be inserted without an id. MongoDB will id this document."
)
input_ids = set()
result_ids = set()
id_batch = []
text_batch = []
metadata_batch = []
embedding_batch = []
size = 0
i = 0
for doc in docs:
id = doc["id"]
text = doc["content"]
metadata = doc.get("metadata", {})
id_batch.append(id)
embedding = doc.get("embedding", None) # None Explicitly Typed for purpose clarity
text_batch.append(text)
metadata_batch.append(metadata)
id_size = 1 if isinstance(id, int) else len(id)
size += len(text) + len(metadata) + id_size
embedding_batch.append(embedding)
id = doc.get("id", None) # None Explicitly Typed for purpose clarity
id_batch.append(id)
if id is not None:
size += 1 if isinstance(id, int) else len(id)
size += len(text) + len(metadata)
if (i + 1) % batch_size == 0 or size >= 47_000_000:
result_ids.update(self._insert_batch(collection, text_batch, metadata_batch, id_batch))
upload_ids = self._insert_batch(collection, text_batch, metadata_batch, id_batch, embedding_batch)
result_ids.update([upload_ids for upload_ids in upload_ids if not isinstance(upload_ids, ObjectId)])
last_id = upload_ids[-1]
input_ids.update(id_batch)
id_batch = []
text_batch = []
metadata_batch = []
embedding_batch = []
size = 0
i += 1
if text_batch:
result_ids.update(self._insert_batch(collection, text_batch, metadata_batch, id_batch)) # type: ignore
upload_ids = self._insert_batch(collection, text_batch, metadata_batch, id_batch, embedding_batch)
result_ids.update([upload_ids for upload_ids in upload_ids if not isinstance(upload_ids, ObjectId)])
last_id = upload_ids[-1]
input_ids.update(id_batch)

input_ids.remove(None)

if result_ids != input_ids:
logger.warning(
"Possible data corruption. "
Expand All @@ -331,17 +351,27 @@ def insert_docs(
in_diff=input_ids.difference(result_ids), out_diff=result_ids.difference(input_ids)
)
)

docs[-1]["id"] = last_id # Update the last document with the last id inserted

if self._wait_until_document_ready and docs:
self._wait_for_document(collection, self.index_name, docs[-1])

def _insert_batch(
self, collection: Collection, texts: List[str], metadatas: List[Mapping[str, Any]], ids: List[ItemID]
) -> Set[ItemID]:
self,
collection: Collection,
texts: List[str],
metadatas: List[Mapping[str, Any]],
ids: List[ItemID],
embeddings: List[List[float]],
) -> List[ItemID]:
"""Compute embeddings for and insert a batch of Documents into the Collection.

For performance reasons, we chose to call self.embedding_function just once,
with the hopefully small tradeoff of having recreating Document dicts.

The embeddings are sense-checked for dimensionality.

Args:
collection: MongoDB Collection
texts: List of the main contents of each document
Expand All @@ -354,13 +384,29 @@ def _insert_batch(
n_texts = len(texts)
if n_texts == 0:
return []
# Embed and create the documents
embeddings = self.embedding_function(texts).tolist()

# Embed and create the missing document embeddings
if None in embeddings:
to_embed_keys = [i for i, embed in enumerate(embeddings) if embed is None]
to_embed = [texts[i] for i in to_embed_keys]
new_embeddings = self.embedding_function(to_embed)
try:
new_embeddings = new_embeddings.tolist() # attempts one to list method before the other
except AttributeError:
new_embeddings = new_embeddings.to_list()
for i, pos in enumerate(to_embed_keys):
embeddings[pos] = new_embeddings[i]

assert (
all(len(embed) == len(embeddings[0]) for embed in embeddings) if embeddings else True
), f"Embedding Vectors are not all equal in length. Sizes: {set([len(embed) for embed in embeddings])}"

assert (
len(embeddings) == n_texts
), f"The number of embeddings produced by self.embedding_function ({len(embeddings)} does not match the number of texts provided to it ({n_texts})."

to_insert = [
{"_id": i, "content": t, "metadata": m, "embedding": e}
{**({"_id": i} if i is not None else {}), "content": t, "metadata": m, "embedding": e}
for i, t, m, e in zip(ids, texts, metadatas, embeddings)
]
# insert the documents in MongoDB Atlas
Expand All @@ -375,11 +421,11 @@ def update_docs(self, docs: List[Document], collection_name: str = None, **kwarg
Uses deepcopy to avoid changing docs.

Args:
docs: List[Document] | A list of documents.
docs: List[Document] | A list of documents, with ID, to ensure the correct document is updated.
collection_name: str | The name of the collection. Default is None.
kwargs: Any | Use upsert=True` to insert documents whose ids are not present in collection.
"""

docs = [doc for doc in docs if doc.get("id") is not None]
n_docs = len(docs)
logger.info(f"Preparing to embed and update {n_docs=}")
# Compute the embeddings
Expand All @@ -390,8 +436,8 @@ def update_docs(self, docs: List[Document], collection_name: str = None, **kwarg
doc = deepcopy(docs[i])
doc["embedding"] = embeddings[i]
doc["_id"] = doc.pop("id")

all_updates.append(UpdateOne({"_id": doc["_id"]}, {"$set": doc}, upsert=kwargs.get("upsert", False)))

# Perform update in bulk
collection = self.get_collection(collection_name)
result = collection.bulk_write(all_updates)
Expand Down
33 changes: 25 additions & 8 deletions test/agentchat/contrib/vectordb/test_mongodb.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import pytest

from autogen.agentchat.contrib.vectordb.base import Document
from autogen.agentchat.contrib.vectordb.mongodb import MongoDocument

try:
import pymongo
Expand Down Expand Up @@ -96,14 +97,22 @@ def db():
_empty_collections_and_delete_indexes(database)


def generate_embeddings(n=384):
return [random.random() for _ in range(n)]


@pytest.fixture
def example_documents() -> List[Document]:
"""Note mix of integers and strings as ids"""
"""Note mix of integers and strings as ids, MongoDocuments added for testing"""
return [
Document(id=1, content="Dogs are tough.", metadata={"a": 1}),
Document(id=2, content="Cats have fluff.", metadata={"b": 1}),
Document(id="1", content="What is a sandwich?", metadata={"c": 1}),
Document(id="2", content="A sandwich makes a great lunch.", metadata={"d": 1, "e": 2}),
MongoDocument(content="Stars are big.", metadata={"a": 1}),
MongoDocument(content="Atoms are small.", metadata={"b": 1}, embedding=generate_embeddings()),
MongoDocument(id="123", content="I hate grass", metadata={"c": 1}),
MongoDocument(id="321", content="I love sand", metadata={"d": 1, "e": 2}, embedding=generate_embeddings()),
]


Expand Down Expand Up @@ -207,25 +216,32 @@ def test_insert_docs(db, collection_name, example_documents):
# Check that documents have correct fields, including "_id" and "embedding" but not "id"
assert all([set(doc.keys()) == {"_id", "content", "metadata", "embedding"} for doc in found])
# Check ids
assert {doc["_id"] for doc in found} == {1, "1", 2, "2"}
assert {doc["_id"] for doc in found} == {1, "1", 2, "2", found[4]["_id"], found[5]["_id"], "123", "321"}
# Check embedding lengths
assert len(found[0]["embedding"]) == 384

db.delete_collection(collection_name)
collection = db.create_collection(collection_name)
example_documents[0].embedding = [random.random() for _ in range(10)]
# Ensuring different size embeddings are not inserted
with pytest.raises(AssertionError, match=r"Embedding Vectors are not all equal in length. Sizes:"):
db.insert_docs(example_documents, collection_name=collection_name, upsert=False)


def test_update_docs(db_with_indexed_clxn, example_documents):
db, collection = db_with_indexed_clxn
# Use update_docs to insert new documents
db.update_docs(example_documents, collection.name, upsert=True)
# Test that no changes were made to example_documents
assert set(example_documents[0].keys()) == {"id", "content", "metadata"}
assert collection.count_documents({}) == len(example_documents)
assert collection.count_documents({}) == len([doc for doc in example_documents if doc.get("id") is not None])
found = list(collection.find({}))
# Check that documents have correct fields, including "_id" and "embedding" but not "id"
assert all([set(doc.keys()) == {"_id", "content", "metadata", "embedding"} for doc in found])
assert all([isinstance(doc["embedding"][0], float) for doc in found])
assert all([len(doc["embedding"]) == db.dimensions for doc in found])
# Check ids
assert {doc["_id"] for doc in found} == {1, "1", 2, "2"}
assert {doc["_id"] for doc in found} == {1, "1", 2, "2", "123", "321"}

# Update an *existing* Document
updated_doc = Document(id=1, content="Cats are tough.", metadata={"a": 10})
Expand Down Expand Up @@ -254,7 +270,8 @@ def test_delete_docs(db_with_indexed_clxn, example_documents):
# Delete the 1s
db.delete_docs(ids=[1, "1"], collection_name=clxn.name)
# Confirm just the 2s remain
assert {2, "2"} == {doc["_id"] for doc in clxn.find({})}
result_set = {doc["_id"] for doc in clxn.find({})}
assert 2 in result_set and "2" in result_set


def test_get_docs_by_ids(db_with_indexed_clxn, example_documents):
Expand Down Expand Up @@ -359,8 +376,8 @@ def results_ready():

assert len(results) == len(queries)
assert all([len(res) == n_results for res in results])
assert {doc[0]["id"] for doc in results[0]} == {1, 2}
assert {doc[0]["id"] for doc in results[1]} == {"1", "2"}
assert {1, 2} <= {doc[0]["id"] for doc in results[0]}
assert {"1", "2"} <= {doc[0]["id"] for doc in results[1]}


def test_retrieve_docs_with_threshold(db_with_indexed_clxn, example_documents):
Expand Down Expand Up @@ -397,6 +414,6 @@ def test_wait_until_document_ready(collection_name, example_documents):
wait_until_document_ready=TIMEOUT,
)
vectorstore.insert_docs(example_documents)
assert vectorstore.retrieve_docs(queries=["Cats"], n_results=4)
assert vectorstore.retrieve_docs(queries=["Cats"], n_results=8)
finally:
_empty_collections_and_delete_indexes(database, [collection_name])
Loading