Skip to content

Commit

Permalink
Add DynamicCallbackAdapter and LangkitCallback (#86)
Browse files Browse the repository at this point in the history
* Add DynamicCallbackAdapter and LangkitCallback
* Simplifications, pre-commit fixes
  • Loading branch information
jamie256 authored Jul 11, 2023
1 parent 7906380 commit 78a2011
Show file tree
Hide file tree
Showing 6 changed files with 406 additions and 11 deletions.
221 changes: 221 additions & 0 deletions langkit/callback_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,221 @@
import inspect
from functools import partial
from logging import getLogger
from typing import Any, Callable, Dict, List, Optional, Union
from whylogs.api.logger.logger import Logger


diagnostic_logger = getLogger(__name__)


def _flex_call(func, *args, **kwargs):
result = None
try:
sig = inspect.signature(func)
params = sig.parameters
# if params has a **kwargs style variable arguments then we don't need to
# remove extra parameters in the filtered_kwargs below.
has_varargs = any(param.kind == param.VAR_KEYWORD for param in params.values())

# Helper to map position args to keyword args, so we can then check for missing arguments.
positional_to_named_args = dict(zip(params.keys(), args))
all_kwargs = {**positional_to_named_args, **kwargs}
# Also remove arguments passed in that the func cannot accept
filtered_kwargs = (
all_kwargs
if has_varargs
else {k: v for k, v in all_kwargs.items() if k in params}
)

for key, param in params.items():
if key not in all_kwargs and param.default is inspect.Parameter.empty:
filtered_kwargs[key] = None
diagnostic_logger.info(f"missing arg {key}, passing in {key}=None")

result = func(**filtered_kwargs)
except Exception as e:
diagnostic_logger.warning(
f"Error calling {func}(args{args}, kwargs{kwargs}) -> error: {e}"
)
return result


def _generate_callback_wrapper(handler) -> Dict[str, partial]:
public_methods = [
method
for method in dir(handler)
if callable(getattr(handler, method)) and not method.startswith("_")
]
callbacks = {
method: partial(_flex_call, getattr(handler, method))
for method in public_methods
}
return callbacks


class LangKitCallback:
def __init__(self, logger: Logger):
"""Bind the configured logger for this langKit callback handler."""
self._logger = logger
diagnostic_logger.info(
f"Initialized LangKitCallback handler with configured whylogs Logger {logger}."
)

def _profile_generations(self, generations: List[Any]) -> None:
for gen in generations:
if hasattr(gen, "text"):
self._logger.log({"response": gen.text})

# Start LLM events
def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None:
"""Pass the input prompts to the logger"""
for prompt in prompts:
self._logger.log({"prompt": prompt})

def on_llm_end(self, response: Any, **kwargs: Any) -> None:
"""Pass the generated response to the logger."""
for generations in response.generations:
self._profile_generations(generations)

def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
diagnostic_logger.debug(f"on_llm_new_token({token})")

def on_llm_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
diagnostic_logger.debug(f"on_llm_error(error={error}, kwargs={kwargs})")

def on_chain_start(
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
) -> None:
diagnostic_logger.debug(
f"on_chain_start(serialized={serialized}, inputs={inputs}, kwargs={kwargs})"
)

def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
diagnostic_logger.debug(f"on_chain_end(outputs={outputs}, kwargs={kwargs})")

def on_chain_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
diagnostic_logger.debug(f"on_chain_error(error={error}, kwargs={kwargs})")

def on_tool_start(
self,
serialized: Dict[str, Any],
input_str: str,
**kwargs: Any,
) -> None:
diagnostic_logger.debug(
f"on_chain_start(serialized={serialized}, input_str={input_str}, kwargs={kwargs})"
)

def on_agent_action(
self, action: Any, color: Optional[str] = None, **kwargs: Any
) -> Any:
diagnostic_logger.debug(f"on_agent_action(action={action}, kwargs={kwargs})")

def on_tool_end(
self,
output: str,
color: Optional[str] = None,
observation_prefix: Optional[str] = None,
llm_prefix: Optional[str] = None,
**kwargs: Any,
) -> None:
diagnostic_logger.debug(f"on_tool_end(output={output}, kwargs={kwargs})")

def on_tool_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
diagnostic_logger.debug(f"on_tool_error(error={error}, kwargs={kwargs})")

def on_text(self, text: str, **kwargs: Any) -> None:
diagnostic_logger.debug(f"on_text(text={text}, kwargs={kwargs})")

def on_agent_finish(
self, finish: Any, color: Optional[str] = None, **kwargs: Any
) -> None:
diagnostic_logger.debug(f"on_agent_finish(finish={finish}, kwargs={kwargs})")

# End LLM events

def _get_callbacks(self) -> Dict[str, partial]:
return _generate_callback_wrapper(self)


class DynamicCallbackMeta(type):
def __new__(mcs, name, bases, attrs):
cls = super().__new__(mcs, name, bases, attrs)

def implement_interface(name):
def method(self, *args, **kwargs):
if name in self._callbacks:
return self._callbacks[name](*args, **kwargs)
else:
return getattr(super(cls, self), name)(*args, **kwargs)

return method

for base in bases:
for name, attr in base.__dict__.items():
if callable(attr) and not name.startswith("_"):
setattr(cls, name, implement_interface(name))

return cls


def DynamicCallbackAdapter(Base):
class DynamicCallbackAdapterClass(Base, metaclass=DynamicCallbackMeta):
# This is called by external integrations,
# do not remove any of these parameters or add new required ones without defaults.
def __init__(self, whylabs_logger: Logger, handler: Any):
if hasattr(handler, "init"):
handler.init(self)
if hasattr(handler, "_get_callbacks"):
self._callbacks = handler._get_callbacks()
diagnostic_logger.debug(
f"initialized LangKit handler with {self._callbacks}."
)
else:
self._callbacks = dict()
diagnostic_logger.warning(
"initialized LangKit handler without callbacks."
)
self._methods: Dict[str, Callable] = dict()
self._logger = whylabs_logger

def __getattr__(self, name):
if name in self._callbacks:
return self._callbacks[name]

if name in self._methods:
return self._methods[name]

def no_op_method(*args, **kwargs):
diagnostic_logger.debug(
f"no passthrough for '{name}' this event, args={args},kwargs={kwargs}."
)

self._methods[name] = no_op_method
return no_op_method

return DynamicCallbackAdapterClass


def get_callback_instance(*args, **kwargs):
handler = kwargs.get("handler")
logger = kwargs.get("logger")
if handler is None:
logger = kwargs.get("logger")
handler = LangKitCallback(logger=logger)
elif logger is None:
logger = handler._logger
base_class = handler.__class__
impl = kwargs.get("impl")
LangKitCallbackImplementation = DynamicCallbackAdapter(base_class)
if impl:
LangKitCallbackImplementation.__bases__ += (impl,)
return LangKitCallbackImplementation(logger, handler=handler)
39 changes: 39 additions & 0 deletions langkit/tests/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from dataclasses import dataclass, field

import pkg_resources


@dataclass
class LangKitConfig:
pattern_file_path: str = pkg_resources.resource_filename(
__name__, "pattern_groups.json"
)
transformer_name: str = "sentence-transformers/all-MiniLM-L6-v2"
theme_file_path: str = pkg_resources.resource_filename(__name__, "themes.json")
prompt_column: str = "prompt"
response_column: str = "response"
topics: list = field(
default_factory=lambda: [
"law",
"finance",
"medical",
"education",
"politics",
"support",
]
)


def package_version(package: str = __package__) -> str:
"""Calculate version number based on pyproject.toml"""
try:
from importlib import metadata

version = metadata.version(package)
except metadata.PackageNotFoundError:
version = f"{package} is not installed."

return version


__version__ = package_version()
133 changes: 133 additions & 0 deletions langkit/tests/test_callback_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
from logging import getLogger
from typing import Any, Dict, List
from langkit.callback_handler import LangKitCallback, get_callback_instance


TEST_LOGGER = getLogger(__name__)


class MockLogger:
def __getattr__(self, name):
def method(*args, **kwargs):
TEST_LOGGER.info(f"logger called {name}(*args={args}, **kwargs={kwargs})")

return method


class MockCallbackOnStartMixin1:
def on_llm_start(self, prompts: List[str]):
TEST_LOGGER.info(
f"MockCallbackOnStartMixin1.on_llm_start called on_llm_start with {prompts}"
)


class MockCallbackOnStartMixin2:
def on_llm_end(self, response):
TEST_LOGGER.info(
f"MockCallbackOnStartMixin2.on_llm_end called on_llm_start with response={response}"
)


class MockCallbackOnStartMixin3:
def on_text(self, prompts: List[str]):
TEST_LOGGER.info(
f"MockCallbackOnStartMixin3.on_text called on_llm_start with {prompts}"
)


class ComplexBaseHandler(
MockCallbackOnStartMixin1, MockCallbackOnStartMixin2, MockCallbackOnStartMixin3
):
def ignore_llm(self):
TEST_LOGGER.info("Calling ignore_llm LangChainBaseHandler")


class MockBaseHandler:
def close(self):
TEST_LOGGER.info("Calling close in test MockBaseHandler")


class MockBaseHandler2:
def close(self):
TEST_LOGGER.info("Calling close in test MockBaseHandler2")


def test_callback_passthroughs_undefined_ok():
universal_callback = get_callback_instance()
universal_callback.foo(a="hi", b=True)
foo1 = universal_callback.foo
foo2 = universal_callback.foo
assert foo1 is foo2


def test_callback_passthroughs_undefined_no_args():
universal_callback = get_callback_instance()
universal_callback.bar()
universal_callback.baz()


def test_callback_passthroughs_defined_functions():
universal_callback = get_callback_instance()
universal_callback.on_text(text="Hello texty text!")


def test_callback_passthroughs_defined_logging_functions():
universal_callback = get_callback_instance(
logger=MockLogger(), impl=MockBaseHandler, interface=MockCallbackOnStartMixin1
)
test_prompts = ["hi"]
default_serialized: Dict[str, Any] = {"test": "serialized"}
on_llm_start = universal_callback.on_llm_start
universal_callback.on_llm_start(serialized=default_serialized, prompts=test_prompts)
on_llm_start(default_serialized, prompts=test_prompts)
test_response = type("", (object,), {"generations": [{"text": "No"}]})()
universal_callback.on_llm_end(response=test_response)
universal_callback.close()


def test_callback_instance_handler_defined():
callback_handler = LangKitCallback(logger=MockLogger())
universal_callback = get_callback_instance(
handler=callback_handler, impl=MockBaseHandler2
)
test_prompts = ["goodbye!"]
universal_callback.on_llm_start(prompts=test_prompts)
universal_callback.close()


def test_callback_instance_handler_defined_getattr():
callback_handler = LangKitCallback(logger=MockLogger())
universal_callback = get_callback_instance(
handler=callback_handler, impl=MockBaseHandler2, base=ComplexBaseHandler
)
test_prompts = ["goodbye variations!"]
method_name = "on_llm_start"

assert hasattr(universal_callback, method_name)
getattr_method = getattr(universal_callback, method_name)
direct_method = universal_callback.on_llm_start
TEST_LOGGER.info(
f"comparing getattr with method name {getattr_method} vs {direct_method}"
)
getattr_method(prompts=test_prompts)
direct_method(prompts=test_prompts)
universal_callback.close()


def test_callback_instance_three_ply_class_hierarchy():
callback_handler = LangKitCallback(logger=MockLogger())
universal_callback = get_callback_instance(
handler=callback_handler, impl=MockBaseHandler2, base=ComplexBaseHandler
)
test_prompts = ["goodbye variations!"]
method_name = "on_llm_start"

assert hasattr(universal_callback, method_name)
getattr_method = getattr(universal_callback, method_name)
direct_method = universal_callback.on_llm_start
TEST_LOGGER.info(
f"comparing getattr with method name {getattr_method} vs {direct_method}"
)
getattr_method(prompts=test_prompts)
direct_method(prompts=test_prompts)
universal_callback.close()
Loading

0 comments on commit 78a2011

Please sign in to comment.