Skip to content

Commit

Permalink
Update langchain callback handler to capture metadata, correlated pro…
Browse files Browse the repository at this point in the history
…mpt/responses (#104)
  • Loading branch information
jamie256 authored Jul 24, 2023
1 parent 7f0513a commit 6f2e289
Showing 1 changed file with 40 additions and 6 deletions.
46 changes: 40 additions & 6 deletions langkit/callback_handler.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import inspect
from functools import partial
from logging import getLogger
from time import time
from typing import Any, Callable, Dict, List, Optional, Union
from whylogs.api.logger.logger import Logger

Expand Down Expand Up @@ -60,24 +61,57 @@ def __init__(self, logger: Logger):
diagnostic_logger.info(
f"Initialized LangKitCallback handler with configured whylogs Logger {logger}."
)
self.records: Dict[str, Any] = dict()

def _profile_generations(self, generations: List[Any]) -> None:
def _extract_generation_responses(
self, generations: List[Any]
) -> List[Dict[str, Any]]:
responses = list()
for gen in generations:
if hasattr(gen, "text"):
self._logger.log({"response": gen.text})
responses.append({"response": gen.text})
return responses

# Start LLM events
def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None:
"""Pass the input prompts to the logger"""
for prompt in prompts:
self._logger.log({"prompt": prompt})
invocation_params = kwargs.get("invocation_params")
run_id = kwargs.get("run_id", 0)
self.records[run_id] = {"prompts": prompts, "t0": time()}
if hasattr(self._logger, "_current_profile"):
profile = self._logger._current_profile
if invocation_params is not None:
profile.track(
{
"invocation_params." + key: value
for key, value in invocation_params.items()
},
execute_udfs=False,
)

def on_llm_end(self, response: Any, **kwargs: Any) -> None:
"""Pass the generated response to the logger."""
for generations in response.generations:
self._profile_generations(generations)
run_id = kwargs.get("run_id", 0)
llm_record = self.records.get(run_id)
if llm_record is not None:
response_latency_s = time() - llm_record["t0"]
self._logger.log({"response_latency_s": response_latency_s})
index = 0
prompts = llm_record["prompts"]
for generations in response.generations:
responses = self._extract_generation_responses(generations)
for response_record in responses:
response_record.update({"prompt": prompts[index]})
self._logger.log(response_record)
index = index + 1

if hasattr(response, "llm_output"):
llm_output = response.llm_output
token_usage = llm_output.get("token_usage")
if token_usage:
self._logger.log(token_usage)

def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
diagnostic_logger.debug(f"on_llm_new_token({token})")
Expand Down

0 comments on commit 6f2e289

Please sign in to comment.