Skip to content

Commit

Permalink
feat: Add document id and chunk ids to segments fixing order.
Browse files Browse the repository at this point in the history
  • Loading branch information
undo76 committed Nov 25, 2024
1 parent 003967b commit 87dc53a
Show file tree
Hide file tree
Showing 4 changed files with 249 additions and 72 deletions.
97 changes: 96 additions & 1 deletion src/raglite/_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@

import datetime
import json
from dataclasses import dataclass
from functools import lru_cache
from hashlib import sha256
from pathlib import Path
from typing import Any
from typing import Any, Callable
from xml.sax.saxutils import escape

import numpy as np
from markdown_it import MarkdownIt
Expand Down Expand Up @@ -332,3 +334,96 @@ def create_database_engine(config: RAGLiteConfig | None = None) -> Engine:
)
session.commit()
return engine


@dataclass
class ContextSegment:
"""A class representing a segment of context from a document.
This class holds information about a specific segment of a document,
including its document ID and associated chunks of text with their IDs and scores.
Attributes:
document_id (str): The unique identifier for the document.
chunks (list[Chunk]): List of chunks for this segment.
chunk_scores (list[float]): List of scores for each chunk.
Raises:
ValueError: If document_id is empty or if chunks is empty.
"""

document_id: str
chunks: list[Chunk]
chunk_scores: list[float]

def __post_init__(self) -> None:
"""Validate the segment data after initialization."""
if not isinstance(self.document_id, str) or not self.document_id.strip():
raise ValueError("document_id must be a non-empty string")
if not isinstance(self.chunks, list):
raise ValueError("chunks must be a list")
if not self.chunks:
raise ValueError("chunks cannot be empty")
if not all(isinstance(chunk, Chunk) for chunk in self.chunks):
raise ValueError("all elements in chunks must be Chunk instances")

def to_xml(self, indent: int = 4) -> str:
"""Convert the segment to an XML string representation.
Args:
indent (int): Number of spaces to use for indentation.
Returns:
str: XML representation of the segment.
"""
chunks_content = "\n".join((str(chunk) for chunk in self.chunks))

# Create the final XML
chunk_ids = ",".join(self.chunk_ids)
xml = f"""<document id="{escape(self.document_id)}" chunk_ids="{escape(chunk_ids)}">\n{escape(str(chunks_content))}\n</document>"""

return xml

def score(self, scoring_function: Callable[[list[float]], float] = sum) -> float:
"""Return an aggregated score of the segment, given a scoring function."""
return scoring_function(self.chunk_scores)

@property
def chunk_ids(self) -> list[str]:
"""Return a list of chunk IDs."""
return [chunk.id for chunk in self.chunks]

def __str__(self) -> str:
"""Return a string representation reconstructing the document with headings.
Shows each unique header exactly once, when it first appears.
For example:
- First chunk with "# A ## B" shows both headers
- Next chunk with "# A ## B" shows no headers as they're the same
- Next chunk with "# A ## C" only shows "## C" as it's the only new header
Returns:
str: A string containing content with each heading shown once.
"""
if not self.chunks:
return ""

result = []
seen_headers = set() # Track headers we've already shown

for chunk in self.chunks:
# Get all headers in this chunk
headers = [h.strip() for h in chunk.headings.split("\n") if h.strip()]

# Add any headers we haven't seen before
new_headers = [h for h in headers if h not in seen_headers]
if new_headers:
result.extend(new_headers)
result.append("") # Empty line after headers
seen_headers.update(new_headers) # Mark these headers as seen

# Add the chunk body if it's not empty
if chunk.body.strip():
result.append(chunk.body.strip())

return "\n".join(result).strip()
97 changes: 71 additions & 26 deletions src/raglite/_rag.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
"""Retrieval-augmented generation."""

from collections.abc import AsyncIterator, Iterator
from typing import Literal

from litellm import acompletion, completion

from raglite._config import RAGLiteConfig
from raglite._database import Chunk
from raglite._database import Chunk, ContextSegment
from raglite._litellm import get_context_size
from raglite._search import hybrid_search, rerank_chunks, retrieve_segments
from raglite._typing import SearchMethod
Expand Down Expand Up @@ -46,15 +47,15 @@ def _max_contexts(
return max_contexts


def _contexts( # noqa: PLR0913
def context_segments( # noqa: PLR0913
prompt: str,
*,
max_contexts: int = 5,
context_neighbors: tuple[int, ...] | None = (-1, 1),
search: SearchMethod | list[str] | list[Chunk] = hybrid_search,
messages: list[dict[str, str]] | None = None,
config: RAGLiteConfig | None = None,
) -> list[str]:
) -> list[ContextSegment]:
"""Retrieve contexts for RAG."""
# Determine the maximum number of contexts.
max_contexts = _max_contexts(
Expand All @@ -71,14 +72,18 @@ def _contexts( # noqa: PLR0913
# If the user has configured a reranker, we retrieve extra contexts to rerank.
extra_contexts = 3 * max_contexts if config.reranker else 0
# Retrieve relevant contexts.
chunk_ids, _ = search(prompt, num_results=max_contexts + extra_contexts, config=config)
chunk_ids, _ = search(
prompt, num_results=max_contexts + extra_contexts, config=config
)
# Rerank the relevant contexts.
chunks = rerank_chunks(query=prompt, chunk_ids=chunk_ids, config=config)
else:
# The user has passed a list of chunk_ids or chunks directly.
chunks = search
# Extend the top contexts with their neighbors and group chunks into contiguous segments.
segments = retrieve_segments(chunks[:max_contexts], neighbors=context_neighbors, config=config)
segments = retrieve_segments(
chunks[:max_contexts], neighbors=context_neighbors, config=config
)
return segments


Expand All @@ -95,29 +100,26 @@ def rag( # noqa: PLR0913
"""Retrieval-augmented generation."""
# Get the contexts for RAG as contiguous segments of chunks.
config = config or RAGLiteConfig()
segments = _contexts(
segments = context_segments(
prompt,
max_contexts=max_contexts,
context_neighbors=context_neighbors,
search=search,
config=config,
)
system_prompt = f"{system_prompt}\n\n" + "\n\n".join(
f'<context index="{i}">\n{segment.strip()}\n</context>'
for i, segment in enumerate(segments)
)
# Stream the LLM response.
stream = completion(
model=config.llm,
messages=[
*(messages or []),
{"role": "system", "content": system_prompt},
{"role": "user", "content": prompt},
],
messages=_compose_messages(
prompt=prompt,
system_prompt=system_prompt,
messages=messages,
segments=segments,
),
stream=True,
)
for output in stream:
token: str = output["choices"][0]["delta"].get("content") or ""
token: str = output["choices"][0]["delta"].get("content") or "" # type: ignore
yield token


Expand All @@ -134,27 +136,70 @@ async def async_rag( # noqa: PLR0913
"""Retrieval-augmented generation."""
# Get the contexts for RAG as contiguous segments of chunks.
config = config or RAGLiteConfig()
segments = _contexts(
segments = context_segments(
prompt,
max_contexts=max_contexts,
context_neighbors=context_neighbors,
search=search,
config=config,
)
system_prompt = f"{system_prompt}\n\n" + "\n\n".join(
f'<context index="{i}">\n{segment.strip()}\n</context>'
for i, segment in enumerate(segments)

messages = _compose_messages(
prompt=prompt,
system_prompt=system_prompt,
messages=messages,
segments=segments,
)
print(messages)
# Stream the LLM response.
async_stream = await acompletion(
model=config.llm,
messages=[
*(messages or []),
{"role": "system", "content": system_prompt},
{"role": "user", "content": prompt},
],
messages=messages,
stream=True,
)
async for output in async_stream:
token: str = output["choices"][0]["delta"].get("content") or ""
token: str = output["choices"][0]["delta"].get("content") or "" # type: ignore
yield token


def _compose_messages(
prompt: str,
system_prompt: str,
messages: list[dict[str, str]],
segments: list[ContextSegment],
context_placement: Literal[
"system_prompt", "user_prompt", "separate_system_prompt"
] = "user_prompt",
) -> list[dict[str, str]]:
"""Compose the messages for the LLM, placing the context in the desired position."""

# Using the format recommended by Anthropic for documents in RAG
# (https://docs.anthropic.com/en/docs/build-with-claude/prompt-engineering/long-context-tips#essential-tips-for-long-context-prompts
context_content = f"\n\n<documents>\n{'\n\n'.join(seg.to_xml() for seg in segments)}\n</documents>"
if not segments:
return [
{"role": "system", "content": system_prompt},
*(messages or []),
{"role": "user", "content": prompt},
]
if context_placement == "system_prompt":
return [
{"role": "system", "content": system_prompt + "\n\n" + context_content},
*(messages or []),
{"role": "user", "content": prompt},
]
if context_placement == "user_prompt":
return [
{"role": "system", "content": system_prompt},
*(messages or []),
{"role": "user", "content": prompt + "\n\n" + context_content},
]
if context_placement == "separate_system_prompt":
return [
{"role": "system", "content": system_prompt},
*(messages or []),
{"role": "system", "content": context_content},
{"role": "user", "content": prompt},
]
else:
raise ValueError("Invalid context placement.")
Loading

0 comments on commit 87dc53a

Please sign in to comment.