Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions logfire/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
instrument_redis = DEFAULT_LOGFIRE_INSTANCE.instrument_redis
instrument_pymongo = DEFAULT_LOGFIRE_INSTANCE.instrument_pymongo
instrument_mysql = DEFAULT_LOGFIRE_INSTANCE.instrument_mysql
instrument_surrealdb = DEFAULT_LOGFIRE_INSTANCE.instrument_surrealdb
instrument_system_metrics = DEFAULT_LOGFIRE_INSTANCE.instrument_system_metrics
instrument_mcp = DEFAULT_LOGFIRE_INSTANCE.instrument_mcp
suppress_scopes = DEFAULT_LOGFIRE_INSTANCE.suppress_scopes
Expand Down Expand Up @@ -148,6 +149,7 @@ def loguru_handler() -> Any:
'instrument_redis',
'instrument_pymongo',
'instrument_mysql',
'instrument_surrealdb',
'instrument_system_metrics',
'instrument_mcp',
'AutoTraceModule',
Expand Down
111 changes: 111 additions & 0 deletions logfire/_internal/integrations/surrealdb.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
from __future__ import annotations

import functools
import inspect
import uuid
from typing import Any, Union, get_args, get_origin

from surrealdb.connections.async_template import AsyncTemplate
from surrealdb.connections.sync_template import SyncTemplate
from surrealdb.data.types.record_id import RecordIdType
from surrealdb.data.types.table import Table
from surrealdb.types import Value

from logfire._internal.main import Logfire


def is_complex_type(tp: type | type[Value]) -> bool:
origin = get_origin(tp)
if origin in {list, dict, set, tuple}:
return True
if tp in {Value}:
return True
if tp in (str, bool, int, float, type(None), uuid.UUID, Table, RecordIdType):
return False
if origin is Union: # pragma: no branch
args = get_args(tp)
return any(is_complex_type(arg) for arg in args)
return True # pragma: no cover


def get_all_subclasses(cls: type) -> set[type]:
subclasses: set[type] = set()
for subclass in cls.__subclasses__():
subclasses.add(subclass)
subclasses.update(get_all_subclasses(subclass))
return subclasses


def get_all_surrealdb_classes() -> set[type]:
return get_all_subclasses(SyncTemplate) | get_all_subclasses(AsyncTemplate)


def instrument_surrealdb(obj: Any, logfire_instance: Logfire):
logfire_instance = logfire_instance.with_settings(custom_scope_suffix='surrealdb')
if obj is None:
for cls in get_all_surrealdb_classes():
instrument_surrealdb(cls, logfire_instance)
return

for name, template_method in inspect.getmembers(AsyncTemplate):
if not (
inspect.isfunction(template_method)
and not name.startswith('_')
and name != 'connect' # weird case that differs between classes
and AsyncTemplate.__dict__.get(name) == template_method
):
continue
patch_method(obj, name, logfire_instance)


def patch_method(obj: Any, method_name: str, logfire_instance: Logfire):
original_method = getattr(obj, method_name, None)
if not original_method or hasattr(original_method, '_logfire_template'):
return # already patched

sig = inspect.signature(original_method)
template_params: list[str] = []
scrubber = logfire_instance.config.scrubber
for param_name, param in sig.parameters.items():
if param_name == 'self':
continue
assert param.annotation is not inspect.Parameter.empty
_, scrubbed = scrubber.scrub_value(path=(param_name,), value=None)
if not is_complex_type(param.annotation) and not scrubbed:
template_params.append(param_name)
template = span_name = f'surrealdb {method_name}'
if len(template_params) == 1:
template += f' {{{template_params[0]}}}'
elif len(template_params) > 1:
template += ' ' + ', '.join(f'{p} = {{{p}}}' for p in template_params)

def get_params(*args: Any, **kwargs: Any) -> dict[str, Any]:
bound = sig.bind(*args, **kwargs)
params = bound.arguments
params.pop('self', None)
return params

if inspect.isgeneratorfunction(original_method) or inspect.isasyncgenfunction(original_method):

@functools.wraps(original_method)
def wrapped_method(*args: Any, **kwargs: Any) -> Any:
logfire_instance.info(template, **get_params(*args, **kwargs))
return original_method(*args, **kwargs)

elif inspect.iscoroutinefunction(original_method):

@functools.wraps(original_method)
async def wrapped_method(*args: Any, **kwargs: Any) -> Any: # pyright: ignore[reportRedeclaration]
with logfire_instance.span(template, **get_params(*args, **kwargs), _span_name=span_name):
return await original_method(*args, **kwargs)

else:

@functools.wraps(original_method)
def wrapped_method(*args: Any, **kwargs: Any) -> Any:
with logfire_instance.span(template, **get_params(*args, **kwargs), _span_name=span_name):
return original_method(*args, **kwargs)

wrapped_method._logfire_template = template # type: ignore

setattr(obj, method_name, wrapped_method)
6 changes: 6 additions & 0 deletions logfire/_internal/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -926,6 +926,12 @@ def install_auto_tracing(
def _warn_if_not_initialized_for_instrumentation(self):
self.config.warn_if_not_initialized('Instrumentation will have no effect')

def instrument_surrealdb(self, obj: Any = None) -> None:
from .integrations.surrealdb import instrument_surrealdb

self._warn_if_not_initialized_for_instrumentation()
instrument_surrealdb(obj, self)

def instrument_mcp(self, *, propagate_otel_context: bool = True) -> None:
"""Instrument the [MCP Python SDK](https://github.com/modelcontextprotocol/python-sdk).

Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@ dev = [
"openinference-instrumentation-litellm >= 0",
"litellm != 1.80.9",
"pip >= 0",
"surrealdb >= 0",
]
docs = [
"black>=23.12.0",
Expand Down
Loading
Loading