diff --git a/pyproject.toml b/pyproject.toml index 104b45a748..80207005ec 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -69,8 +69,10 @@ full = [ # The qdrant-client >= 1.16.0 has conflicts with pymilvus, so we fix # the version to 1.15.1 here. "qdrant-client==1.15.1", + # MySQL vector store "mysql-connector-python", - + # Lindorm vector store (OpenSearch-compatible) + "opensearch-py>=2.0.0", ] dev = [ diff --git a/src/agentscope/rag/__init__.py b/src/agentscope/rag/__init__.py index 146b35aa61..70da06ad0a 100644 --- a/src/agentscope/rag/__init__.py +++ b/src/agentscope/rag/__init__.py @@ -17,6 +17,7 @@ QdrantStore, MilvusLiteStore, AlibabaCloudMySQLStore, + LindormStore, ) from ._knowledge_base import KnowledgeBase from ._simple_knowledge import SimpleKnowledge @@ -34,6 +35,7 @@ "QdrantStore", "MilvusLiteStore", "AlibabaCloudMySQLStore", + "LindormStore", "KnowledgeBase", "SimpleKnowledge", ] diff --git a/src/agentscope/rag/_store/__init__.py b/src/agentscope/rag/_store/__init__.py index c107eae5d9..217c813afb 100644 --- a/src/agentscope/rag/_store/__init__.py +++ b/src/agentscope/rag/_store/__init__.py @@ -7,10 +7,12 @@ from ._qdrant_store import QdrantStore from ._milvuslite_store import MilvusLiteStore from ._alibabacloud_mysql_store import AlibabaCloudMySQLStore +from ._lindorm_store import LindormStore __all__ = [ "VDBStoreBase", "QdrantStore", "MilvusLiteStore", "AlibabaCloudMySQLStore", + "LindormStore", ] diff --git a/src/agentscope/rag/_store/_lindorm_store.py b/src/agentscope/rag/_store/_lindorm_store.py new file mode 100644 index 0000000000..c7bf1ca086 --- /dev/null +++ b/src/agentscope/rag/_store/_lindorm_store.py @@ -0,0 +1,317 @@ +# -*- coding: utf-8 -*- +"""The Lindorm vector store implementation.""" +import json +from typing import Any, Literal, TYPE_CHECKING + +from .._reader import Document +from ._store_base import VDBStoreBase +from .._document import DocMetadata +from ..._utils._common import _map_text_to_uuid +from ...types import Embedding + +if TYPE_CHECKING: + from opensearchpy import OpenSearch +else: + OpenSearch = "opensearchpy.OpenSearch" + + +class LindormStore(VDBStoreBase): + """The Lindorm vector store implementation, supporting Aliyun Lindorm + vector engine with vector similarity search and custom routing. + + .. note:: Lindorm uses OpenSearch-compatible API. We store metadata in + document fields including doc_id, chunk_id, and content. + + """ + + def __init__( + self, + hosts: list[str], + index_name: str, + dimensions: int, + http_auth: tuple[str, str], + distance_metric: Literal["l2", "cosine", "inner_product"] = "cosine", + enable_routing: bool = False, + use_ssl: bool = False, + verify_certs: bool = False, + ) -> None: + """Initialize the Lindorm vector store. + + Args: + hosts (`list[str]`): + List of Lindorm hosts, e.g., ["http://lindorm-host:9200"]. + index_name (`str`): + The name of the index to store embeddings. + dimensions (`int`): + The dimension of the embeddings. + http_auth (`tuple[str, str]`): + HTTP authentication (username, password) tuple. Required + for Aliyun Lindorm cloud service. + distance_metric (`Literal["l2", "cosine", "inner_product"]`, \ + defaults to "cosine"): + The distance metric for vector similarity. + enable_routing (`bool`, defaults to False): + Whether to enable custom routing for data isolation. + use_ssl (`bool`, defaults to False): + Whether to use SSL/TLS for the connection. + verify_certs (`bool`, defaults to False): + Whether to verify SSL certificates. + """ + + try: + from opensearchpy import OpenSearch + except ImportError as e: + raise ImportError( + "opensearch-py is not installed. Please install it with " + "`pip install opensearch-py`.", + ) from e + + self._client = OpenSearch( + hosts=hosts, + http_auth=http_auth, + use_ssl=use_ssl, + verify_certs=verify_certs, + ssl_show_warn=False, + ) + + self.index_name = index_name + self.dimensions = dimensions + self.distance_metric = distance_metric + self.enable_routing = enable_routing + + def _create_index_body(self) -> dict[str, Any]: + """Create the index body configuration for Lindorm. + + Returns: + `dict[str, Any]`: + The index configuration body including settings and mappings + for vector storage with Lindorm's lvector engine. + """ + knn_settings: dict[str, Any] = {} + if self.enable_routing: + knn_settings["knn_routing"] = True + + # Map distance metric to Lindorm's space_type + space_type_map = { + "l2": "l2", + "cosine": "cosinesimil", + "inner_product": "innerproduct", + } + lvector_space_type = space_type_map.get( + self.distance_metric, + self.distance_metric, + ) + + index_body = { + "settings": { + "index": { + "number_of_shards": 1, + "number_of_replicas": 0, + "knn": True, + **knn_settings, + }, + }, + "mappings": { + "_source": {"excludes": ["vector"]}, + "properties": { + "vector": { + "type": "knn_vector", + "dimension": self.dimensions, + "method": { + "engine": "lvector", + "name": "hnsw", + "space_type": lvector_space_type, + }, + }, + "doc_id": {"type": "keyword"}, + "chunk_id": {"type": "integer"}, + "content": {"type": "object", "enabled": False}, + "total_chunks": {"type": "integer"}, + }, + }, + } + + return index_body + + async def _validate_index(self) -> None: + """Validate the index exists, and create it if not. + + This method checks if the index exists in Lindorm. If the index + does not exist, it will be created with the appropriate vector + configuration. + + Raises: + Exception: If index creation fails due to connection issues + or invalid configuration. + """ + if not self._client.indices.exists(index=self.index_name): + index_body = self._create_index_body() + self._client.indices.create( + index=self.index_name, + body=index_body, + ) + + async def add(self, documents: list[Document], **kwargs: Any) -> None: + """Add embeddings to the Lindorm vector store. + + Args: + documents (`list[Document]`): + A list of documents to be added to the Lindorm store. + **kwargs (`Any`): + Additional arguments: + - routing (`str`): Custom routing key for data isolation. + """ + await self._validate_index() + + routing = kwargs.get("routing", None) + + for doc in documents: + unique_string = json.dumps( + { + "doc_id": doc.metadata.doc_id, + "chunk_id": doc.metadata.chunk_id, + }, + ensure_ascii=False, + ) + doc_id = _map_text_to_uuid(unique_string) + + body = { + "vector": doc.embedding, + "doc_id": doc.metadata.doc_id, + "chunk_id": doc.metadata.chunk_id, + "content": doc.metadata.content, + "total_chunks": doc.metadata.total_chunks, + } + + index_params: dict[str, Any] = { + "index": self.index_name, + "id": doc_id, + "body": body, + } + + if self.enable_routing and routing: + index_params["routing"] = routing + + self._client.index(**index_params) + + self._client.indices.refresh(index=self.index_name) + + async def search( + self, + query_embedding: Embedding, + limit: int, + score_threshold: float | None = None, + **kwargs: Any, + ) -> list[Document]: + """Search relevant documents from the Lindorm vector store. + + Args: + query_embedding (`Embedding`): + The embedding of the query text. + limit (`int`): + The number of relevant documents to retrieve. + score_threshold (`float | None`, optional): + The threshold of the score to filter results. + **kwargs (`Any`): + Additional arguments: + - routing (`str`): Custom routing key for targeted search. + """ + routing = kwargs.get("routing", None) + + query_body = { + "size": limit, + "query": { + "knn": { + "vector": { + "vector": query_embedding, + "k": limit, + }, + }, + }, + "_source": True, + } + + search_params: dict[str, Any] = { + "index": self.index_name, + "body": query_body, + } + + if self.enable_routing and routing: + search_params["routing"] = routing + + response = self._client.search(**search_params) + + collected_res = [] + for hit in response["hits"]["hits"]: + score = hit["_score"] + + if score_threshold is not None and score < score_threshold: + continue + + source = hit.get("_source", {}) + if not source: + # Lindorm might return fields directly without _source + source = hit + + doc_metadata = DocMetadata( + content=source.get("content", {}), + doc_id=source.get("doc_id", ""), + chunk_id=source.get("chunk_id", 0), + total_chunks=source.get("total_chunks", 0), + ) + + collected_res.append( + Document( + embedding=source.get("vector"), + score=score, + metadata=doc_metadata, + ), + ) + + return collected_res + + async def delete( + self, + doc_ids: list[str], + routing: str | None = None, + ) -> None: + """Delete documents from the Lindorm vector store. + + Args: + doc_ids (`list[str]`): + List of internal document UUIDs to delete. These values must + match the index document IDs generated during :meth:`add` by + combining ``doc.metadata.doc_id`` and + ``doc.metadata.chunk_id``, and are not the original + ``doc.metadata.doc_id`` values. + routing (`str | None`, optional): + Custom routing key for targeted deletion when routing is + enabled. Defaults to None. + + Raises: + ValueError: If ``doc_ids`` is empty. + """ + if not doc_ids: + raise ValueError("doc_ids must be provided for deletion.") + + for doc_id in doc_ids: + delete_params: dict[str, Any] = { + "index": self.index_name, + "id": doc_id, + } + + if self.enable_routing and routing: + delete_params["routing"] = routing + + self._client.delete(**delete_params) + + self._client.indices.refresh(index=self.index_name) + + def get_client(self) -> OpenSearch: + """Get the underlying OpenSearch client for Lindorm. + + Returns: + `OpenSearch`: + The underlying OpenSearch client. + """ + return self._client diff --git a/tests/rag_store_test.py b/tests/rag_store_test.py index 485dd23827..e93f701f42 100644 --- a/tests/rag_store_test.py +++ b/tests/rag_store_test.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- """Test the RAG store implementations.""" +import json import os from unittest import IsolatedAsyncioTestCase from unittest.mock import MagicMock, patch @@ -11,7 +12,9 @@ DocMetadata, MilvusLiteStore, AlibabaCloudMySQLStore, + LindormStore, ) +from agentscope._utils._common import _map_text_to_uuid class RAGStoreTest(IsolatedAsyncioTestCase): @@ -261,6 +264,114 @@ async def test_alibabacloud_mysql_store(self) -> None: async def asyncTearDown(self) -> None: """Clean up after tests.""" - # Remove Milvus Lite database file if os.path.exists("./milvus_demo.db"): os.remove("./milvus_demo.db") + + @patch("opensearchpy.OpenSearch") + async def test_lindorm_store( + self, + mock_opensearch_class: MagicMock, + ) -> None: + """Test the LindormStore implementation.""" + mock_client = MagicMock() + mock_opensearch_class.return_value = mock_client + + mock_client.indices.exists.return_value = False + mock_client.indices.create.return_value = {"acknowledged": True} + mock_client.index.return_value = {"result": "created"} + mock_client.indices.refresh.return_value = { + "_shards": {"successful": 1}, + } + mock_client.search.return_value = { + "hits": { + "hits": [ + { + "_score": 0.95, + "_source": { + "vector": [0.1, 0.2, 0.3], + "doc_id": "doc1", + "chunk_id": 0, + "content": { + "type": "text", + "text": "This is a test document.", + }, + "total_chunks": 2, + }, + }, + ], + }, + } + + store = LindormStore( + hosts=["http://localhost:9200"], + index_name="test_index", + dimensions=3, + http_auth=("user", "pass"), + enable_routing=True, + ) + + await store.add( + [ + Document( + embedding=[0.1, 0.2, 0.3], + metadata=DocMetadata( + content=TextBlock( + type="text", + text="This is a test document.", + ), + doc_id="doc1", + chunk_id=0, + total_chunks=2, + ), + ), + ], + routing="user123", + ) + + mock_client.indices.create.assert_called_once() + self.assertTrue(mock_client.index.called) + + res = await store.search( + query_embedding=[0.15, 0.25, 0.35], + limit=3, + score_threshold=0.9, + routing="user123", + ) + + self.assertEqual(len(res), 1) + self.assertEqual(res[0].score, 0.95) + self.assertEqual( + res[0].metadata.content["text"], + "This is a test document.", + ) + + call_args = mock_client.search.call_args + query_body = call_args[1]["body"] + self.assertEqual(query_body["size"], 3) + self.assertIn("knn", query_body["query"]) + + # Test delete + mock_client.delete.return_value = {"result": "deleted"} + mock_client.indices.refresh.return_value = { + "_shards": {"successful": 1}, + } + + # Generate a doc_id similar to how add() does it + unique_string = json.dumps( + {"doc_id": "doc1", "chunk_id": 0}, + ensure_ascii=False, + ) + doc_id_to_delete = _map_text_to_uuid(unique_string) + + await store.delete(doc_ids=[doc_id_to_delete], routing="user123") + + self.assertTrue(mock_client.delete.called) + delete_call_args = mock_client.delete.call_args + self.assertEqual( + delete_call_args[1]["id"], + doc_id_to_delete, + ) + self.assertEqual( + delete_call_args[1]["routing"], + "user123", + )