Skip to content

Commit 0d75ed7

Browse files
google-genai-botcopybara-github
authored andcommitted
refactor: Fixing EventAction's artifaceDelta type
This type now matches with Vertex AI Session API and the python implementation. The new implementation unblocks rewind functionality. PiperOrigin-RevId: 868327356
1 parent 968a9a8 commit 0d75ed7

File tree

7 files changed

+25
-33
lines changed

7 files changed

+25
-33
lines changed

core/src/main/java/com/google/adk/agents/CallbackContext.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ public Completable saveArtifact(String filename, Part artifact) {
134134
invocationContext.session().id(),
135135
filename,
136136
artifact)
137-
.doOnSuccess(unusedVersion -> this.eventActions.artifactDelta().put(filename, artifact))
137+
.doOnSuccess(version -> this.eventActions.artifactDelta().put(filename, version))
138138
.ignoreElement();
139139
}
140140
}

core/src/main/java/com/google/adk/events/EventActions.java

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
import com.google.adk.agents.BaseAgentState;
2323
import com.google.adk.sessions.State;
2424
import com.google.errorprone.annotations.CanIgnoreReturnValue;
25-
import com.google.genai.types.Part;
2625
import java.util.HashSet;
2726
import java.util.Objects;
2827
import java.util.Optional;
@@ -38,7 +37,7 @@ public class EventActions extends JsonBaseModel {
3837

3938
private Optional<Boolean> skipSummarization;
4039
private ConcurrentMap<String, Object> stateDelta;
41-
private ConcurrentMap<String, Part> artifactDelta;
40+
private ConcurrentMap<String, Integer> artifactDelta;
4241
private Set<String> deletedArtifactIds;
4342
private Optional<String> transferToAgent;
4443
private Optional<Boolean> escalate;
@@ -120,11 +119,11 @@ public void removeStateByKey(String key) {
120119
}
121120

122121
@JsonProperty("artifactDelta")
123-
public ConcurrentMap<String, Part> artifactDelta() {
122+
public ConcurrentMap<String, Integer> artifactDelta() {
124123
return artifactDelta;
125124
}
126125

127-
public void setArtifactDelta(ConcurrentMap<String, Part> artifactDelta) {
126+
public void setArtifactDelta(ConcurrentMap<String, Integer> artifactDelta) {
128127
this.artifactDelta = artifactDelta;
129128
}
130129

@@ -288,7 +287,7 @@ public int hashCode() {
288287
public static class Builder {
289288
private Optional<Boolean> skipSummarization;
290289
private ConcurrentMap<String, Object> stateDelta;
291-
private ConcurrentMap<String, Part> artifactDelta;
290+
private ConcurrentMap<String, Integer> artifactDelta;
292291
private Set<String> deletedArtifactIds;
293292
private Optional<String> transferToAgent;
294293
private Optional<Boolean> escalate;
@@ -348,7 +347,7 @@ public Builder stateDelta(ConcurrentMap<String, Object> value) {
348347

349348
@CanIgnoreReturnValue
350349
@JsonProperty("artifactDelta")
351-
public Builder artifactDelta(ConcurrentMap<String, Part> value) {
350+
public Builder artifactDelta(ConcurrentMap<String, Integer> value) {
352351
this.artifactDelta = value;
353352
return this;
354353
}

core/src/main/java/com/google/adk/flows/llmflows/CodeExecution.java

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -448,11 +448,9 @@ private static Single<Event> postProcessCodeExecutionResult(
448448
.toList()
449449
.map(
450450
versions -> {
451-
ConcurrentMap<String, Part> artifactDelta = new ConcurrentHashMap<>();
451+
ConcurrentMap<String, Integer> artifactDelta = new ConcurrentHashMap<>();
452452
for (int i = 0; i < versions.size(); i++) {
453-
artifactDelta.put(
454-
codeExecutionResult.outputFiles().get(i).name(),
455-
Part.fromText(String.valueOf(versions.get(i))));
453+
artifactDelta.put(codeExecutionResult.outputFiles().get(i).name(), versions.get(i));
456454
}
457455
eventActionsBuilder.artifactDelta(artifactDelta);
458456
return Event.builder()

core/src/main/java/com/google/adk/sessions/SessionJsonConverter.java

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -298,18 +298,19 @@ private static Instant convertToInstant(Object timestampObj) {
298298
* @return A {@link ConcurrentMap} representing the artifact delta.
299299
*/
300300
@SuppressWarnings("unchecked")
301-
private static ConcurrentMap<String, Part> convertToArtifactDeltaMap(Object artifactDeltaObj) {
301+
private static ConcurrentMap<String, Integer> convertToArtifactDeltaMap(Object artifactDeltaObj) {
302302
if (!(artifactDeltaObj instanceof Map)) {
303303
return new ConcurrentHashMap<>();
304304
}
305-
ConcurrentMap<String, Part> artifactDeltaMap = new ConcurrentHashMap<>();
306-
Map<String, Map<String, Object>> rawMap = (Map<String, Map<String, Object>>) artifactDeltaObj;
307-
for (Map.Entry<String, Map<String, Object>> entry : rawMap.entrySet()) {
305+
ConcurrentMap<String, Integer> artifactDeltaMap = new ConcurrentHashMap<>();
306+
Map<String, Object> rawMap = (Map<String, Object>) artifactDeltaObj;
307+
for (Map.Entry<String, Object> entry : rawMap.entrySet()) {
308308
try {
309-
Part part = objectMapper.convertValue(entry.getValue(), Part.class);
310-
artifactDeltaMap.put(entry.getKey(), part);
309+
Integer value = objectMapper.convertValue(entry.getValue(), Integer.class);
310+
artifactDeltaMap.put(entry.getKey(), value);
311311
} catch (IllegalArgumentException e) {
312-
logger.warn("Error converting artifactDelta value to Part for key: {}", entry.getKey(), e);
312+
logger.warn(
313+
"Error converting artifactDelta value to Integer for key: {}", entry.getKey(), e);
313314
}
314315
}
315316
return artifactDeltaMap;

core/src/test/java/com/google/adk/events/EventActionsTest.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ public void merge_mergesAllFields() {
6363
EventActions.builder()
6464
.skipSummarization(true)
6565
.stateDelta(new ConcurrentHashMap<>(ImmutableMap.of("key1", "value1")))
66-
.artifactDelta(new ConcurrentHashMap<>(ImmutableMap.of("artifact1", PART)))
66+
.artifactDelta(new ConcurrentHashMap<>(ImmutableMap.of("artifact1", 1)))
6767
.deletedArtifactIds(ImmutableSet.of("deleted1"))
6868
.requestedAuthConfigs(
6969
new ConcurrentHashMap<>(
@@ -75,7 +75,7 @@ public void merge_mergesAllFields() {
7575
EventActions eventActions2 =
7676
EventActions.builder()
7777
.stateDelta(new ConcurrentHashMap<>(ImmutableMap.of("key2", "value2")))
78-
.artifactDelta(new ConcurrentHashMap<>(ImmutableMap.of("artifact2", PART)))
78+
.artifactDelta(new ConcurrentHashMap<>(ImmutableMap.of("artifact2", 2)))
7979
.deletedArtifactIds(ImmutableSet.of("deleted2"))
8080
.transferToAgent("agentId")
8181
.escalate(true)
@@ -91,7 +91,7 @@ public void merge_mergesAllFields() {
9191

9292
assertThat(merged.skipSummarization()).hasValue(true);
9393
assertThat(merged.stateDelta()).containsExactly("key1", "value1", "key2", "value2");
94-
assertThat(merged.artifactDelta()).containsExactly("artifact1", PART, "artifact2", PART);
94+
assertThat(merged.artifactDelta()).containsExactly("artifact1", 1, "artifact2", 2);
9595
assertThat(merged.deletedArtifactIds()).containsExactly("deleted1", "deleted2");
9696
assertThat(merged.transferToAgent()).hasValue("agentId");
9797
assertThat(merged.escalate()).hasValue(true);

core/src/test/java/com/google/adk/events/EventTest.java

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,9 +45,7 @@ public final class EventTest {
4545
EventActions.builder()
4646
.skipSummarization(true)
4747
.stateDelta(new ConcurrentHashMap<>(ImmutableMap.of("key", "value")))
48-
.artifactDelta(
49-
new ConcurrentHashMap<>(
50-
ImmutableMap.of("artifact_key", Part.builder().text("artifact_value").build())))
48+
.artifactDelta(new ConcurrentHashMap<>(ImmutableMap.of("artifact_key", 1)))
5149
.transferToAgent("agent_id")
5250
.escalate(true)
5351
.requestedAuthConfigs(

core/src/test/java/com/google/adk/sessions/SessionJsonConverterTest.java

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,7 @@ public void convertEventToJson_fullEvent_success() throws JsonProcessingExceptio
3939
EventActions.builder()
4040
.skipSummarization(true)
4141
.stateDelta(new ConcurrentHashMap<>(ImmutableMap.of("key", "value")))
42-
.artifactDelta(
43-
new ConcurrentHashMap<>(
44-
ImmutableMap.of("artifact", Part.fromText("artifact_text"))))
42+
.artifactDelta(new ConcurrentHashMap<>(ImmutableMap.of("artifact", 1)))
4543
.transferToAgent("agent")
4644
.escalate(true)
4745
.build();
@@ -80,8 +78,7 @@ public void convertEventToJson_fullEvent_success() throws JsonProcessingExceptio
8078
JsonNode actionsNode = jsonNode.get("actions");
8179
assertThat(actionsNode.get("skipSummarization").asBoolean()).isTrue();
8280
assertThat(actionsNode.get("stateDelta").get("key").asText()).isEqualTo("value");
83-
assertThat(actionsNode.get("artifactDelta").get("artifact").get("text").asText())
84-
.isEqualTo("artifact_text");
81+
assertThat(actionsNode.get("artifactDelta").get("artifact").asInt()).isEqualTo(1);
8582
assertThat(actionsNode.get("transferAgent").asText()).isEqualTo("agent");
8683
assertThat(actionsNode.get("escalate").asBoolean()).isTrue();
8784
}
@@ -131,8 +128,7 @@ public void fromApiEvent_fullEvent_success() {
131128
Map<String, Object> actions = new HashMap<>();
132129
actions.put("skipSummarization", true);
133130
actions.put("stateDelta", ImmutableMap.of("key", "value"));
134-
actions.put(
135-
"artifactDelta", ImmutableMap.of("artifact", ImmutableMap.of("text", "artifact_text")));
131+
actions.put("artifactDelta", ImmutableMap.of("artifact", 1));
136132
actions.put("transferAgent", "agent");
137133
actions.put("escalate", true);
138134
apiEvent.put("actions", actions);
@@ -154,7 +150,7 @@ public void fromApiEvent_fullEvent_success() {
154150
EventActions eventActions = event.actions();
155151
assertThat(eventActions.skipSummarization()).hasValue(true);
156152
assertThat(eventActions.stateDelta()).containsEntry("key", "value");
157-
assertThat(eventActions.artifactDelta().get("artifact").text()).hasValue("artifact_text");
153+
assertThat(eventActions.artifactDelta().get("artifact")).isEqualTo(1);
158154
assertThat(eventActions.transferToAgent()).hasValue("agent");
159155
assertThat(eventActions.escalate()).hasValue(true);
160156
}
@@ -383,7 +379,7 @@ public void fromApiEvent_withInvalidArtifactDelta_skipsInvalidEntries() {
383379
apiEvent.put("timestamp", "2023-01-01T00:00:00Z");
384380

385381
Map<String, Object> artifactDelta = new HashMap<>();
386-
artifactDelta.put("valid", ImmutableMap.of("text", "valid_text"));
382+
artifactDelta.put("valid", 1);
387383
artifactDelta.put("invalid", "not-a-map");
388384

389385
Map<String, Object> actions = new HashMap<>();

0 commit comments

Comments
 (0)