Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 20 additions & 44 deletions core/src/main/java/com/google/adk/agents/BaseAgent.java
Original file line number Diff line number Diff line change
Expand Up @@ -288,9 +288,10 @@ private InvocationContext createInvocationContext(InvocationContext parentContex
InvocationContext.Builder builder = parentContext.toBuilder();
builder.agent(this);
// Check for branch to be truthy (not None, not empty string),
if (parentContext.branch().filter(s -> !s.isEmpty()).isPresent()) {
builder.branch(parentContext.branch().get() + "." + name());
}
parentContext
.branch()
.filter(s -> !s.isEmpty())
.ifPresent(branch -> builder.branch(branch + "." + name()));
return builder.build();
}

Expand All @@ -301,6 +302,19 @@ private InvocationContext createInvocationContext(InvocationContext parentContex
* @return stream of agent-generated events.
*/
public Flowable<Event> runAsync(InvocationContext parentContext) {
return run(parentContext, this::runAsyncImpl);
}

/**
* Runs the agent with the given implementation.
*
* @param parentContext Parent context to inherit.
* @param runImplementation The agent-specific logic to run.
* @return stream of agent-generated events.
*/
private Flowable<Event> run(
InvocationContext parentContext,
Function<InvocationContext, Flowable<Event>> runImplementation) {
Tracer tracer = Tracing.getTracer();
return Flowable.defer(
() -> {
Expand All @@ -326,7 +340,7 @@ public Flowable<Event> runAsync(InvocationContext parentContext) {

Flowable<Event> beforeEvents = Flowable.fromOptional(beforeEventOpt);
Flowable<Event> mainEvents =
Flowable.defer(() -> runAsyncImpl(invocationContext));
Flowable.defer(() -> runImplementation.apply(invocationContext));
Flowable<Event> afterEvents =
Flowable.defer(
() ->
Expand Down Expand Up @@ -382,7 +396,7 @@ private ImmutableList<Function<CallbackContext, Maybe<Content>>> afterCallbacksT
private Single<Optional<Event>> callCallback(
List<Function<CallbackContext, Maybe<Content>>> agentCallbacks,
InvocationContext invocationContext) {
if (agentCallbacks == null || agentCallbacks.isEmpty()) {
if (agentCallbacks.isEmpty()) {
return Single.just(Optional.empty());
}

Expand Down Expand Up @@ -437,45 +451,7 @@ private Single<Optional<Event>> callCallback(
* @return stream of agent-generated events.
*/
public Flowable<Event> runLive(InvocationContext parentContext) {
Tracer tracer = Tracing.getTracer();
return Flowable.defer(
() -> {
InvocationContext invocationContext = createInvocationContext(parentContext);
Span span =
tracer.spanBuilder("invoke_agent " + name()).setParent(Context.current()).startSpan();
Tracing.traceAgentInvocation(span, name(), description(), invocationContext);
Context spanContext = Context.current().with(span);

return Tracing.traceFlowable(
spanContext,
span,
() ->
callCallback(
beforeCallbacksToFunctions(
invocationContext.pluginManager(), beforeAgentCallback),
invocationContext)
.flatMapPublisher(
beforeEventOpt -> {
if (invocationContext.endInvocation()) {
return Flowable.fromOptional(beforeEventOpt);
}

Flowable<Event> beforeEvents = Flowable.fromOptional(beforeEventOpt);
Flowable<Event> mainEvents =
Flowable.defer(() -> runLiveImpl(invocationContext));
Flowable<Event> afterEvents =
Flowable.defer(
() ->
callCallback(
afterCallbacksToFunctions(
invocationContext.pluginManager(),
afterAgentCallback),
invocationContext)
.flatMapPublisher(Flowable::fromOptional));

return Flowable.concat(beforeEvents, mainEvents, afterEvents);
}));
});
return run(parentContext, this::runLiveImpl);
}

/**
Expand Down