diff --git a/.gitignore b/.gitignore index fd4bd83..d42d3d6 100644 --- a/.gitignore +++ b/.gitignore @@ -5,3 +5,5 @@ book.txt lightrag-dev/ .idea/ dist/ +/.lightRagEnv +volumes/ diff --git a/docker-compose.yml b/docker-compose.yml new file mode 100644 index 0000000..2a929d8 --- /dev/null +++ b/docker-compose.yml @@ -0,0 +1,83 @@ +version: '3.5' + +services: + etcd: + container_name: milvus-etcd + image: quay.io/coreos/etcd:v3.5.5 + environment: + - ETCD_AUTO_COMPACTION_MODE=revision + - ETCD_AUTO_COMPACTION_RETENTION=1000 + - ETCD_QUOTA_BACKEND_BYTES=4294967296 + - ETCD_SNAPSHOT_COUNT=50000 + volumes: + - ${DOCKER_VOLUME_DIRECTORY:-.}/volumes/etcd:/etcd + command: etcd -advertise-client-urls=http://127.0.0.1:2379 -listen-client-urls http://0.0.0.0:2379 --data-dir /etcd + healthcheck: + test: ["CMD", "etcdctl", "endpoint", "health"] + interval: 30s + timeout: 20s + retries: 3 + + minio: + container_name: milvus-minio + image: minio/minio:RELEASE.2023-03-20T20-16-18Z + environment: + MINIO_ACCESS_KEY: minioadmin + MINIO_SECRET_KEY: minioadmin + ports: + - "9001:9001" + - "9000:9000" + volumes: + - ${DOCKER_VOLUME_DIRECTORY:-.}/volumes/minio:/minio_data + command: minio server /minio_data --console-address ":9001" + healthcheck: + test: ["CMD", "curl", "-f", "http://localhost:9000/minio/health/live"] + interval: 30s + timeout: 20s + retries: 3 + + standalone: + container_name: milvus-standalone + image: milvusdb/milvus:v2.4.13-hotfix + command: ["milvus", "run", "standalone"] + security_opt: + - seccomp:unconfined + environment: + ETCD_ENDPOINTS: etcd:2379 + MINIO_ADDRESS: minio:9000 + volumes: + - ${DOCKER_VOLUME_DIRECTORY:-.}/volumes/milvus:/var/lib/milvus + healthcheck: + test: ["CMD", "curl", "-f", "http://localhost:9091/healthz"] + interval: 30s + start_period: 90s + timeout: 20s + retries: 3 + ports: + - "19530:19530" + - "9091:9091" + depends_on: + - "etcd" + - "minio" + + neo4j: + image: neo4j:latest + volumes: + - ${DOCKER_VOLUME_DIRECTORY:-.}/volumes/neo4j/logs:/logs + - ${DOCKER_VOLUME_DIRECTORY:-.}/volumes/neo4j/config:/config + - ${DOCKER_VOLUME_DIRECTORY:-.}/volumes/neo4j/data:/data + - ${DOCKER_VOLUME_DIRECTORY:-.}/volumes/neo4j/plugins:/plugins + environment: + - NEO4J_AUTH=neo4j/admin12345 + - NEO4JLABS_PLUGINS=["graph-data-science"] + - NEO4J_ACCEPT_LICENSE_AGREEMENT=yes + - NEO4J_dbms_security_procedures_unrestricted=gds.* + - NEO4J_dbms_security_procedures_allowlist=gds.* + ports: + - "7474:7474" + - "7687:7687" + restart: always + +networks: + default: + name: milvus diff --git a/examples/lightrag_openai_demo_storage.py b/examples/lightrag_openai_demo_storage.py new file mode 100644 index 0000000..c7504c1 --- /dev/null +++ b/examples/lightrag_openai_demo_storage.py @@ -0,0 +1,43 @@ +import os + +from lightrag import LightRAG, QueryParam +from lightrag.llm import gpt_4o_mini_complete +from lightrag.storage import Neo4jKVStorage +# WORKING_DIR = "./dickens" + +# if not os.path.exists(WORKING_DIR): +# os.mkdir(WORKING_DIR) + +rag = LightRAG ( + working_dir="./dickens", + llm_model_func=gpt_4o_mini_complete, + neo4j_config={ + "uri": "neo4j://localhost:7687", + "username": "neo4j", + "password": "admin12345" + }, +) + + +with open("./book.txt", "r", encoding="utf-8") as f: + rag.insert(f.read()) + +# Perform naive search +print( + rag.query("What are the top themes in this story?", param=QueryParam(mode="naive")) +) + +# Perform local search +print( + rag.query("What are the top themes in this story?", param=QueryParam(mode="local")) +) + +# Perform global search +print( + rag.query("What are the top themes in this story?", param=QueryParam(mode="global")) +) + +# Perform hybrid search +print( + rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid")) +) diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index 3004f5e..40c1337 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -4,7 +4,7 @@ from datetime import datetime from functools import partial from typing import Type, cast - +from .storage import Neo4jKVStorage, Neo4jGraphStorage from .llm import ( gpt_4o_mini_complete, openai_embedding, @@ -44,9 +44,8 @@ def always_get_an_event_loop() -> asyncio.AbstractEventLoop: try: loop = asyncio.get_running_loop() except RuntimeError: - logger.info("Creating a new event loop in a sub-thread.") - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) + logger.info("Creating a new event loop in the main thread.") + loop = asyncio.get_event_loop() return loop @@ -97,23 +96,34 @@ class LightRAG: vector_db_storage_cls_kwargs: dict = field(default_factory=dict) graph_storage_cls: Type[BaseGraphStorage] = NetworkXStorage enable_llm_cache: bool = True - + + # Neo4j configuration + neo4j_config: dict = field(default_factory=dict) + milvus_config: dict = field(default_factory=dict) + # extension addon_params: dict = field(default_factory=dict) convert_response_to_json_func: callable = convert_response_to_json def __post_init__(self): + if not os.path.exists(self.working_dir): + os.makedirs(self.working_dir) + logger.info(f"Creating working directory {self.working_dir}") log_file = os.path.join(self.working_dir, "lightrag.log") set_logger(log_file) logger.info(f"Logger initialized for working directory: {self.working_dir}") _print_config = ",\n ".join([f"{k} = {v}" for k, v in asdict(self).items()]) logger.debug(f"LightRAG init with param:\n {_print_config}\n") - - if not os.path.exists(self.working_dir): - logger.info(f"Creating working directory {self.working_dir}") - os.makedirs(self.working_dir) - + + if self.neo4j_config.get("uri") and self.neo4j_config.get("username") and self.neo4j_config.get("password"): + self.key_string_value_json_storage_cls = Neo4jKVStorage + self.graph_storage_cls = Neo4jGraphStorage + logger.info("Using Neo4jKVStorage") + else: + self.key_string_value_json_storage_cls = JsonKVStorage + logger.info("Using JsonKVStorage") + self.full_docs = self.key_string_value_json_storage_cls( namespace="full_docs", global_config=asdict(self) ) diff --git a/lightrag/storage.py b/lightrag/storage.py index 1f22fc5..cbb301e 100644 --- a/lightrag/storage.py +++ b/lightrag/storage.py @@ -1,12 +1,12 @@ import asyncio import html import os -from dataclasses import dataclass -from typing import Any, Union, cast +from dataclasses import dataclass, field +from typing import Any, Union, cast, Tuple, List, Dict import networkx as nx import numpy as np from nano_vectordb import NanoVectorDB - +from neo4j import AsyncGraphDatabase from .utils import load_json, logger, write_json from .base import ( BaseGraphStorage, @@ -14,7 +14,6 @@ BaseVectorStorage, ) - @dataclass class JsonKVStorage(BaseKVStorage): def __post_init__(self): @@ -243,3 +242,241 @@ async def _node2vec_embed(self): nodes_ids = [self._graph.nodes[node_id]["id"] for node_id in nodes] return embeddings, nodes_ids +@dataclass +class Neo4jKVStorage(BaseKVStorage): + def __post_init__(self): + logger.debug(f"Global config: {self.global_config}") + neo4j_config: dict = self.global_config.get("neo4j_config", {}) + self.uri = neo4j_config.get("uri") + self.username = neo4j_config.get("username") + self.password = neo4j_config.get("password") + + if not self.namespace: + self.namespace = neo4j_config.get("namespace", "default_namespace") + self.driver = AsyncGraphDatabase.driver(self.uri, auth=(self.username, self.password)) + + async def close(self): + await self.driver.close() + + async def all_keys(self) -> list[str]: + query = f"MATCH (n:{self.namespace}) RETURN n.id AS id" + async with self.driver.session() as session: + result = await session.run(query) + keys = [] + async for record in result: + keys.append(record["id"]) + return keys + + async def get_by_id(self, id: str): + query = f"MATCH (n:{self.namespace} {{id: $id}}) RETURN n" + async with self.driver.session() as session: + result = await session.run(query, id=id) + record = await result.single() + return dict(record["n"]) if record else None + + async def get_by_ids(self, ids: list[str], fields: list[str] = None): + field_str = ", ".join([f"n.{field}" for field in fields]) if fields else "n" + query = f"MATCH (n:{self.namespace}) WHERE n.id IN $ids RETURN {field_str}" + async with self.driver.session() as session: + result = await session.run(query, ids=ids) + records = [] + async for record in result: + records.append(dict(record["n"])) + return records + + async def filter_keys(self, data: list[str]) -> set[str]: + query = f"MATCH (n:{self.namespace}) WHERE n.id IN $data RETURN n.id AS id" + async with self.driver.session() as session: + result = await session.run(query, data=data) + existing_keys = set() + async for record in result: + existing_keys.add(record["id"]) + return set(data) - existing_keys + + async def upsert(self, data: dict[str, dict]): + query = f""" + UNWIND $data AS row + MERGE (n:{self.namespace} {{id: row.id}}) + SET n += row.properties + RETURN n.id AS id + """ + async with self.driver.session() as session: + await session.run(query, data=[{"id": k, "properties": v} for k, v in data.items()]) + return data + + async def drop(self): + query = f"MATCH (n:{self.namespace}) DETACH DELETE n" + async with self.driver.session() as session: + await session.run(query) + +@dataclass +class Neo4jGraphStorage(BaseGraphStorage): + def __post_init__(self): + neo4j_config: dict = self.global_config.get("neo4j_config", {}) + self.uri = neo4j_config.get("uri") + self.username = neo4j_config.get("username") + self.password = neo4j_config.get("password") + + if not self.namespace: + self.namespace = neo4j_config.get("namespace", "default_namespace") + + self.driver = AsyncGraphDatabase.driver(self.uri, auth=(self.username, self.password)) + logger.info(f"Connected to Neo4j at {self.uri} with namespace '{self.namespace}'") + + # Initialize node embedding algorithms + self._node_embed_algorithms = { + "node2vec": self._node2vec_embed, + } + + async def close(self): + await self.driver.close() + + async def index_done_callback(self): + # Since we don't have a graphml file in Neo4j, we can perform any necessary finalization here + logger.info("Indexing done. You can add any finalization logic if needed.") + + async def has_node(self, node_id: str) -> bool: + query = f"MATCH (n:{self.namespace} {{id: $id}}) RETURN n LIMIT 1" + async with self.driver.session() as session: + result = await session.run(query, id=node_id) + record = await result.single() + return record is not None + + async def has_edge(self, source_node_id: str, target_node_id: str) -> bool: + query = f""" + MATCH (n:{self.namespace} {{id: $source_id}})-[r]->(m:{self.namespace} {{id: $target_id}}) + RETURN r LIMIT 1 + """ + async with self.driver.session() as session: + result = await session.run(query, source_id=source_node_id, target_id=target_node_id) + record = await result.single() + return record is not None + + async def get_node(self, node_id: str) -> Union[dict, None]: + query = f"MATCH (n:{self.namespace} {{id: $id}}) RETURN properties(n) AS props" + async with self.driver.session() as session: + result = await session.run(query, id=node_id) + record = await result.single() + if record: + return record["props"] + else: + return None + + async def node_degree(self, node_id: str) -> int: + query = f"MATCH (n:{self.namespace} {{id: $id}})-[r]-() RETURN count(r) as degree" + async with self.driver.session() as session: + result = await session.run(query, id=node_id) + record = await result.single() + if record: + return record["degree"] + else: + return 0 + + async def edge_degree(self, src_id: str, tgt_id: str) -> int: + src_degree = await self.node_degree(src_id) + tgt_degree = await self.node_degree(tgt_id) + return src_degree + tgt_degree + + async def get_edge(self, source_node_id: str, target_node_id: str) -> Union[dict, None]: + query = f""" + MATCH (n:{self.namespace} {{id: $source_id}})-[r]->(m:{self.namespace} {{id: $target_id}}) + RETURN properties(r) AS props + """ + async with self.driver.session() as session: + result = await session.run(query, source_id=source_node_id, target_id=target_node_id) + record = await result.single() + if record: + return record["props"] + else: + return None + + async def get_node_edges(self, source_node_id: str) -> List[Tuple[str, str]]: + query = f""" + MATCH (n:{self.namespace} {{id: $source_id}})-[r]->(m:{self.namespace}) + RETURN n.id AS source_id, m.id AS target_id + """ + async with self.driver.session() as session: + result = await session.run(query, source_id=source_node_id) + edges = [] + async for record in result: + source_id = record["source_id"] + target_id = record["target_id"] + edges.append((source_id, target_id)) + return edges + + async def upsert_node(self, node_id: str, node_data: dict[str, Any]): + query = f""" + MERGE (n:{self.namespace} {{id: $id}}) + SET n += $properties + """ + async with self.driver.session() as session: + await session.run(query, id=node_id, properties=node_data) + logger.info(f"Upserted node with ID '{node_id}' and label '{self.namespace}'") + + async def upsert_edge(self, source_node_id: str, target_node_id: str, edge_data: dict[str, Any]): + query = f""" + MERGE (source:{self.namespace} {{id: $source_id}}) + MERGE (target:{self.namespace} {{id: $target_id}}) + MERGE (source)-[r:RELATION]->(target) + SET r += $properties + """ + async with self.driver.session() as session: + await session.run(query, source_id=source_node_id, target_id=target_node_id, properties=edge_data) + logger.info(f"Upserted edge from '{source_node_id}' to '{target_node_id}' with label '{self.namespace}'") + + async def embed_nodes(self, algorithm: str) -> Tuple[np.ndarray, List[str]]: + if algorithm not in self._node_embed_algorithms: + raise ValueError(f"Node embedding algorithm {algorithm} not supported") + return await self._node_embed_algorithms[algorithm]() + + async def _node2vec_embed(self): + graph_name = f"{self.namespace}_graph" + async with self.driver.session() as session: + # Drop existing in-memory graph if exists + await session.run(f"CALL gds.graph.drop($graph_name, false) YIELD graphName", graph_name=graph_name) + + # Create an in-memory graph + await session.run(f""" + CALL gds.graph.project( + $graph_name, + $node_label, + {{ + RELATION: {{ + orientation: 'UNDIRECTED' + }} + }} + ) YIELD graphName, nodeCount, relationshipCount + """, graph_name=graph_name, node_label=self.namespace) + + # Run node2vec embedding + result = await session.run(f""" + CALL gds.node2vec.stream($graph_name, {{ + embeddingDimension: $dimensions, + walkLength: $walk_length, + walksPerNode: $num_walks, + windowSize: $window_size, + iterations: $iterations + }}) + YIELD nodeId, embedding + RETURN gds.util.asNode(nodeId).id AS id, embedding + """, + graph_name=graph_name, + dimensions=self.global_config["node2vec_params"]["dimensions"], + walk_length=self.global_config["node2vec_params"]["walk_length"], + num_walks=self.global_config["node2vec_params"]["num_walks"], + window_size=self.global_config["node2vec_params"]["window_size"], + iterations=self.global_config["node2vec_params"]["iterations"], + ) + embeddings = [] + node_ids = [] + async for record in result: + node_ids.append(record["id"]) + embeddings.append(record["embedding"]) + embeddings = np.array(embeddings) + return embeddings, node_ids + + async def drop(self): + # Deletes all nodes and relationships with the given namespace label + query = f"MATCH (n:{self.namespace}) DETACH DELETE n" + async with self.driver.session() as session: + await session.run(query) \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 98f32b0..e5439cf 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,3 +13,6 @@ tiktoken torch transformers xxhash +aiohttp +neo4j +pymilvus