diff --git a/build.sbt b/build.sbt index a9aeeab0d9..ec2e13db95 100644 --- a/build.sbt +++ b/build.sbt @@ -432,6 +432,10 @@ ThisBuild / mimaBinaryIssueFilters ++= Seq( // tf-metadata upgrade ProblemFilters.exclude[Problem]( "org.tensorflow.metadata.v0.*" + ), + // relax type hierarchy for batch stream + ProblemFilters.exclude[IncompatibleMethTypeProblem]( + "com.spotify.scio.grpc.GrpcBatchDoFn.asyncLookup" ) ) diff --git a/scio-core/src/main/java/com/spotify/scio/transforms/BaseAsyncBatchLookupDoFn.java b/scio-core/src/main/java/com/spotify/scio/transforms/BaseAsyncBatchLookupDoFn.java index 699b88b310..9ac82678c2 100644 --- a/scio-core/src/main/java/com/spotify/scio/transforms/BaseAsyncBatchLookupDoFn.java +++ b/scio-core/src/main/java/com/spotify/scio/transforms/BaseAsyncBatchLookupDoFn.java @@ -17,6 +17,7 @@ package com.spotify.scio.transforms; import static java.util.Objects.requireNonNull; +import static java.util.function.Function.identity; import com.google.common.cache.Cache; import com.spotify.scio.transforms.BaseAsyncLookupDoFn.CacheSupplier; @@ -26,6 +27,7 @@ import java.util.Collections; import java.util.LinkedList; import java.util.List; +import java.util.Map; import java.util.Queue; import java.util.UUID; import java.util.concurrent.ConcurrentHashMap; @@ -275,48 +277,53 @@ private void createRequest() throws InterruptedException { } private FutureType handleOutput(FutureType future, List batchInput, UUID key) { + final Map keyedInputs = + batchInput.stream().collect(Collectors.toMap(idExtractorFn::apply, identity())); return addCallback( future, response -> { - batchResponseFn - .apply(response) - .forEach( - pair -> { - final String id = pair.getLeft(); - final Output output = pair.getRight(); - final List> processInputs = inputs.remove(id); - if (processInputs == null) { - // no need to fail future here as we're only interested in its completion - // finishBundle will fail the checkState as we do not produce any result - LOG.error( - "The ID '{}' received in the gRPC batch response does not " - + "match any IDs extracted via the idExtractorFn for the requested " - + "batch sent to the gRPC endpoint. Please ensure that the IDs returned " - + "from the gRPC endpoints match the IDs extracted using the provided" - + "idExtractorFn for the same input.", - id); - } else { - final List>> batchResult = - processInputs.stream() - .map( - processInput -> { - final Input i = processInput.getValue(); - final TryWrapper o = success(output); - final Instant ts = processInput.getTimestamp(); - final BoundedWindow w = processInput.getWindow(); - final PaneInfo p = processInput.getPane(); - return ValueInSingleWindow.of(KV.of(i, o), ts, w, p); - }) - .collect(Collectors.toList()); - results.add(Pair.of(key, batchResult)); - } - }); + final Map keyedOutput = + batchResponseFn.apply(response).stream() + .collect(Collectors.toMap(Pair::getKey, Pair::getValue)); + + keyedInputs.forEach( + (id, input) -> { + final List> processInputs = inputs.remove(id); + if (processInputs == null) { + // no need to fail future here as we're only interested in its completion + // finishBundle will fail the checkState as we do not produce any result + LOG.error( + "The ID '{}' received in the gRPC batch response does not " + + "match any IDs extracted via the idExtractorFn for the requested " + + "batch sent to the gRPC endpoint. Please ensure that the IDs returned " + + "from the gRPC endpoints match the IDs extracted using the provided" + + "idExtractorFn for the same input.", + id); + } else { + List>> batchResult = + processInputs.stream() + .map( + processInput -> { + final Input i = processInput.getValue(); + final Output output = keyedOutput.get(id); + final TryWrapper o = + output == null + ? failure(new UnmatchedRequestException(id)) + : success(output); + final Instant ts = processInput.getTimestamp(); + final BoundedWindow w = processInput.getWindow(); + final PaneInfo p = processInput.getPane(); + return ValueInSingleWindow.of(KV.of(i, o), ts, w, p); + }) + .collect(Collectors.toList()); + results.add(Pair.of(key, batchResult)); + } + }); return null; }, throwable -> { - batchInput.forEach( - element -> { - final String id = idExtractorFn.apply(element); + keyedInputs.forEach( + (id, element) -> { final List>> batchResult = inputs.remove(id).stream() .map( diff --git a/scio-core/src/main/java/com/spotify/scio/transforms/UnmatchedRequestException.java b/scio-core/src/main/java/com/spotify/scio/transforms/UnmatchedRequestException.java new file mode 100644 index 0000000000..e6025f6b7f --- /dev/null +++ b/scio-core/src/main/java/com/spotify/scio/transforms/UnmatchedRequestException.java @@ -0,0 +1,45 @@ +/* + * Copyright 2024 Spotify AB + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.spotify.scio.transforms; + +import java.util.Objects; + +public class UnmatchedRequestException extends RuntimeException { + + private final String id; + + public UnmatchedRequestException(String id) { + super("Unmatched batch request for ID: " + id); + this.id = id; + } + + public String getId() { + return id; + } + + @Override + public boolean equals(Object o) { + if (o == null || getClass() != o.getClass()) return false; + UnmatchedRequestException that = (UnmatchedRequestException) o; + return Objects.equals(id, that.id); + } + + @Override + public int hashCode() { + return Objects.hashCode(id); + } +} diff --git a/scio-core/src/test/scala/com/spotify/scio/transforms/AsyncBatchLookupDoFnTest.scala b/scio-core/src/test/scala/com/spotify/scio/transforms/AsyncBatchLookupDoFnTest.scala index a620944a4e..2844f88d86 100644 --- a/scio-core/src/test/scala/com/spotify/scio/transforms/AsyncBatchLookupDoFnTest.scala +++ b/scio-core/src/test/scala/com/spotify/scio/transforms/AsyncBatchLookupDoFnTest.scala @@ -61,7 +61,12 @@ class AsyncBatchLookupDoFnTest extends PipelineSpec { doFn: BaseAsyncBatchLookupDoFn[Int, List[Int], List[String], String, AsyncBatchClient, F, T] )(tryFn: T => Try[String]): Unit = { // batches of size 4 and size 3 - val output = runWithData(Seq[Seq[Int]](1 to 4, 8 to 10))(_.flatten.parDo(doFn)).map { kv => + val output = runWithData( + Seq[Seq[Int]]( + 1 to 4, // 1 and 3 are unmatched + 8 to 10 // failure + ) + )(_.flatten.parDo(doFn)).map { kv => val r = tryFn(kv.getValue) match { case Success(v) => v case Failure(e: CompletionException) => e.getCause.getMessage @@ -70,8 +75,9 @@ class AsyncBatchLookupDoFnTest extends PipelineSpec { (kv.getKey, r) } output should contain theSameElementsAs ( - (1 to 4).map(x => x -> x.toString) ++ - (8 to 10).map(x => x -> "failure for 8,9,10") + Seq(1, 3).map(x => x -> s"Unmatched batch request for ID: $x") ++ + Seq(2, 4).map(x => x -> x.toString) ++ + Seq(8, 9, 10).map(x => x -> "failure for 8,9,10") ) } @@ -229,7 +235,7 @@ class FailingGuavaBatchLookupDoFn extends AbstractGuavaAsyncBatchLookupDoFn() { input: List[Int] ): ListenableFuture[List[String]] = if (input.size % 2 == 0) { - Futures.immediateFuture(input.map(_.toString)) + Futures.immediateFuture(input.filter(_ % 2 == 0).map(_.toString)) } else { Futures.immediateFailedFuture(new RuntimeException("failure for " + input.mkString(","))) } @@ -299,7 +305,7 @@ class FailingJavaBatchLookupDoFn extends AbstractJavaAsyncBatchLookupDoFn() { input: List[Int] ): CompletableFuture[List[String]] = if (input.size % 2 == 0) { - CompletableFuture.supplyAsync(() => input.map(_.toString)) + CompletableFuture.supplyAsync(() => input.filter(_ % 2 == 0).map(_.toString)) } else { val f = new CompletableFuture[List[String]]() f.completeExceptionally(new RuntimeException("failure for " + input.mkString(","))) @@ -347,7 +353,7 @@ class FailingScalaBatchLookupDoFn extends AbstractScalaAsyncBatchLookupDoFn() { override protected def newClient(): AsyncBatchClient = null override def asyncLookup(session: AsyncBatchClient, input: List[Int]): Future[List[String]] = if (input.size % 2 == 0) { - Future.successful(input.map(_.toString)) + Future.successful(input.filter(_ % 2 == 0).map(_.toString)) } else { Future.failed(new RuntimeException("failure for " + input.mkString(","))) } diff --git a/scio-google-cloud-platform/src/main/java/com/spotify/scio/bigtable/BigtableBatchDoFn.java b/scio-google-cloud-platform/src/main/java/com/spotify/scio/bigtable/BigtableBatchDoFn.java new file mode 100644 index 0000000000..f57e748ba0 --- /dev/null +++ b/scio-google-cloud-platform/src/main/java/com/spotify/scio/bigtable/BigtableBatchDoFn.java @@ -0,0 +1,126 @@ +/* + * Copyright 2024 Spotify AB. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package com.spotify.scio.bigtable; + +import com.google.cloud.bigtable.config.BigtableOptions; +import com.google.cloud.bigtable.grpc.BigtableSession; +import com.google.common.util.concurrent.ListenableFuture; +import com.spotify.scio.transforms.BaseAsyncLookupDoFn; +import com.spotify.scio.transforms.GuavaAsyncBatchLookupDoFn; +import java.io.IOException; +import java.util.List; +import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.SerializableFunction; +import org.apache.commons.lang3.tuple.Pair; + +/** + * A {@link DoFn} which batches elements and performs asynchronous lookup for them using Google + * Cloud Bigtable. + * + * @param input element type. + * @param batched input element type + * @param batched response from BigTable type + * @param Bigtable lookup value type. + */ +public abstract class BigtableBatchDoFn + extends GuavaAsyncBatchLookupDoFn { + + private final BigtableOptions options; + + /** Perform asynchronous Bigtable lookup. */ + public abstract ListenableFuture asyncLookup( + BigtableSession session, BatchRequest batchRequest); + + /** + * Create a {@link BigtableBatchDoFn} instance. + * + * @param options Bigtable options. + */ + public BigtableBatchDoFn( + BigtableOptions options, + int batchSize, + SerializableFunction, BatchRequest> batchRequestFn, + SerializableFunction>> batchResponseFn, + SerializableFunction idExtractorFn) { + this(options, batchSize, batchRequestFn, batchResponseFn, idExtractorFn, 1000); + } + + /** + * Create a {@link BigtableBatchDoFn} instance. + * + * @param options Bigtable options. + * @param maxPendingRequests maximum number of pending requests on every cloned DoFn. This + * prevents runner from timing out and retrying bundles. + */ + public BigtableBatchDoFn( + BigtableOptions options, + int batchSize, + SerializableFunction, BatchRequest> batchRequestFn, + SerializableFunction>> batchResponseFn, + SerializableFunction idExtractorFn, + int maxPendingRequests) { + this( + options, + batchSize, + batchRequestFn, + batchResponseFn, + idExtractorFn, + maxPendingRequests, + new BaseAsyncLookupDoFn.NoOpCacheSupplier<>()); + } + + /** + * Create a {@link BigtableBatchDoFn} instance. + * + * @param options Bigtable options. + * @param maxPendingRequests maximum number of pending requests on every cloned DoFn. This + * prevents runner from timing out and retrying bundles. + * @param cacheSupplier supplier for lookup cache. + */ + public BigtableBatchDoFn( + BigtableOptions options, + int batchSize, + SerializableFunction, BatchRequest> batchRequestFn, + SerializableFunction>> batchResponseFn, + SerializableFunction idExtractorFn, + int maxPendingRequests, + BaseAsyncLookupDoFn.CacheSupplier cacheSupplier) { + super( + batchSize, + batchRequestFn, + batchResponseFn, + idExtractorFn, + maxPendingRequests, + cacheSupplier); + this.options = options; + } + + @Override + public ResourceType getResourceType() { + // BigtableSession is backed by a gRPC thread safe client + return ResourceType.PER_INSTANCE; + } + + protected BigtableSession newClient() { + try { + return new BigtableSession(options); + } catch (IOException e) { + throw new RuntimeException(e); + } + } +} diff --git a/scio-google-cloud-platform/src/test/scala/com/spotify/scio/bigtable/BigTableBatchDoFnTest.scala b/scio-google-cloud-platform/src/test/scala/com/spotify/scio/bigtable/BigTableBatchDoFnTest.scala new file mode 100644 index 0000000000..fd579b42b9 --- /dev/null +++ b/scio-google-cloud-platform/src/test/scala/com/spotify/scio/bigtable/BigTableBatchDoFnTest.scala @@ -0,0 +1,136 @@ +/* + * Copyright 2024 Spotify AB. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package com.spotify.scio.bigtable + +import java.util.concurrent.{CompletionException, ConcurrentLinkedQueue} +import com.google.cloud.bigtable.grpc.BigtableSession +import com.google.common.cache.{Cache, CacheBuilder} +import com.google.common.util.concurrent.{Futures, ListenableFuture} +import com.spotify.scio.testing._ +import com.spotify.scio.transforms.BaseAsyncLookupDoFn.CacheSupplier +import com.spotify.scio.transforms.JavaAsyncConverters._ +import org.apache.commons.lang3.tuple.Pair + +import scala.jdk.CollectionConverters._ +import scala.util.{Failure, Success} + +class BigtableBatchDoFnTest extends PipelineSpec { + "BigtableDoFn" should "work" in { + val fn = new TestBigtableBatchDoFn + val output = runWithData(1 to 10)(_.parDo(fn)) + .map(kv => (kv.getKey, kv.getValue.get())) + output should contain theSameElementsAs (1 to 10).map(x => (x, x.toString)) + } + + it should "work with cache" in { + val fn = new TestCachingBigtableBatchDoFn + val output = runWithData((1 to 10) ++ (6 to 15))(_.parDo(fn)) + .map(kv => (kv.getKey, kv.getValue.get())) + output should have size 20 + output should contain theSameElementsAs ((1 to 10) ++ (6 to 15)).map(x => (x, x.toString)) + BigtableBatchDoFnTest.queue.asScala.toSet should contain theSameElementsAs (1 to 15) + BigtableBatchDoFnTest.queue.size() should be <= 20 + } + + it should "work with failures" in { + val fn = new TestFailingBigtableBatchDoFn + + val output = runWithData(Seq[Seq[Int]](1 to 4, 8 to 10))(_.flatten.parDo(fn)).map { kv => + val r = kv.getValue.asScala match { + case Success(v) => v + case Failure(e: CompletionException) => e.getCause.getMessage + case Failure(e) => e.getMessage + } + (kv.getKey, r) + } + output should contain theSameElementsAs ( + (1 to 4).map(x => x -> x.toString) ++ + (8 to 10).map(x => x -> "failure for 8,9,10") + ) + } +} + +object BigtableBatchDoFnTest { + val queue: ConcurrentLinkedQueue[Int] = new ConcurrentLinkedQueue[Int]() + + def batchRequest(input: java.util.List[Int]): List[Int] = input.asScala.toList + def batchResponse(input: List[String]): java.util.List[Pair[String, String]] = + input.map(x => Pair.of(x, x)).asJava + def idExtractor(input: Int): String = input.toString +} + +class TestBigtableBatchDoFn + extends BigtableBatchDoFn[Int, List[Int], List[String], String]( + null, + 2, + BigtableBatchDoFnTest.batchRequest, + BigtableBatchDoFnTest.batchResponse, + BigtableBatchDoFnTest.idExtractor + ) { + override def newClient(): BigtableSession = null + override def asyncLookup( + session: BigtableSession, + input: List[Int] + ): ListenableFuture[List[String]] = + Futures.immediateFuture(input.map(_.toString)) +} + +class TestCachingBigtableBatchDoFn + extends BigtableBatchDoFn[Int, List[Int], List[String], String]( + null, + 2, + BigtableBatchDoFnTest.batchRequest, + BigtableBatchDoFnTest.batchResponse, + BigtableBatchDoFnTest.idExtractor, + 100, + new TestCacheBatchSupplier + ) { + override def newClient(): BigtableSession = null + + override def asyncLookup( + session: BigtableSession, + input: List[Int] + ): ListenableFuture[List[String]] = { + input.foreach(BigtableBatchDoFnTest.queue.add) + Futures.immediateFuture(input.map(_.toString)) + } +} + +class TestFailingBigtableBatchDoFn + extends BigtableBatchDoFn[Int, List[Int], List[String], String]( + null, + 4, + BigtableBatchDoFnTest.batchRequest, + BigtableBatchDoFnTest.batchResponse, + BigtableBatchDoFnTest.idExtractor + ) { + override def newClient(): BigtableSession = null + override def asyncLookup( + session: BigtableSession, + input: List[Int] + ): ListenableFuture[List[String]] = + if (input.size % 2 == 0) { + Futures.immediateFuture(input.map(_.toString)) + } else { + Futures.immediateFailedFuture(new RuntimeException("failure for " + input.mkString(","))) + } +} + +class TestCacheBatchSupplier extends CacheSupplier[String, String] { + override def get(): Cache[String, String] = CacheBuilder.newBuilder().build() +} diff --git a/scio-grpc/src/main/java/com/spotify/scio/grpc/GrpcBatchDoFn.java b/scio-grpc/src/main/java/com/spotify/scio/grpc/GrpcBatchDoFn.java index 9ffe255d84..b668191303 100644 --- a/scio-grpc/src/main/java/com/spotify/scio/grpc/GrpcBatchDoFn.java +++ b/scio-grpc/src/main/java/com/spotify/scio/grpc/GrpcBatchDoFn.java @@ -25,7 +25,7 @@ import com.spotify.scio.transforms.BaseAsyncLookupDoFn.CacheSupplier; import com.spotify.scio.transforms.GuavaAsyncBatchLookupDoFn; import io.grpc.Channel; -import io.grpc.stub.AbstractFutureStub; +import io.grpc.stub.AbstractStub; import java.io.Serializable; import java.util.List; import org.apache.beam.sdk.transforms.SerializableBiFunction; @@ -43,7 +43,7 @@ * @param client type. */ public class GrpcBatchDoFn< - Input, BatchRequest, BatchResponse, Output, Client extends AbstractFutureStub> + Input, BatchRequest, BatchResponse, Output, Client extends AbstractStub> extends GuavaAsyncBatchLookupDoFn { private final ChannelSupplier channelSupplier; private final SerializableFunction newClientFn; @@ -104,21 +104,13 @@ protected Client newClient() { } public static < - Input, - BatchRequest, - BatchResponse, - Output, - ClientType extends AbstractFutureStub> + Input, BatchRequest, BatchResponse, Output, ClientType extends AbstractStub> Builder newBuilder() { return new Builder<>(); } public static class Builder< - Input, - BatchRequest, - BatchResponse, - Output, - ClientType extends AbstractFutureStub> + Input, BatchRequest, BatchResponse, Output, ClientType extends AbstractStub> implements Serializable { private ChannelSupplier channelSupplier; diff --git a/scio-grpc/src/main/scala/com/spotify/scio/grpc/SCollectionSyntax.scala b/scio-grpc/src/main/scala/com/spotify/scio/grpc/SCollectionSyntax.scala index 121f76bf3f..3752af62d5 100644 --- a/scio-grpc/src/main/scala/com/spotify/scio/grpc/SCollectionSyntax.scala +++ b/scio-grpc/src/main/scala/com/spotify/scio/grpc/SCollectionSyntax.scala @@ -147,6 +147,64 @@ class GrpcSCollectionOps[Request](private val self: SCollection[Request]) extend ).map(kvToTuple _) .mapValues(_.asScala) } + + def grpcLookupBatchStream[ + BatchRequest, + Response, + Result: Coder, + Client <: AbstractStub[Client] + ]( + channelSupplier: () => Channel, + clientFactory: Channel => Client, + batchSize: Int, + batchRequestFn: Seq[Request] => BatchRequest, + batchResponseFn: List[Response] => Seq[(String, Result)], + idExtractorFn: Request => String, + maxPendingRequests: Int, + cacheSupplier: CacheSupplier[String, Result] = new NoOpCacheSupplier[String, Result]() + )( + f: Client => (BatchRequest, StreamObserver[Response]) => Unit + ): SCollection[(Request, Try[Result])] = self.transform { in => + import self.coder + val cleanedChannelSupplier = ClosureCleaner.clean(channelSupplier) + val serializableClientFactory = Functions.serializableFn(clientFactory) + val serializableLookupFn = + Functions.serializableBiFn[Client, BatchRequest, ListenableFuture[JIterable[Response]]] { + (client, request) => + val observer = new StreamObservableFuture[Response]() + f(client)(request, observer) + observer + } + val serializableBatchRequestFn = + Functions.serializableFn[java.util.List[Request], BatchRequest] { inputs => + batchRequestFn(inputs.asScala.toSeq) + } + + val serializableBatchResponseFn = + Functions.serializableFn[JIterable[Response], java.util.List[Pair[String, Result]]] { + batchResponse => + batchResponseFn(batchResponse.asScala.toList).map { case (input, output) => + Pair.of(input, output) + }.asJava + } + val serializableIdExtractorFn = Functions.serializableFn(idExtractorFn) + in.parDo( + GrpcBatchDoFn + .newBuilder[Request, BatchRequest, JIterable[Response], Result, Client]() + .withChannelSupplier(() => cleanedChannelSupplier()) + .withNewClientFn(serializableClientFactory) + .withLookupFn(serializableLookupFn) + .withMaxPendingRequests(maxPendingRequests) + .withBatchSize(batchSize) + .withBatchRequestFn(serializableBatchRequestFn) + .withBatchResponseFn(serializableBatchResponseFn) + .withIdExtractorFn(serializableIdExtractorFn) + .withCacheSupplier(cacheSupplier) + .build() + ).map(kvToTuple _) + .mapValues(_.asScala) + } + } trait SCollectionSyntax { diff --git a/scio-grpc/src/test/protobuf/service.proto b/scio-grpc/src/test/protobuf/service.proto index 6b669a6da0..c8612b0cfb 100644 --- a/scio-grpc/src/test/protobuf/service.proto +++ b/scio-grpc/src/test/protobuf/service.proto @@ -43,6 +43,7 @@ service ConcatService { rpc ConcatClientStreaming(stream ConcatRequest) returns (ConcatResponse) {} rpc ConcatFullStreaming(stream ConcatRequest) returns (stream ConcatResponse) {} rpc BatchConcat(BatchRequest) returns (BatchResponse) {} + rpc BatchConcatServerStreaming(BatchRequest) returns (stream ConcatResponseWithID) {} rpc Ping(google.protobuf.Empty) returns (google.protobuf.Empty); } diff --git a/scio-grpc/src/test/scala/com/spotify/scio/grpc/GrpcBatchDoFnTest.scala b/scio-grpc/src/test/scala/com/spotify/scio/grpc/GrpcBatchDoFnTest.scala index b97c059ede..be6abfe1e7 100644 --- a/scio-grpc/src/test/scala/com/spotify/scio/grpc/GrpcBatchDoFnTest.scala +++ b/scio-grpc/src/test/scala/com/spotify/scio/grpc/GrpcBatchDoFnTest.scala @@ -17,20 +17,24 @@ package com.spotify.scio.grpc import com.google.common.cache.{Cache, CacheBuilder} -import com.spotify.concat.v1.ConcatServiceGrpc.{ConcatServiceFutureStub, ConcatServiceImplBase} +import com.spotify.concat.v1.ConcatServiceGrpc.{ + ConcatServiceFutureStub, + ConcatServiceImplBase, + ConcatServiceStub +} import com.spotify.concat.v1._ import com.spotify.scio.testing.PipelineSpec import com.spotify.scio.transforms.BaseAsyncLookupDoFn.CacheSupplier +import com.spotify.scio.transforms.UnmatchedRequestException import io.grpc.netty.NettyChannelBuilder import io.grpc.stub.StreamObserver import io.grpc.{Server, ServerBuilder} -import org.apache.beam.sdk.Pipeline.PipelineExecutionException import org.scalatest.BeforeAndAfterAll import java.net.ServerSocket import java.util.stream.Collectors import scala.jdk.CollectionConverters._ -import scala.util.{Success, Try} +import scala.util.{Failure, Success, Try} object GrpcBatchDoFnTest { @@ -66,6 +70,11 @@ object GrpcBatchDoFnTest { def concatBatchResponse(response: BatchResponse): Seq[(String, ConcatResponseWithID)] = response.getResponseList.asScala.toSeq.map(e => (e.getRequestId, e)) + def concatListResponse( + response: List[ConcatResponseWithID] + ): Seq[(String, ConcatResponseWithID)] = + response.map(e => (e.getRequestId, e)) + def idExtractor(concatRequest: ConcatRequestWithID): String = concatRequest.getRequestId @@ -83,6 +92,15 @@ object GrpcBatchDoFnTest { responseObserver.onNext(processBatch(request)) responseObserver.onCompleted() } + + override def batchConcatServerStreaming( + request: BatchRequest, + responseObserver: StreamObserver[ConcatResponseWithID] + ): Unit = { + val batchResponse = processBatch(request) + batchResponse.getResponseList.forEach(responseObserver.onNext) + responseObserver.onCompleted() + } } } @@ -138,6 +156,43 @@ class GrpcBatchDoFnTest extends PipelineSpec with BeforeAndAfterAll { } } + it should "issue request and propagate streamed responses" in { + val input = (0 to 10).map { i => + ConcatRequestWithID + .newBuilder() + .setRequestId(i.toString) + .setStringOne(i.toString) + .setStringTwo(i.toString) + .build() + } + + val expected: Seq[(ConcatRequestWithID, Try[ConcatResponseWithID])] = input.map { req => + val resp = concat(req) + req -> Success(resp) + } + + runWithContext { sc => + val result = sc + .parallelize(input) + .grpcLookupBatchStream[ + BatchRequest, + ConcatResponseWithID, + ConcatResponseWithID, + ConcatServiceStub + ]( + () => NettyChannelBuilder.forTarget(ServiceUri).usePlaintext().build(), + ConcatServiceGrpc.newStub, + 2, + concatBatchRequest, + concatListResponse, + idExtractor, + 2 + )(_.batchConcatServerStreaming) + + result should containInAnyOrder(expected) + } + } + it should "return cached responses" in { val request = ConcatRequestWithID .newBuilder() @@ -230,7 +285,7 @@ class GrpcBatchDoFnTest extends PipelineSpec with BeforeAndAfterAll { } } - it should "throw an IllegalStateException if gRPC response contains unknown ids" in { + it should "fail unmatched inputs" in { val input = (0 to 1).map { i => ConcatRequestWithID .newBuilder() @@ -240,30 +295,29 @@ class GrpcBatchDoFnTest extends PipelineSpec with BeforeAndAfterAll { .build() } - assertThrows[IllegalStateException] { - try { - runWithContext { sc => - sc.parallelize(input) - .grpcBatchLookup[ - BatchRequest, - BatchResponse, - ConcatResponseWithID, - ConcatServiceFutureStub - ]( - () => NettyChannelBuilder.forTarget(ServiceUri).usePlaintext().build(), - ConcatServiceGrpc.newFutureStub, - 2, - concatBatchRequest, - r => r.getResponseList.asScala.toSeq.map(e => ("WrongID-" + e.getRequestId, e)), - idExtractor, - 2 - )(_.batchConcat) - } - } catch { - case e: PipelineExecutionException => - e.getMessage should include("Expected requestCount == responseCount") - throw e.getCause - } + val expected: Seq[(ConcatRequestWithID, Try[ConcatResponseWithID])] = input.map { req => + req -> Failure(new UnmatchedRequestException(req.getRequestId)) + } + + runWithContext { sc => + val result = sc + .parallelize(input) + .grpcBatchLookup[ + BatchRequest, + BatchResponse, + ConcatResponseWithID, + ConcatServiceFutureStub + ]( + () => NettyChannelBuilder.forTarget(ServiceUri).usePlaintext().build(), + ConcatServiceGrpc.newFutureStub, + 2, + concatBatchRequest, + r => r.getResponseList.asScala.toSeq.map(e => ("WrongID-" + e.getRequestId, e)), + idExtractor, + 2 + )(_.batchConcat) + + result should containInAnyOrder(expected) } } }