diff --git a/scio-core/src/main/java/com/spotify/scio/transforms/BaseAsyncBatchLookupDoFn.java b/scio-core/src/main/java/com/spotify/scio/transforms/BaseAsyncBatchLookupDoFn.java index 9326321488..d7c9e8ed08 100644 --- a/scio-core/src/main/java/com/spotify/scio/transforms/BaseAsyncBatchLookupDoFn.java +++ b/scio-core/src/main/java/com/spotify/scio/transforms/BaseAsyncBatchLookupDoFn.java @@ -17,6 +17,7 @@ package com.spotify.scio.transforms; import static java.util.Objects.requireNonNull; +import static java.util.function.Function.identity; import com.google.common.cache.Cache; import com.spotify.scio.transforms.BaseAsyncLookupDoFn.CacheSupplier; @@ -26,6 +27,7 @@ import java.util.Collections; import java.util.LinkedList; import java.util.List; +import java.util.Map; import java.util.Queue; import java.util.UUID; import java.util.concurrent.ConcurrentHashMap; @@ -274,48 +276,53 @@ private void createRequest() throws InterruptedException { } private FutureType handleOutput(FutureType future, List batchInput, UUID key) { + final Map keyedInputs = + batchInput.stream().collect(Collectors.toMap(idExtractorFn::apply, identity())); return addCallback( future, response -> { - batchResponseFn - .apply(response) - .forEach( - pair -> { - final String id = pair.getLeft(); - final Output output = pair.getRight(); - final List> processInputs = inputs.remove(id); - if (processInputs == null) { - // no need to fail future here as we're only interested in its completion - // finishBundle will fail the checkState as we do not produce any result - LOG.error( - "The ID '{}' received in the gRPC batch response does not " - + "match any IDs extracted via the idExtractorFn for the requested " - + "batch sent to the gRPC endpoint. Please ensure that the IDs returned " - + "from the gRPC endpoints match the IDs extracted using the provided" - + "idExtractorFn for the same input.", - id); - } else { - final List>> batchResult = - processInputs.stream() - .map( - processInput -> { - final Input i = processInput.getValue(); - final TryWrapper o = success(output); - final Instant ts = processInput.getTimestamp(); - final BoundedWindow w = processInput.getWindow(); - final PaneInfo p = processInput.getPane(); - return ValueInSingleWindow.of(KV.of(i, o), ts, w, p); - }) - .collect(Collectors.toList()); - results.add(Pair.of(key, batchResult)); - } - }); + final Map keyedOutput = + batchResponseFn.apply(response).stream() + .collect(Collectors.toMap(Pair::getKey, Pair::getValue)); + + keyedInputs.forEach( + (id, input) -> { + final List> processInputs = inputs.remove(id); + if (processInputs == null) { + // no need to fail future here as we're only interested in its completion + // finishBundle will fail the checkState as we do not produce any result + LOG.error( + "The ID '{}' received in the gRPC batch response does not " + + "match any IDs extracted via the idExtractorFn for the requested " + + "batch sent to the gRPC endpoint. Please ensure that the IDs returned " + + "from the gRPC endpoints match the IDs extracted using the provided" + + "idExtractorFn for the same input.", + id); + } else { + List>> batchResult = + processInputs.stream() + .map( + processInput -> { + final Input i = processInput.getValue(); + final Output output = keyedOutput.get(id); + final TryWrapper o = + output == null + ? failure(new UnmatchedRequestException(id)) + : success(output); + final Instant ts = processInput.getTimestamp(); + final BoundedWindow w = processInput.getWindow(); + final PaneInfo p = processInput.getPane(); + return ValueInSingleWindow.of(KV.of(i, o), ts, w, p); + }) + .collect(Collectors.toList()); + results.add(Pair.of(key, batchResult)); + } + }); return null; }, throwable -> { - batchInput.forEach( - element -> { - final String id = idExtractorFn.apply(element); + keyedInputs.forEach( + (id, element) -> { final List>> batchResult = inputs.remove(id).stream() .map( diff --git a/scio-core/src/main/java/com/spotify/scio/transforms/UnmatchedRequestException.java b/scio-core/src/main/java/com/spotify/scio/transforms/UnmatchedRequestException.java new file mode 100644 index 0000000000..e6025f6b7f --- /dev/null +++ b/scio-core/src/main/java/com/spotify/scio/transforms/UnmatchedRequestException.java @@ -0,0 +1,45 @@ +/* + * Copyright 2024 Spotify AB + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.spotify.scio.transforms; + +import java.util.Objects; + +public class UnmatchedRequestException extends RuntimeException { + + private final String id; + + public UnmatchedRequestException(String id) { + super("Unmatched batch request for ID: " + id); + this.id = id; + } + + public String getId() { + return id; + } + + @Override + public boolean equals(Object o) { + if (o == null || getClass() != o.getClass()) return false; + UnmatchedRequestException that = (UnmatchedRequestException) o; + return Objects.equals(id, that.id); + } + + @Override + public int hashCode() { + return Objects.hashCode(id); + } +} diff --git a/scio-core/src/test/scala/com/spotify/scio/transforms/AsyncBatchLookupDoFnTest.scala b/scio-core/src/test/scala/com/spotify/scio/transforms/AsyncBatchLookupDoFnTest.scala index a620944a4e..2844f88d86 100644 --- a/scio-core/src/test/scala/com/spotify/scio/transforms/AsyncBatchLookupDoFnTest.scala +++ b/scio-core/src/test/scala/com/spotify/scio/transforms/AsyncBatchLookupDoFnTest.scala @@ -61,7 +61,12 @@ class AsyncBatchLookupDoFnTest extends PipelineSpec { doFn: BaseAsyncBatchLookupDoFn[Int, List[Int], List[String], String, AsyncBatchClient, F, T] )(tryFn: T => Try[String]): Unit = { // batches of size 4 and size 3 - val output = runWithData(Seq[Seq[Int]](1 to 4, 8 to 10))(_.flatten.parDo(doFn)).map { kv => + val output = runWithData( + Seq[Seq[Int]]( + 1 to 4, // 1 and 3 are unmatched + 8 to 10 // failure + ) + )(_.flatten.parDo(doFn)).map { kv => val r = tryFn(kv.getValue) match { case Success(v) => v case Failure(e: CompletionException) => e.getCause.getMessage @@ -70,8 +75,9 @@ class AsyncBatchLookupDoFnTest extends PipelineSpec { (kv.getKey, r) } output should contain theSameElementsAs ( - (1 to 4).map(x => x -> x.toString) ++ - (8 to 10).map(x => x -> "failure for 8,9,10") + Seq(1, 3).map(x => x -> s"Unmatched batch request for ID: $x") ++ + Seq(2, 4).map(x => x -> x.toString) ++ + Seq(8, 9, 10).map(x => x -> "failure for 8,9,10") ) } @@ -229,7 +235,7 @@ class FailingGuavaBatchLookupDoFn extends AbstractGuavaAsyncBatchLookupDoFn() { input: List[Int] ): ListenableFuture[List[String]] = if (input.size % 2 == 0) { - Futures.immediateFuture(input.map(_.toString)) + Futures.immediateFuture(input.filter(_ % 2 == 0).map(_.toString)) } else { Futures.immediateFailedFuture(new RuntimeException("failure for " + input.mkString(","))) } @@ -299,7 +305,7 @@ class FailingJavaBatchLookupDoFn extends AbstractJavaAsyncBatchLookupDoFn() { input: List[Int] ): CompletableFuture[List[String]] = if (input.size % 2 == 0) { - CompletableFuture.supplyAsync(() => input.map(_.toString)) + CompletableFuture.supplyAsync(() => input.filter(_ % 2 == 0).map(_.toString)) } else { val f = new CompletableFuture[List[String]]() f.completeExceptionally(new RuntimeException("failure for " + input.mkString(","))) @@ -347,7 +353,7 @@ class FailingScalaBatchLookupDoFn extends AbstractScalaAsyncBatchLookupDoFn() { override protected def newClient(): AsyncBatchClient = null override def asyncLookup(session: AsyncBatchClient, input: List[Int]): Future[List[String]] = if (input.size % 2 == 0) { - Future.successful(input.map(_.toString)) + Future.successful(input.filter(_ % 2 == 0).map(_.toString)) } else { Future.failed(new RuntimeException("failure for " + input.mkString(","))) } diff --git a/scio-grpc/src/test/scala/com/spotify/scio/grpc/GrpcBatchDoFnTest.scala b/scio-grpc/src/test/scala/com/spotify/scio/grpc/GrpcBatchDoFnTest.scala index b97c059ede..9079159589 100644 --- a/scio-grpc/src/test/scala/com/spotify/scio/grpc/GrpcBatchDoFnTest.scala +++ b/scio-grpc/src/test/scala/com/spotify/scio/grpc/GrpcBatchDoFnTest.scala @@ -21,16 +21,16 @@ import com.spotify.concat.v1.ConcatServiceGrpc.{ConcatServiceFutureStub, ConcatS import com.spotify.concat.v1._ import com.spotify.scio.testing.PipelineSpec import com.spotify.scio.transforms.BaseAsyncLookupDoFn.CacheSupplier +import com.spotify.scio.transforms.UnmatchedRequestException import io.grpc.netty.NettyChannelBuilder import io.grpc.stub.StreamObserver import io.grpc.{Server, ServerBuilder} -import org.apache.beam.sdk.Pipeline.PipelineExecutionException import org.scalatest.BeforeAndAfterAll import java.net.ServerSocket import java.util.stream.Collectors import scala.jdk.CollectionConverters._ -import scala.util.{Success, Try} +import scala.util.{Failure, Success, Try} object GrpcBatchDoFnTest { @@ -230,7 +230,7 @@ class GrpcBatchDoFnTest extends PipelineSpec with BeforeAndAfterAll { } } - it should "throw an IllegalStateException if gRPC response contains unknown ids" in { + it should "fail unmatched inputs" in { val input = (0 to 1).map { i => ConcatRequestWithID .newBuilder() @@ -240,30 +240,29 @@ class GrpcBatchDoFnTest extends PipelineSpec with BeforeAndAfterAll { .build() } - assertThrows[IllegalStateException] { - try { - runWithContext { sc => - sc.parallelize(input) - .grpcBatchLookup[ - BatchRequest, - BatchResponse, - ConcatResponseWithID, - ConcatServiceFutureStub - ]( - () => NettyChannelBuilder.forTarget(ServiceUri).usePlaintext().build(), - ConcatServiceGrpc.newFutureStub, - 2, - concatBatchRequest, - r => r.getResponseList.asScala.toSeq.map(e => ("WrongID-" + e.getRequestId, e)), - idExtractor, - 2 - )(_.batchConcat) - } - } catch { - case e: PipelineExecutionException => - e.getMessage should include("Expected requestCount == responseCount") - throw e.getCause - } + val expected: Seq[(ConcatRequestWithID, Try[ConcatResponseWithID])] = input.map { req => + req -> Failure(new UnmatchedRequestException(req.getRequestId)) + } + + runWithContext { sc => + val result = sc + .parallelize(input) + .grpcBatchLookup[ + BatchRequest, + BatchResponse, + ConcatResponseWithID, + ConcatServiceFutureStub + ]( + () => NettyChannelBuilder.forTarget(ServiceUri).usePlaintext().build(), + ConcatServiceGrpc.newFutureStub, + 2, + concatBatchRequest, + r => r.getResponseList.asScala.toSeq.map(e => ("WrongID-" + e.getRequestId, e)), + idExtractor, + 2 + )(_.batchConcat) + + result should containInAnyOrder(expected) } } }