Skip to content

Commit

Permalink
Make response handling more generic. Resulting in generic cost handling.
Browse files Browse the repository at this point in the history
  • Loading branch information
alesj committed Dec 5, 2024
1 parent 2e8615a commit 1ca082d
Show file tree
Hide file tree
Showing 22 changed files with 356 additions and 34 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,12 @@

import java.util.Optional;

import jakarta.inject.Singleton;

import org.jboss.jandex.DotName;

import io.quarkiverse.langchain4j.cost.CostEstimatorResponseListener;
import io.quarkiverse.langchain4j.deployment.config.LangChain4jBuildConfig;
import io.quarkiverse.langchain4j.runtime.listeners.MetricsChatModelListener;
import io.quarkiverse.langchain4j.runtime.listeners.SpanChatModelListener;
import io.quarkus.arc.deployment.AdditionalBeanBuildItem;
Expand All @@ -14,6 +20,20 @@

public class ListenersProcessor {

@BuildStep
public void costListener(
LangChain4jBuildConfig config,
BuildProducer<AdditionalBeanBuildItem> additionalBeanProducer) {
if (config.costListener()) {
additionalBeanProducer.produce(
AdditionalBeanBuildItem.builder()
.addBeanClass(CostEstimatorResponseListener.class)
.setDefaultScope(DotName.createSimple(Singleton.class))
.setUnremovable()
.build());
}
}

@BuildStep
public void spanListeners(Capabilities capabilities,
Optional<MetricsCapabilityBuildItem> metricsCapability,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,12 @@ public interface LangChain4jBuildConfig {
@WithDefault("true")
boolean responseSchema();

/**
* Configuration property to enable or disable generic cost listener
*/
@WithDefault("false")
boolean costListener();

interface BaseConfig {
/**
* Chat model
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
package io.quarkiverse.langchain4j.cost;

import java.util.ArrayList;
import java.util.Comparator;
import java.util.List;

import jakarta.inject.Inject;

import dev.langchain4j.model.output.TokenUsage;
import io.quarkiverse.langchain4j.response.ResponseListener;
import io.quarkiverse.langchain4j.response.ResponseRecord;
import io.quarkus.arc.All;
import io.smallrye.common.annotation.Experimental;

/**
* Allows for user code to provide a custom strategy for estimating the cost of API calls
*/
@Experimental("This feature is experimental and the API is subject to change")
public class CostEstimatorResponseListener implements ResponseListener {

private final CostEstimatorService service;
private final List<CostListener> listeners;

@Inject
public CostEstimatorResponseListener(CostEstimatorService service, @All List<CostListener> listeners) {
this.service = service;
this.listeners = new ArrayList<>(listeners);
this.listeners.sort(Comparator.comparingInt(CostListener::order));
}

@Override
public void onResponse(ResponseRecord rr) {
String model = rr.model();
TokenUsage tokenUsage = rr.tokenUsage();
CostEstimator.CostContext context = new MyCostContext(tokenUsage, model);
Cost cost = service.estimate(context);
if (cost != null) {
for (CostListener cl : listeners) {
cl.handleCost(model, tokenUsage, cost);
}
}
}

private record MyCostContext(TokenUsage tokenUsage, String model) implements CostEstimator.CostContext {
@Override
public Integer inputTokens() {
return tokenUsage().inputTokenCount();
}

@Override
public Integer outputTokens() {
return tokenUsage().outputTokenCount();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,13 @@ public CostEstimatorService(@All List<CostEstimator> costEstimators) {
public Cost estimate(ChatModelResponseContext response) {
TokenUsage tokenUsage = response.response().tokenUsage();
CostEstimator.CostContext costContext = new MyCostContext(tokenUsage, response);
return estimate(costContext);
}

public Cost estimate(CostEstimator.CostContext context) {
for (CostEstimator costEstimator : costEstimators) {
if (costEstimator.supports(costContext)) {
CostEstimator.CostResult costResult = costEstimator.estimate(costContext);
if (costEstimator.supports(context)) {
CostEstimator.CostResult costResult = costEstimator.estimate(context);
if (costResult != null) {
BigDecimal totalCost = costResult.inputTokensCost().add(costResult.outputTokensCost());
return new Cost(totalCost, costResult.currency());
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
package io.quarkiverse.langchain4j.cost;

import dev.langchain4j.model.output.TokenUsage;

/**
* Allows for user code to handle estimate cost; e.g. some simple accounting
*/
public interface CostListener {
void handleCost(String model, TokenUsage tokenUsage, Cost cost);

default int order() {
return 0;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
package io.quarkiverse.langchain4j.response;

import java.util.Map;

import jakarta.annotation.Priority;
import jakarta.interceptor.AroundInvoke;
import jakarta.interceptor.Interceptor;
import jakarta.interceptor.InvocationContext;

import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.model.chat.listener.ChatModelResponse;
import dev.langchain4j.model.chat.response.ChatResponse;
import dev.langchain4j.model.output.Response;

/**
* Simple (Chat)Response interceptor, to be applied directly on the model.
*/
@Interceptor
@ResponseInterceptorBinding
@Priority(0)
public class ResponseInterceptor extends ResponseInterceptorBase {

@AroundInvoke
public Object intercept(InvocationContext context) throws Exception {
Object result = context.proceed();
ResponseRecord rr = null;
if (result instanceof Response<?> response) {
Object content = response.content();
if (content instanceof AiMessage am) {
rr = new ResponseRecord(getModel(context.getTarget()), am, response.tokenUsage(), response.finishReason(),
response.metadata());
}
} else if (result instanceof ChatResponse response) {
rr = new ResponseRecord(getModel(context.getTarget()), response.aiMessage(), response.tokenUsage(),
response.finishReason(), Map.of());
} else if (result instanceof ChatModelResponse response) {
rr = new ResponseRecord(response.model(), response.aiMessage(), response.tokenUsage(), response.finishReason(),
Map.of("id", response.id()));
}
if (rr != null) {
for (ResponseListener l : getListeners()) {
l.onResponse(rr);
}
}
return result;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
package io.quarkiverse.langchain4j.response;

import java.lang.reflect.Method;
import java.util.Comparator;
import java.util.List;

import jakarta.enterprise.inject.Any;
import jakarta.enterprise.inject.spi.CDI;

/**
* Simple (Chat)Response interceptor base, to be applied directly on the model.
*/
public abstract class ResponseInterceptorBase {

private volatile String model;
private volatile List<ResponseListener> listeners;

// TODO -- uh uh ... reflection ... puke
protected String getModel(Object target) {
if (model == null) {
try {
Class<?> clazz = target.getClass();
Method method = clazz.getMethod("modelName");
model = (String) method.invoke(target);
} catch (Exception e) {
throw new RuntimeException(e);
}
}
return model;
}

protected List<ResponseListener> getListeners() {
if (listeners == null) {
listeners = CDI.current().select(ResponseListener.class, Any.Literal.INSTANCE)
.stream()
.sorted(Comparator.comparing(ResponseListener::order))
.toList();
}
return listeners;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
package io.quarkiverse.langchain4j.response;

import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;

import jakarta.interceptor.InterceptorBinding;

@InterceptorBinding
@Target({ ElementType.TYPE, ElementType.METHOD })
@Retention(RetentionPolicy.RUNTIME)
public @interface ResponseInterceptorBinding {
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
package io.quarkiverse.langchain4j.response;

@ResponseInterceptorBinding
public abstract class ResponseInterceptorBindingSource {
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
package io.quarkiverse.langchain4j.response;

/**
* Simple ResponseRecord listener, to be implemented by the (advanced) users.
*/
public interface ResponseListener {
void onResponse(ResponseRecord response);

default int order() {
return 0;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
package io.quarkiverse.langchain4j.response;

import java.util.Map;

import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.model.output.FinishReason;
import dev.langchain4j.model.output.TokenUsage;

/**
* Abstract away Response vs ChatResponse.
*/
public record ResponseRecord(
String model,
AiMessage content,
TokenUsage tokenUsage,
FinishReason finishReason,
Map<String, Object> metadata) {
}
17 changes: 17 additions & 0 deletions docs/modules/ROOT/pages/includes/quarkus-langchain4j-core.adoc
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,23 @@ endif::add-copy-button-to-env-var[]
|boolean
|`true`

a|icon:lock[title=Fixed at build time] [[quarkus-langchain4j-core_quarkus-langchain4j-cost-listener]] [.property-path]##link:#quarkus-langchain4j-core_quarkus-langchain4j-cost-listener[`quarkus.langchain4j.cost-listener`]##

[.description]
--
Configuration property to enable or disable generic cost listener


ifdef::add-copy-button-to-env-var[]
Environment variable: env_var_with_copy_button:+++QUARKUS_LANGCHAIN4J_COST_LISTENER+++[]
endif::add-copy-button-to-env-var[]
ifndef::add-copy-button-to-env-var[]
Environment variable: `+++QUARKUS_LANGCHAIN4J_COST_LISTENER+++`
endif::add-copy-button-to-env-var[]
--
|boolean
|`false`

a| [[quarkus-langchain4j-core_quarkus-langchain4j-chat-memory-memory-window-max-messages]] [.property-path]##link:#quarkus-langchain4j-core_quarkus-langchain4j-chat-memory-memory-window-max-messages[`quarkus.langchain4j.chat-memory.memory-window.max-messages`]##

[.description]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,23 @@ endif::add-copy-button-to-env-var[]
|boolean
|`true`

a|icon:lock[title=Fixed at build time] [[quarkus-langchain4j-core_quarkus-langchain4j-cost-listener]] [.property-path]##link:#quarkus-langchain4j-core_quarkus-langchain4j-cost-listener[`quarkus.langchain4j.cost-listener`]##

[.description]
--
Configuration property to enable or disable generic cost listener


ifdef::add-copy-button-to-env-var[]
Environment variable: env_var_with_copy_button:+++QUARKUS_LANGCHAIN4J_COST_LISTENER+++[]
endif::add-copy-button-to-env-var[]
ifndef::add-copy-button-to-env-var[]
Environment variable: `+++QUARKUS_LANGCHAIN4J_COST_LISTENER+++`
endif::add-copy-button-to-env-var[]
--
|boolean
|`false`

a| [[quarkus-langchain4j-core_quarkus-langchain4j-chat-memory-memory-window-max-messages]] [.property-path]##link:#quarkus-langchain4j-core_quarkus-langchain4j-chat-memory-memory-window-max-messages[`quarkus.langchain4j.chat-memory.memory-window.max-messages`]##

[.description]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,12 +64,12 @@ public class MultipleChatProvidersTest {

@Test
void defaultModel() {
assertThat(ClientProxy.unwrap(defaultModel)).isInstanceOf(OpenAiChatModel.class);
assertThat(SubclassUtil.unwrap(defaultModel)).isInstanceOf(OpenAiChatModel.class);
}

@Test
void firstNamedModel() {
assertThat(ClientProxy.unwrap(firstNamedModel)).isInstanceOf(OpenAiChatModel.class);
assertThat(SubclassUtil.unwrap(firstNamedModel)).isInstanceOf(OpenAiChatModel.class);
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ public class MultipleEmbeddingModelsTest {

@Test
void firstNamedModel() {
assertThat(ClientProxy.unwrap(firstNamedModel)).isInstanceOf(OpenAiEmbeddingModel.class);
assertThat(SubclassUtil.unwrap(firstNamedModel)).isInstanceOf(OpenAiEmbeddingModel.class);
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
import dev.langchain4j.model.moderation.ModerationModel;
import dev.langchain4j.model.openai.OpenAiModerationModel;
import io.quarkiverse.langchain4j.ModelName;
import io.quarkus.arc.ClientProxy;
import io.quarkus.test.junit.QuarkusTest;

@QuarkusTest
Expand All @@ -30,12 +29,12 @@ public class MultipleModerationProvidersTest {

@Test
void defaultModel() {
assertThat(ClientProxy.unwrap(defaultModel)).isInstanceOf(OpenAiModerationModel.class);
assertThat(SubclassUtil.unwrap(defaultModel)).isInstanceOf(OpenAiModerationModel.class);
}

@Test
void firstNamedModel() {
assertThat(ClientProxy.unwrap(firstNamedModel)).isInstanceOf(OpenAiModerationModel.class);
assertThat(SubclassUtil.unwrap(firstNamedModel)).isInstanceOf(OpenAiModerationModel.class);
}

@Test
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
package org.acme.example.multiple;

import java.lang.reflect.Field;

import io.quarkus.arc.ClientProxy;
import io.quarkus.arc.Subclass;

public class SubclassUtil {

public static <T> T unwrap(T target) {
T sub = ClientProxy.unwrap(target);
if (sub instanceof Subclass) {
try {
Field delegate = sub.getClass().getDeclaredField("delegate");
delegate.setAccessible(true);
sub = (T) delegate.get(sub);
} catch (Exception e) {
throw new RuntimeException(e);
}
}
return sub;
}

}
Loading

0 comments on commit 1ca082d

Please sign in to comment.