From 6f2e2897df568fe6b07416b443fdfa2670152f3e Mon Sep 17 00:00:00 2001 From: Jamie Broomall <88007022+jamie256@users.noreply.github.com> Date: Mon, 24 Jul 2023 19:34:54 +0200 Subject: [PATCH] Update langchain callback handler to capture metadata, correlated prompt/responses (#104) --- langkit/callback_handler.py | 46 ++++++++++++++++++++++++++++++++----- 1 file changed, 40 insertions(+), 6 deletions(-) diff --git a/langkit/callback_handler.py b/langkit/callback_handler.py index 37fd838d..be327444 100644 --- a/langkit/callback_handler.py +++ b/langkit/callback_handler.py @@ -1,6 +1,7 @@ import inspect from functools import partial from logging import getLogger +from time import time from typing import Any, Callable, Dict, List, Optional, Union from whylogs.api.logger.logger import Logger @@ -60,24 +61,57 @@ def __init__(self, logger: Logger): diagnostic_logger.info( f"Initialized LangKitCallback handler with configured whylogs Logger {logger}." ) + self.records: Dict[str, Any] = dict() - def _profile_generations(self, generations: List[Any]) -> None: + def _extract_generation_responses( + self, generations: List[Any] + ) -> List[Dict[str, Any]]: + responses = list() for gen in generations: if hasattr(gen, "text"): - self._logger.log({"response": gen.text}) + responses.append({"response": gen.text}) + return responses # Start LLM events def on_llm_start( self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any ) -> None: """Pass the input prompts to the logger""" - for prompt in prompts: - self._logger.log({"prompt": prompt}) + invocation_params = kwargs.get("invocation_params") + run_id = kwargs.get("run_id", 0) + self.records[run_id] = {"prompts": prompts, "t0": time()} + if hasattr(self._logger, "_current_profile"): + profile = self._logger._current_profile + if invocation_params is not None: + profile.track( + { + "invocation_params." + key: value + for key, value in invocation_params.items() + }, + execute_udfs=False, + ) def on_llm_end(self, response: Any, **kwargs: Any) -> None: """Pass the generated response to the logger.""" - for generations in response.generations: - self._profile_generations(generations) + run_id = kwargs.get("run_id", 0) + llm_record = self.records.get(run_id) + if llm_record is not None: + response_latency_s = time() - llm_record["t0"] + self._logger.log({"response_latency_s": response_latency_s}) + index = 0 + prompts = llm_record["prompts"] + for generations in response.generations: + responses = self._extract_generation_responses(generations) + for response_record in responses: + response_record.update({"prompt": prompts[index]}) + self._logger.log(response_record) + index = index + 1 + + if hasattr(response, "llm_output"): + llm_output = response.llm_output + token_usage = llm_output.get("token_usage") + if token_usage: + self._logger.log(token_usage) def on_llm_new_token(self, token: str, **kwargs: Any) -> None: diagnostic_logger.debug(f"on_llm_new_token({token})")