diff --git a/scio-core/src/main/java/com/spotify/scio/transforms/BaseAsyncBatchLookupDoFn.java b/scio-core/src/main/java/com/spotify/scio/transforms/BaseAsyncBatchLookupDoFn.java
index 9326321488..d7c9e8ed08 100644
--- a/scio-core/src/main/java/com/spotify/scio/transforms/BaseAsyncBatchLookupDoFn.java
+++ b/scio-core/src/main/java/com/spotify/scio/transforms/BaseAsyncBatchLookupDoFn.java
@@ -17,6 +17,7 @@
package com.spotify.scio.transforms;
import static java.util.Objects.requireNonNull;
+import static java.util.function.Function.identity;
import com.google.common.cache.Cache;
import com.spotify.scio.transforms.BaseAsyncLookupDoFn.CacheSupplier;
@@ -26,6 +27,7 @@
import java.util.Collections;
import java.util.LinkedList;
import java.util.List;
+import java.util.Map;
import java.util.Queue;
import java.util.UUID;
import java.util.concurrent.ConcurrentHashMap;
@@ -274,48 +276,53 @@ private void createRequest() throws InterruptedException {
}
private FutureType handleOutput(FutureType future, List batchInput, UUID key) {
+ final Map keyedInputs =
+ batchInput.stream().collect(Collectors.toMap(idExtractorFn::apply, identity()));
return addCallback(
future,
response -> {
- batchResponseFn
- .apply(response)
- .forEach(
- pair -> {
- final String id = pair.getLeft();
- final Output output = pair.getRight();
- final List> processInputs = inputs.remove(id);
- if (processInputs == null) {
- // no need to fail future here as we're only interested in its completion
- // finishBundle will fail the checkState as we do not produce any result
- LOG.error(
- "The ID '{}' received in the gRPC batch response does not "
- + "match any IDs extracted via the idExtractorFn for the requested "
- + "batch sent to the gRPC endpoint. Please ensure that the IDs returned "
- + "from the gRPC endpoints match the IDs extracted using the provided"
- + "idExtractorFn for the same input.",
- id);
- } else {
- final List>> batchResult =
- processInputs.stream()
- .map(
- processInput -> {
- final Input i = processInput.getValue();
- final TryWrapper o = success(output);
- final Instant ts = processInput.getTimestamp();
- final BoundedWindow w = processInput.getWindow();
- final PaneInfo p = processInput.getPane();
- return ValueInSingleWindow.of(KV.of(i, o), ts, w, p);
- })
- .collect(Collectors.toList());
- results.add(Pair.of(key, batchResult));
- }
- });
+ final Map keyedOutput =
+ batchResponseFn.apply(response).stream()
+ .collect(Collectors.toMap(Pair::getKey, Pair::getValue));
+
+ keyedInputs.forEach(
+ (id, input) -> {
+ final List> processInputs = inputs.remove(id);
+ if (processInputs == null) {
+ // no need to fail future here as we're only interested in its completion
+ // finishBundle will fail the checkState as we do not produce any result
+ LOG.error(
+ "The ID '{}' received in the gRPC batch response does not "
+ + "match any IDs extracted via the idExtractorFn for the requested "
+ + "batch sent to the gRPC endpoint. Please ensure that the IDs returned "
+ + "from the gRPC endpoints match the IDs extracted using the provided"
+ + "idExtractorFn for the same input.",
+ id);
+ } else {
+ List>> batchResult =
+ processInputs.stream()
+ .map(
+ processInput -> {
+ final Input i = processInput.getValue();
+ final Output output = keyedOutput.get(id);
+ final TryWrapper o =
+ output == null
+ ? failure(new UnmatchedRequestException(id))
+ : success(output);
+ final Instant ts = processInput.getTimestamp();
+ final BoundedWindow w = processInput.getWindow();
+ final PaneInfo p = processInput.getPane();
+ return ValueInSingleWindow.of(KV.of(i, o), ts, w, p);
+ })
+ .collect(Collectors.toList());
+ results.add(Pair.of(key, batchResult));
+ }
+ });
return null;
},
throwable -> {
- batchInput.forEach(
- element -> {
- final String id = idExtractorFn.apply(element);
+ keyedInputs.forEach(
+ (id, element) -> {
final List>> batchResult =
inputs.remove(id).stream()
.map(
diff --git a/scio-core/src/main/java/com/spotify/scio/transforms/UnmatchedRequestException.java b/scio-core/src/main/java/com/spotify/scio/transforms/UnmatchedRequestException.java
new file mode 100644
index 0000000000..e6025f6b7f
--- /dev/null
+++ b/scio-core/src/main/java/com/spotify/scio/transforms/UnmatchedRequestException.java
@@ -0,0 +1,45 @@
+/*
+ * Copyright 2024 Spotify AB
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.spotify.scio.transforms;
+
+import java.util.Objects;
+
+public class UnmatchedRequestException extends RuntimeException {
+
+ private final String id;
+
+ public UnmatchedRequestException(String id) {
+ super("Unmatched batch request for ID: " + id);
+ this.id = id;
+ }
+
+ public String getId() {
+ return id;
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (o == null || getClass() != o.getClass()) return false;
+ UnmatchedRequestException that = (UnmatchedRequestException) o;
+ return Objects.equals(id, that.id);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hashCode(id);
+ }
+}
diff --git a/scio-core/src/test/scala/com/spotify/scio/transforms/AsyncBatchLookupDoFnTest.scala b/scio-core/src/test/scala/com/spotify/scio/transforms/AsyncBatchLookupDoFnTest.scala
index a620944a4e..2844f88d86 100644
--- a/scio-core/src/test/scala/com/spotify/scio/transforms/AsyncBatchLookupDoFnTest.scala
+++ b/scio-core/src/test/scala/com/spotify/scio/transforms/AsyncBatchLookupDoFnTest.scala
@@ -61,7 +61,12 @@ class AsyncBatchLookupDoFnTest extends PipelineSpec {
doFn: BaseAsyncBatchLookupDoFn[Int, List[Int], List[String], String, AsyncBatchClient, F, T]
)(tryFn: T => Try[String]): Unit = {
// batches of size 4 and size 3
- val output = runWithData(Seq[Seq[Int]](1 to 4, 8 to 10))(_.flatten.parDo(doFn)).map { kv =>
+ val output = runWithData(
+ Seq[Seq[Int]](
+ 1 to 4, // 1 and 3 are unmatched
+ 8 to 10 // failure
+ )
+ )(_.flatten.parDo(doFn)).map { kv =>
val r = tryFn(kv.getValue) match {
case Success(v) => v
case Failure(e: CompletionException) => e.getCause.getMessage
@@ -70,8 +75,9 @@ class AsyncBatchLookupDoFnTest extends PipelineSpec {
(kv.getKey, r)
}
output should contain theSameElementsAs (
- (1 to 4).map(x => x -> x.toString) ++
- (8 to 10).map(x => x -> "failure for 8,9,10")
+ Seq(1, 3).map(x => x -> s"Unmatched batch request for ID: $x") ++
+ Seq(2, 4).map(x => x -> x.toString) ++
+ Seq(8, 9, 10).map(x => x -> "failure for 8,9,10")
)
}
@@ -229,7 +235,7 @@ class FailingGuavaBatchLookupDoFn extends AbstractGuavaAsyncBatchLookupDoFn() {
input: List[Int]
): ListenableFuture[List[String]] =
if (input.size % 2 == 0) {
- Futures.immediateFuture(input.map(_.toString))
+ Futures.immediateFuture(input.filter(_ % 2 == 0).map(_.toString))
} else {
Futures.immediateFailedFuture(new RuntimeException("failure for " + input.mkString(",")))
}
@@ -299,7 +305,7 @@ class FailingJavaBatchLookupDoFn extends AbstractJavaAsyncBatchLookupDoFn() {
input: List[Int]
): CompletableFuture[List[String]] =
if (input.size % 2 == 0) {
- CompletableFuture.supplyAsync(() => input.map(_.toString))
+ CompletableFuture.supplyAsync(() => input.filter(_ % 2 == 0).map(_.toString))
} else {
val f = new CompletableFuture[List[String]]()
f.completeExceptionally(new RuntimeException("failure for " + input.mkString(",")))
@@ -347,7 +353,7 @@ class FailingScalaBatchLookupDoFn extends AbstractScalaAsyncBatchLookupDoFn() {
override protected def newClient(): AsyncBatchClient = null
override def asyncLookup(session: AsyncBatchClient, input: List[Int]): Future[List[String]] =
if (input.size % 2 == 0) {
- Future.successful(input.map(_.toString))
+ Future.successful(input.filter(_ % 2 == 0).map(_.toString))
} else {
Future.failed(new RuntimeException("failure for " + input.mkString(",")))
}
diff --git a/scio-grpc/src/test/scala/com/spotify/scio/grpc/GrpcBatchDoFnTest.scala b/scio-grpc/src/test/scala/com/spotify/scio/grpc/GrpcBatchDoFnTest.scala
index b97c059ede..9079159589 100644
--- a/scio-grpc/src/test/scala/com/spotify/scio/grpc/GrpcBatchDoFnTest.scala
+++ b/scio-grpc/src/test/scala/com/spotify/scio/grpc/GrpcBatchDoFnTest.scala
@@ -21,16 +21,16 @@ import com.spotify.concat.v1.ConcatServiceGrpc.{ConcatServiceFutureStub, ConcatS
import com.spotify.concat.v1._
import com.spotify.scio.testing.PipelineSpec
import com.spotify.scio.transforms.BaseAsyncLookupDoFn.CacheSupplier
+import com.spotify.scio.transforms.UnmatchedRequestException
import io.grpc.netty.NettyChannelBuilder
import io.grpc.stub.StreamObserver
import io.grpc.{Server, ServerBuilder}
-import org.apache.beam.sdk.Pipeline.PipelineExecutionException
import org.scalatest.BeforeAndAfterAll
import java.net.ServerSocket
import java.util.stream.Collectors
import scala.jdk.CollectionConverters._
-import scala.util.{Success, Try}
+import scala.util.{Failure, Success, Try}
object GrpcBatchDoFnTest {
@@ -230,7 +230,7 @@ class GrpcBatchDoFnTest extends PipelineSpec with BeforeAndAfterAll {
}
}
- it should "throw an IllegalStateException if gRPC response contains unknown ids" in {
+ it should "fail unmatched inputs" in {
val input = (0 to 1).map { i =>
ConcatRequestWithID
.newBuilder()
@@ -240,30 +240,29 @@ class GrpcBatchDoFnTest extends PipelineSpec with BeforeAndAfterAll {
.build()
}
- assertThrows[IllegalStateException] {
- try {
- runWithContext { sc =>
- sc.parallelize(input)
- .grpcBatchLookup[
- BatchRequest,
- BatchResponse,
- ConcatResponseWithID,
- ConcatServiceFutureStub
- ](
- () => NettyChannelBuilder.forTarget(ServiceUri).usePlaintext().build(),
- ConcatServiceGrpc.newFutureStub,
- 2,
- concatBatchRequest,
- r => r.getResponseList.asScala.toSeq.map(e => ("WrongID-" + e.getRequestId, e)),
- idExtractor,
- 2
- )(_.batchConcat)
- }
- } catch {
- case e: PipelineExecutionException =>
- e.getMessage should include("Expected requestCount == responseCount")
- throw e.getCause
- }
+ val expected: Seq[(ConcatRequestWithID, Try[ConcatResponseWithID])] = input.map { req =>
+ req -> Failure(new UnmatchedRequestException(req.getRequestId))
+ }
+
+ runWithContext { sc =>
+ val result = sc
+ .parallelize(input)
+ .grpcBatchLookup[
+ BatchRequest,
+ BatchResponse,
+ ConcatResponseWithID,
+ ConcatServiceFutureStub
+ ](
+ () => NettyChannelBuilder.forTarget(ServiceUri).usePlaintext().build(),
+ ConcatServiceGrpc.newFutureStub,
+ 2,
+ concatBatchRequest,
+ r => r.getResponseList.asScala.toSeq.map(e => ("WrongID-" + e.getRequestId, e)),
+ idExtractor,
+ 2
+ )(_.batchConcat)
+
+ result should containInAnyOrder(expected)
}
}
}