diff --git a/core/src/main/java/com/google/adk/agents/BaseAgent.java b/core/src/main/java/com/google/adk/agents/BaseAgent.java index 20d7dfa4f..72fc5883a 100644 --- a/core/src/main/java/com/google/adk/agents/BaseAgent.java +++ b/core/src/main/java/com/google/adk/agents/BaseAgent.java @@ -288,9 +288,10 @@ private InvocationContext createInvocationContext(InvocationContext parentContex InvocationContext.Builder builder = parentContext.toBuilder(); builder.agent(this); // Check for branch to be truthy (not None, not empty string), - if (parentContext.branch().filter(s -> !s.isEmpty()).isPresent()) { - builder.branch(parentContext.branch().get() + "." + name()); - } + parentContext + .branch() + .filter(s -> !s.isEmpty()) + .ifPresent(branch -> builder.branch(branch + "." + name())); return builder.build(); } @@ -301,6 +302,19 @@ private InvocationContext createInvocationContext(InvocationContext parentContex * @return stream of agent-generated events. */ public Flowable runAsync(InvocationContext parentContext) { + return run(parentContext, this::runAsyncImpl); + } + + /** + * Runs the agent with the given implementation. + * + * @param parentContext Parent context to inherit. + * @param runImplementation The agent-specific logic to run. + * @return stream of agent-generated events. + */ + private Flowable run( + InvocationContext parentContext, + Function> runImplementation) { Tracer tracer = Tracing.getTracer(); return Flowable.defer( () -> { @@ -326,7 +340,7 @@ public Flowable runAsync(InvocationContext parentContext) { Flowable beforeEvents = Flowable.fromOptional(beforeEventOpt); Flowable mainEvents = - Flowable.defer(() -> runAsyncImpl(invocationContext)); + Flowable.defer(() -> runImplementation.apply(invocationContext)); Flowable afterEvents = Flowable.defer( () -> @@ -382,7 +396,7 @@ private ImmutableList>> afterCallbacksT private Single> callCallback( List>> agentCallbacks, InvocationContext invocationContext) { - if (agentCallbacks == null || agentCallbacks.isEmpty()) { + if (agentCallbacks.isEmpty()) { return Single.just(Optional.empty()); } @@ -437,45 +451,7 @@ private Single> callCallback( * @return stream of agent-generated events. */ public Flowable runLive(InvocationContext parentContext) { - Tracer tracer = Tracing.getTracer(); - return Flowable.defer( - () -> { - InvocationContext invocationContext = createInvocationContext(parentContext); - Span span = - tracer.spanBuilder("invoke_agent " + name()).setParent(Context.current()).startSpan(); - Tracing.traceAgentInvocation(span, name(), description(), invocationContext); - Context spanContext = Context.current().with(span); - - return Tracing.traceFlowable( - spanContext, - span, - () -> - callCallback( - beforeCallbacksToFunctions( - invocationContext.pluginManager(), beforeAgentCallback), - invocationContext) - .flatMapPublisher( - beforeEventOpt -> { - if (invocationContext.endInvocation()) { - return Flowable.fromOptional(beforeEventOpt); - } - - Flowable beforeEvents = Flowable.fromOptional(beforeEventOpt); - Flowable mainEvents = - Flowable.defer(() -> runLiveImpl(invocationContext)); - Flowable afterEvents = - Flowable.defer( - () -> - callCallback( - afterCallbacksToFunctions( - invocationContext.pluginManager(), - afterAgentCallback), - invocationContext) - .flatMapPublisher(Flowable::fromOptional)); - - return Flowable.concat(beforeEvents, mainEvents, afterEvents); - })); - }); + return run(parentContext, this::runLiveImpl); } /**