Skip to content

Commit

Permalink
chore(llmobs): refactor to use span events
Browse files Browse the repository at this point in the history
The LLMObs service formerly depended on the TraceProcessor interface in the
tracer. This was problematic due to sharing a dependency with the public API.
As such, users could configure a trace filter (under the hood is a trace
processor) and overwrite the LLMObs TraceProcessor.

Instead, the tracer can emit span start and finish events which the LLMObs
service listens to and acts on, as proposed here.

The gotcha is that the LLMObs service no longer has a way to drop traces when
run in agentless mode, which only LLMObs supports. Instead, we encourage users
to explicitly turn off APM which carries the benefit of clarity since this was
implicit before.
  • Loading branch information
Kyle-Verhoog committed Dec 20, 2024
1 parent b253aa3 commit 959432a
Show file tree
Hide file tree
Showing 7 changed files with 151 additions and 285 deletions.
5 changes: 4 additions & 1 deletion ddtrace/_trace/tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from ddtrace.internal.atexit import register_on_exit_signal
from ddtrace.internal.constants import SAMPLING_DECISION_TRACE_TAG_KEY
from ddtrace.internal.constants import SPAN_API_DATADOG
from ddtrace.internal.core import dispatch
from ddtrace.internal.dogstatsd import get_dogstatsd_client
from ddtrace.internal.logger import get_logger
from ddtrace.internal.peer_service.processor import PeerServiceProcessor
Expand Down Expand Up @@ -866,7 +867,7 @@ def _start_span(
for p in chain(self._span_processors, SpanProcessor.__processors__, self._deferred_processors):
p.on_span_start(span)
self._hooks.emit(self.__class__.start_span, span)

dispatch("trace.span_start", (span,))
return span

start_span = _start_span
Expand All @@ -883,6 +884,8 @@ def _on_span_finish(self, span: Span) -> None:
for p in chain(self._span_processors, SpanProcessor.__processors__, self._deferred_processors):
p.on_span_finish(span)

dispatch("trace.span_finish", (span,))

if log.isEnabledFor(logging.DEBUG):
log.debug("finishing span %s (enabled:%s)", span._pprint(), self.enabled)

Expand Down
162 changes: 141 additions & 21 deletions ddtrace/llmobs/_llmobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
import time
from typing import Any
from typing import Dict
from typing import List
from typing import Optional
from typing import Tuple
from typing import Union

import ddtrace
Expand All @@ -13,6 +15,7 @@
from ddtrace._trace.context import Context
from ddtrace.ext import SpanTypes
from ddtrace.internal import atexit
from ddtrace.internal import core
from ddtrace.internal import forksafe
from ddtrace.internal._rand import rand64bits
from ddtrace.internal.compat import ensure_text
Expand Down Expand Up @@ -45,11 +48,11 @@
from ddtrace.llmobs._constants import SPAN_START_WHILE_DISABLED_WARNING
from ddtrace.llmobs._constants import TAGS
from ddtrace.llmobs._evaluators.runner import EvaluatorRunner
from ddtrace.llmobs._trace_processor import LLMObsTraceProcessor
from ddtrace.llmobs._utils import AnnotationContext
from ddtrace.llmobs._utils import _get_llmobs_parent_id
from ddtrace.llmobs._utils import _get_ml_app
from ddtrace.llmobs._utils import _get_session_id
from ddtrace.llmobs._utils import _get_span_name
from ddtrace.llmobs._utils import _inject_llmobs_parent_id
from ddtrace.llmobs._utils import safe_json
from ddtrace.llmobs._utils import validate_prompt
Expand All @@ -60,6 +63,11 @@
from ddtrace.llmobs.utils import Messages
from ddtrace.propagation.http import HTTPPropagator

from ..constants import ERROR_MSG
from ..constants import ERROR_STACK
from ..constants import ERROR_TYPE
from . import _constants as constants


log = get_logger(__name__)

Expand All @@ -81,34 +89,157 @@ class LLMObs(Service):
def __init__(self, tracer=None):
super(LLMObs, self).__init__()
self.tracer = tracer or ddtrace.tracer
self._llmobs_span_writer = None

self._llmobs_span_writer = LLMObsSpanWriter(
is_agentless=config._llmobs_agentless_enabled,
interval=float(os.getenv("_DD_LLMOBS_WRITER_INTERVAL", 1.0)),
timeout=float(os.getenv("_DD_LLMOBS_WRITER_TIMEOUT", 5.0)),
)

self._llmobs_eval_metric_writer = LLMObsEvalMetricWriter(
site=config._dd_site,
api_key=config._dd_api_key,
interval=float(os.getenv("_DD_LLMOBS_WRITER_INTERVAL", 1.0)),
timeout=float(os.getenv("_DD_LLMOBS_WRITER_TIMEOUT", 5.0)),
)

self._evaluator_runner = EvaluatorRunner(
interval=float(os.getenv("_DD_LLMOBS_EVALUATOR_INTERVAL", 1.0)),
llmobs_service=self,
)

self._trace_processor = LLMObsTraceProcessor(self._llmobs_span_writer, self._evaluator_runner)
forksafe.register(self._child_after_fork)

self._annotations = []
self._annotation_context_lock = forksafe.RLock()
self.tracer.on_start_span(self._do_annotations)

def _do_annotations(self, span):
# Register hooks for span events
core.on("trace.span_start", self._do_annotations)
core.on("trace.span_finish", self._on_span_finish)

def _on_span_finish(self, span):
if self.enabled and span.span_type == SpanTypes.LLM:
self._submit_llmobs_span(span)

def _submit_llmobs_span(self, span: Span) -> None:
"""Generate and submit an LLMObs span event to be sent to LLMObs."""
span_event = None
is_llm_span = span._get_ctx_item(SPAN_KIND) == "llm"
is_ragas_integration_span = False
try:
span_event, is_ragas_integration_span = self._llmobs_span_event(span)
self._llmobs_span_writer.enqueue(span_event)
except (KeyError, TypeError):
log.error(
"Error generating LLMObs span event for span %s, likely due to malformed span", span, exc_info=True
)
finally:
if not span_event or not is_llm_span or is_ragas_integration_span:
return
if self._evaluator_runner:
self._evaluator_runner.enqueue(span_event, span)

@classmethod
def _llmobs_span_event(cls, span: Span) -> Tuple[Dict[str, Any], bool]:
"""Span event object structure."""
span_kind = span._get_ctx_item(SPAN_KIND)
if not span_kind:
raise KeyError("Span kind not found in span context")
meta: Dict[str, Any] = {"span.kind": span_kind, "input": {}, "output": {}}
if span_kind in ("llm", "embedding") and span._get_ctx_item(MODEL_NAME) is not None:
meta["model_name"] = span._get_ctx_item(MODEL_NAME)
meta["model_provider"] = (span._get_ctx_item(MODEL_PROVIDER) or "custom").lower()
meta["metadata"] = span._get_ctx_item(METADATA) or {}
if span._get_ctx_item(INPUT_PARAMETERS):
meta["input"]["parameters"] = span._get_ctx_item(INPUT_PARAMETERS)
if span_kind == "llm" and span._get_ctx_item(INPUT_MESSAGES) is not None:
meta["input"]["messages"] = span._get_ctx_item(INPUT_MESSAGES)
if span._get_ctx_item(INPUT_VALUE) is not None:
meta["input"]["value"] = safe_json(span._get_ctx_item(INPUT_VALUE))
if span_kind == "llm" and span._get_ctx_item(OUTPUT_MESSAGES) is not None:
meta["output"]["messages"] = span._get_ctx_item(OUTPUT_MESSAGES)
if span_kind == "embedding" and span._get_ctx_item(INPUT_DOCUMENTS) is not None:
meta["input"]["documents"] = span._get_ctx_item(INPUT_DOCUMENTS)
if span._get_ctx_item(OUTPUT_VALUE) is not None:
meta["output"]["value"] = safe_json(span._get_ctx_item(OUTPUT_VALUE))
if span_kind == "retrieval" and span._get_ctx_item(OUTPUT_DOCUMENTS) is not None:
meta["output"]["documents"] = span._get_ctx_item(OUTPUT_DOCUMENTS)
if span._get_ctx_item(INPUT_PROMPT) is not None:
prompt_json_str = span._get_ctx_item(INPUT_PROMPT)
if span_kind != "llm":
log.warning(
"Dropping prompt on non-LLM span kind, annotating prompts is only supported for LLM span kinds."
)
else:
meta["input"]["prompt"] = prompt_json_str
if span.error:
meta.update(
{
ERROR_MSG: span.get_tag(ERROR_MSG),
ERROR_STACK: span.get_tag(ERROR_STACK),
ERROR_TYPE: span.get_tag(ERROR_TYPE),
}
)
if not meta["input"]:
meta.pop("input")
if not meta["output"]:
meta.pop("output")
metrics = span._get_ctx_item(METRICS) or {}
ml_app = _get_ml_app(span)

is_ragas_integration_span = False

if ml_app.startswith(constants.RAGAS_ML_APP_PREFIX):
is_ragas_integration_span = True

span._set_ctx_item(ML_APP, ml_app)
parent_id = str(_get_llmobs_parent_id(span) or "undefined")

llmobs_span_event = {
"trace_id": "{:x}".format(span.trace_id),
"span_id": str(span.span_id),
"parent_id": parent_id,
"name": _get_span_name(span),
"start_ns": span.start_ns,
"duration": span.duration_ns,
"status": "error" if span.error else "ok",
"meta": meta,
"metrics": metrics,
}
session_id = _get_session_id(span)
if session_id is not None:
span._set_ctx_item(SESSION_ID, session_id)
llmobs_span_event["session_id"] = session_id

llmobs_span_event["tags"] = cls._llmobs_tags(
span, ml_app, session_id, is_ragas_integration_span=is_ragas_integration_span
)
return llmobs_span_event, is_ragas_integration_span

@staticmethod
def _llmobs_tags(
span: Span, ml_app: str, session_id: Optional[str] = None, is_ragas_integration_span: bool = False
) -> List[str]:
tags = {
"version": config.version or "",
"env": config.env or "",
"service": span.service or "",
"source": "integration",
"ml_app": ml_app,
"ddtrace.version": ddtrace.__version__,
"language": "python",
"error": span.error,
}
err_type = span.get_tag(ERROR_TYPE)
if err_type:
tags["error_type"] = err_type
if session_id:
tags["session_id"] = session_id
if is_ragas_integration_span:
tags[constants.RUNNER_IS_INTEGRATION_SPAN_TAG] = "ragas"
existing_tags = span._get_ctx_item(TAGS)
if existing_tags is not None:
tags.update(existing_tags)
return ["{}:{}".format(k, v) for k, v in tags.items()]

def _do_annotations(self, span: Span) -> None:
# get the current span context
# only do the annotations if it matches the context
if span.span_type != SpanTypes.LLM: # do this check to avoid the warning log in `annotate`
Expand All @@ -120,20 +251,14 @@ def _do_annotations(self, span):
if current_context_id == context_id:
self.annotate(span, **annotation_kwargs)

def _child_after_fork(self):
def _child_after_fork(self) -> None:
self._llmobs_span_writer = self._llmobs_span_writer.recreate()
self._llmobs_eval_metric_writer = self._llmobs_eval_metric_writer.recreate()
self._evaluator_runner = self._evaluator_runner.recreate()
self._trace_processor._span_writer = self._llmobs_span_writer
self._trace_processor._evaluator_runner = self._evaluator_runner
if self.enabled:
self._start_service()

def _start_service(self) -> None:
tracer_filters = self.tracer._filters
if not any(isinstance(tracer_filter, LLMObsTraceProcessor) for tracer_filter in tracer_filters):
tracer_filters += [self._trace_processor]
self.tracer.configure(settings={"FILTERS": tracer_filters})
try:
self._llmobs_span_writer.start()
self._llmobs_eval_metric_writer.start()
Expand All @@ -160,11 +285,7 @@ def _stop_service(self) -> None:
except ServiceStatusError:
log.debug("Error stopping LLMObs writers")

try:
forksafe.unregister(self._child_after_fork)
self.tracer.shutdown()
except Exception:
log.warning("Failed to shutdown tracer", exc_info=True)
forksafe.unregister(self._child_after_fork)

@classmethod
def enable(
Expand Down Expand Up @@ -265,7 +386,6 @@ def disable(cls) -> None:

cls._instance.stop()
cls.enabled = False
cls._instance.tracer.deregister_on_start_span(cls._instance._do_annotations)
telemetry_writer.product_activated(TELEMETRY_APM_PRODUCT.LLMOBS, False)

log.debug("%s disabled", cls.__name__)
Expand Down
Loading

0 comments on commit 959432a

Please sign in to comment.