Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added session ID as a contextVar #327

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
148 changes: 148 additions & 0 deletions src/ell/api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
from sqlmodel import Session

from ell.configurator import config
from ell.ctxt import get_session_id
from typing import Dict, List, Optional, Set, Any
from ell.types import SerializedLMP, Invocation, InvocationContents


def write_lmp(serialized_lmp: SerializedLMP, uses: Dict[str, Any]) -> Optional[SerializedLMP]:
"""
Write a serialized LMP to the store.

:param serialized_lmp: The SerializedLMP object to write.
:param uses: A dictionary of LMPs that this LMP uses.
:return: The written LMP or None if it already exists.
"""
return config.store.write_lmp(serialized_lmp, uses)

def write_invocation(invocation: Invocation, consumes: Set[str]) -> Optional[Any]:
"""
Write an invocation to the store.

:param invocation: The Invocation object to write.
:param consumes: A set of invocation IDs that this invocation consumes.
:return: None
"""
return config.store.write_invocation(invocation, consumes)

def get_invocations_by_session_id(session_id: str = "") -> List[Invocation]:
"""
Retrieve invocations by session ID.

:param session_id: The session ID to filter by.
:return: A list of Invocation objects.
"""
session_id = session_id or get_session_id()
return config.store.get_invocations_by_sessionid(session_id)

def get_cached_invocations(lmp_id: str, state_cache_key: str) -> List[Invocation]:
"""
Retrieve cached invocations for a given LMP and state cache key.

:param lmp_id: The ID of the LMP.
:param state_cache_key: The state cache key.
:return: A list of Invocation objects.
"""
return config.store.get_cached_invocations(lmp_id, state_cache_key)

def get_cached_invocations_contents(lmp_id: str, state_cache_key: str) -> List[InvocationContents]:
"""
Retrieve contents of cached invocations for a given LMP and state cache key.

:param lmp_id: The ID of the LMP.
:param state_cache_key: The state cache key.
:return: A list of InvocationContents objects.
"""
return config.store.get_cached_invocations_contents(lmp_id, state_cache_key)

def get_lmp(lmp_id: str) -> Optional[SerializedLMP]:
"""
Retrieve an LMP by its ID.

:param lmp_id: The ID of the LMP to retrieve.
:return: A SerializedLMP object or None if not found.
"""
return config.store.get_lmp(lmp_id)

def get_versions_by_fqn(fqn: str) -> List[SerializedLMP]:
"""
Retrieve all versions of an LMP by its fully qualified name.

:param fqn: The fully qualified name of the LMP.
:return: A list of SerializedLMP objects.
"""
return config.store.get_versions_by_fqn(fqn)

def get_latest_lmps(skip: int = 0, limit: int = 10) -> List[Dict[str, Any]]:
"""
Retrieve the latest LMPs.

:param skip: Number of records to skip.
:param limit: Maximum number of records to return.
:return: A list of SerializedLMP objects.
"""
with Session(config.store.engine) as session:
return config.store.get_latest_lmps(session, skip, limit)

def get_invocations_by_lmp_name(lmp_name: str, skip: int = 0, limit: int = 10) -> List[Dict[str, Any]]:
"""
Retrieve invocations for a given LMP name, sorted by creation time in descending order.

:param lmp_name: The name of the LMP.
:param skip: Number of records to skip (for pagination).
:param limit: Maximum number of records to return.
:return: A list of Invocation objects sorted by creation time in descending order.
"""
lmp_filters = {"name": lmp_name}
filters = None # The sorting is handled by default in the get_invocations method

with Session(config.store.engine) as session:
return config.store.get_invocations(session, lmp_filters, skip=skip, limit=limit, filters=filters)

def get_lmps(skip: int = 0, limit: int = 10, **filters: Any) -> List[Dict[str, Any]]:
"""
Retrieve LMPs based on filters.

:param skip: Number of records to skip.
:param limit: Maximum number of records to return.
:param filters: Additional filters to apply.
:return: A list of SerializedLMP objects.
"""
with Session(config.store.engine) as session:
return config.store.get_lmps(session, skip, limit, **filters)

def get_invocations(lmp_filters: Dict[str, Any], skip: int = 0, limit: int = 10, filters: Optional[Dict[str, Any]] = None, hierarchical: bool = False) -> List[Dict[str, Any]]:
"""
Retrieve invocations based on filters.

:param lmp_filters: Filters to apply to the LMP.
:param skip: Number of records to skip.
:param limit: Maximum number of records to return.
:param filters: Additional filters to apply to the invocation.
:param hierarchical: Whether to return hierarchical results.
:return: A list of Invocation objects.
"""
with Session(config.store.engine) as session:
return config.store.get_invocations(session, lmp_filters, skip, limit, filters, hierarchical)

def get_traces() -> List[Dict[str, Any]]:
"""
Retrieve all traces.

:return: A list of trace dictionaries.
"""
with Session(config.store.engine) as session:
return config.store.get_traces(session)

def get_invocations_aggregate(lmp_filters: Optional[Dict[str, Any]] = None, filters: Optional[Dict[str, Any]] = None, days: int = 30) -> Dict[str, Any]:
"""
Retrieve aggregate data for invocations.

:param lmp_filters: Filters to apply to the LMP.
:param filters: Additional filters to apply to the invocation.
:param days: Number of days to include in the aggregation.
:return: A dictionary containing aggregate data and graph data.
"""
with Session(config.store.engine) as session:
return config.store.get_invocations_aggregate(session, lmp_filters, filters, days);
17 changes: 17 additions & 0 deletions src/ell/ctxt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from contextvars import ContextVar
from uuid import uuid4

session_id_context: ContextVar[str] = ContextVar('session_id', default='')

def get_session_id() -> str:
"""Get current session ID or create new one"""
session_id = session_id_context.get()
if not session_id:
session_id = str(uuid4())
session_id_context.set(session_id)

return session_id

def set_session_id(session_id: str) -> None:
"""Set session ID in current context"""
session_id_context.set(session_id)
7 changes: 7 additions & 0 deletions src/ell/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,13 @@ def write_invocation(self, invocation: Invocation, consumes: Set[str]) -> Optio
"""
pass

@abstractmethod
def get_invocations_by_session_id(self, session_id: str) -> List[Invocation]:
"""
Get all invocations for a given session ID.
"""
pass

@abstractmethod
def get_cached_invocations(self, lmp_id :str, state_cache_key :str) -> List[Invocation]:
"""
Expand Down
4 changes: 4 additions & 0 deletions src/ell/stores/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,10 @@ def write_invocation(self, invocation: Invocation, consumes: Set[str]) -> Option
session.commit()
return None

def get_invocations_by_session_id(self, session_id: str) -> List[Invocation]:
with Session(self.engine) as session:
return self.get_invocations(session, lmp_filters={}, filters={"session_id": session_id})

def get_cached_invocations(self, lmp_id :str, state_cache_key :str) -> List[Invocation]:
with Session(self.engine) as session:
return self.get_invocations(session, lmp_filters={"lmp_id": lmp_id}, filters={"state_cache_key": state_cache_key})
Expand Down
1 change: 1 addition & 0 deletions src/ell/types/studio.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ class InvocationBase(SQLModel):
completion_tokens: Optional[int] = Field(default=None)
state_cache_key: Optional[str] = Field(default=None)
created_at: datetime = UTCTimestampField(default=func.now(), nullable=False)
session_id: Optional[str] = Field(default=None)
used_by_id: Optional[str] = Field(default=None, foreign_key="invocation.id", index=True)
# global_vars and free_vars removed from here

Expand Down