diff --git a/build.sbt b/build.sbt
index a9aeeab0d9..ec2e13db95 100644
--- a/build.sbt
+++ b/build.sbt
@@ -432,6 +432,10 @@ ThisBuild / mimaBinaryIssueFilters ++= Seq(
// tf-metadata upgrade
ProblemFilters.exclude[Problem](
"org.tensorflow.metadata.v0.*"
+ ),
+ // relax type hierarchy for batch stream
+ ProblemFilters.exclude[IncompatibleMethTypeProblem](
+ "com.spotify.scio.grpc.GrpcBatchDoFn.asyncLookup"
)
)
diff --git a/scio-core/src/main/java/com/spotify/scio/transforms/BaseAsyncBatchLookupDoFn.java b/scio-core/src/main/java/com/spotify/scio/transforms/BaseAsyncBatchLookupDoFn.java
index 699b88b310..9ac82678c2 100644
--- a/scio-core/src/main/java/com/spotify/scio/transforms/BaseAsyncBatchLookupDoFn.java
+++ b/scio-core/src/main/java/com/spotify/scio/transforms/BaseAsyncBatchLookupDoFn.java
@@ -17,6 +17,7 @@
package com.spotify.scio.transforms;
import static java.util.Objects.requireNonNull;
+import static java.util.function.Function.identity;
import com.google.common.cache.Cache;
import com.spotify.scio.transforms.BaseAsyncLookupDoFn.CacheSupplier;
@@ -26,6 +27,7 @@
import java.util.Collections;
import java.util.LinkedList;
import java.util.List;
+import java.util.Map;
import java.util.Queue;
import java.util.UUID;
import java.util.concurrent.ConcurrentHashMap;
@@ -275,48 +277,53 @@ private void createRequest() throws InterruptedException {
}
private FutureType handleOutput(FutureType future, List batchInput, UUID key) {
+ final Map keyedInputs =
+ batchInput.stream().collect(Collectors.toMap(idExtractorFn::apply, identity()));
return addCallback(
future,
response -> {
- batchResponseFn
- .apply(response)
- .forEach(
- pair -> {
- final String id = pair.getLeft();
- final Output output = pair.getRight();
- final List> processInputs = inputs.remove(id);
- if (processInputs == null) {
- // no need to fail future here as we're only interested in its completion
- // finishBundle will fail the checkState as we do not produce any result
- LOG.error(
- "The ID '{}' received in the gRPC batch response does not "
- + "match any IDs extracted via the idExtractorFn for the requested "
- + "batch sent to the gRPC endpoint. Please ensure that the IDs returned "
- + "from the gRPC endpoints match the IDs extracted using the provided"
- + "idExtractorFn for the same input.",
- id);
- } else {
- final List>> batchResult =
- processInputs.stream()
- .map(
- processInput -> {
- final Input i = processInput.getValue();
- final TryWrapper o = success(output);
- final Instant ts = processInput.getTimestamp();
- final BoundedWindow w = processInput.getWindow();
- final PaneInfo p = processInput.getPane();
- return ValueInSingleWindow.of(KV.of(i, o), ts, w, p);
- })
- .collect(Collectors.toList());
- results.add(Pair.of(key, batchResult));
- }
- });
+ final Map keyedOutput =
+ batchResponseFn.apply(response).stream()
+ .collect(Collectors.toMap(Pair::getKey, Pair::getValue));
+
+ keyedInputs.forEach(
+ (id, input) -> {
+ final List> processInputs = inputs.remove(id);
+ if (processInputs == null) {
+ // no need to fail future here as we're only interested in its completion
+ // finishBundle will fail the checkState as we do not produce any result
+ LOG.error(
+ "The ID '{}' received in the gRPC batch response does not "
+ + "match any IDs extracted via the idExtractorFn for the requested "
+ + "batch sent to the gRPC endpoint. Please ensure that the IDs returned "
+ + "from the gRPC endpoints match the IDs extracted using the provided"
+ + "idExtractorFn for the same input.",
+ id);
+ } else {
+ List>> batchResult =
+ processInputs.stream()
+ .map(
+ processInput -> {
+ final Input i = processInput.getValue();
+ final Output output = keyedOutput.get(id);
+ final TryWrapper o =
+ output == null
+ ? failure(new UnmatchedRequestException(id))
+ : success(output);
+ final Instant ts = processInput.getTimestamp();
+ final BoundedWindow w = processInput.getWindow();
+ final PaneInfo p = processInput.getPane();
+ return ValueInSingleWindow.of(KV.of(i, o), ts, w, p);
+ })
+ .collect(Collectors.toList());
+ results.add(Pair.of(key, batchResult));
+ }
+ });
return null;
},
throwable -> {
- batchInput.forEach(
- element -> {
- final String id = idExtractorFn.apply(element);
+ keyedInputs.forEach(
+ (id, element) -> {
final List>> batchResult =
inputs.remove(id).stream()
.map(
diff --git a/scio-core/src/main/java/com/spotify/scio/transforms/UnmatchedRequestException.java b/scio-core/src/main/java/com/spotify/scio/transforms/UnmatchedRequestException.java
new file mode 100644
index 0000000000..e6025f6b7f
--- /dev/null
+++ b/scio-core/src/main/java/com/spotify/scio/transforms/UnmatchedRequestException.java
@@ -0,0 +1,45 @@
+/*
+ * Copyright 2024 Spotify AB
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.spotify.scio.transforms;
+
+import java.util.Objects;
+
+public class UnmatchedRequestException extends RuntimeException {
+
+ private final String id;
+
+ public UnmatchedRequestException(String id) {
+ super("Unmatched batch request for ID: " + id);
+ this.id = id;
+ }
+
+ public String getId() {
+ return id;
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (o == null || getClass() != o.getClass()) return false;
+ UnmatchedRequestException that = (UnmatchedRequestException) o;
+ return Objects.equals(id, that.id);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hashCode(id);
+ }
+}
diff --git a/scio-core/src/test/scala/com/spotify/scio/transforms/AsyncBatchLookupDoFnTest.scala b/scio-core/src/test/scala/com/spotify/scio/transforms/AsyncBatchLookupDoFnTest.scala
index a620944a4e..2844f88d86 100644
--- a/scio-core/src/test/scala/com/spotify/scio/transforms/AsyncBatchLookupDoFnTest.scala
+++ b/scio-core/src/test/scala/com/spotify/scio/transforms/AsyncBatchLookupDoFnTest.scala
@@ -61,7 +61,12 @@ class AsyncBatchLookupDoFnTest extends PipelineSpec {
doFn: BaseAsyncBatchLookupDoFn[Int, List[Int], List[String], String, AsyncBatchClient, F, T]
)(tryFn: T => Try[String]): Unit = {
// batches of size 4 and size 3
- val output = runWithData(Seq[Seq[Int]](1 to 4, 8 to 10))(_.flatten.parDo(doFn)).map { kv =>
+ val output = runWithData(
+ Seq[Seq[Int]](
+ 1 to 4, // 1 and 3 are unmatched
+ 8 to 10 // failure
+ )
+ )(_.flatten.parDo(doFn)).map { kv =>
val r = tryFn(kv.getValue) match {
case Success(v) => v
case Failure(e: CompletionException) => e.getCause.getMessage
@@ -70,8 +75,9 @@ class AsyncBatchLookupDoFnTest extends PipelineSpec {
(kv.getKey, r)
}
output should contain theSameElementsAs (
- (1 to 4).map(x => x -> x.toString) ++
- (8 to 10).map(x => x -> "failure for 8,9,10")
+ Seq(1, 3).map(x => x -> s"Unmatched batch request for ID: $x") ++
+ Seq(2, 4).map(x => x -> x.toString) ++
+ Seq(8, 9, 10).map(x => x -> "failure for 8,9,10")
)
}
@@ -229,7 +235,7 @@ class FailingGuavaBatchLookupDoFn extends AbstractGuavaAsyncBatchLookupDoFn() {
input: List[Int]
): ListenableFuture[List[String]] =
if (input.size % 2 == 0) {
- Futures.immediateFuture(input.map(_.toString))
+ Futures.immediateFuture(input.filter(_ % 2 == 0).map(_.toString))
} else {
Futures.immediateFailedFuture(new RuntimeException("failure for " + input.mkString(",")))
}
@@ -299,7 +305,7 @@ class FailingJavaBatchLookupDoFn extends AbstractJavaAsyncBatchLookupDoFn() {
input: List[Int]
): CompletableFuture[List[String]] =
if (input.size % 2 == 0) {
- CompletableFuture.supplyAsync(() => input.map(_.toString))
+ CompletableFuture.supplyAsync(() => input.filter(_ % 2 == 0).map(_.toString))
} else {
val f = new CompletableFuture[List[String]]()
f.completeExceptionally(new RuntimeException("failure for " + input.mkString(",")))
@@ -347,7 +353,7 @@ class FailingScalaBatchLookupDoFn extends AbstractScalaAsyncBatchLookupDoFn() {
override protected def newClient(): AsyncBatchClient = null
override def asyncLookup(session: AsyncBatchClient, input: List[Int]): Future[List[String]] =
if (input.size % 2 == 0) {
- Future.successful(input.map(_.toString))
+ Future.successful(input.filter(_ % 2 == 0).map(_.toString))
} else {
Future.failed(new RuntimeException("failure for " + input.mkString(",")))
}
diff --git a/scio-google-cloud-platform/src/main/java/com/spotify/scio/bigtable/BigtableBatchDoFn.java b/scio-google-cloud-platform/src/main/java/com/spotify/scio/bigtable/BigtableBatchDoFn.java
new file mode 100644
index 0000000000..f57e748ba0
--- /dev/null
+++ b/scio-google-cloud-platform/src/main/java/com/spotify/scio/bigtable/BigtableBatchDoFn.java
@@ -0,0 +1,126 @@
+/*
+ * Copyright 2024 Spotify AB.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package com.spotify.scio.bigtable;
+
+import com.google.cloud.bigtable.config.BigtableOptions;
+import com.google.cloud.bigtable.grpc.BigtableSession;
+import com.google.common.util.concurrent.ListenableFuture;
+import com.spotify.scio.transforms.BaseAsyncLookupDoFn;
+import com.spotify.scio.transforms.GuavaAsyncBatchLookupDoFn;
+import java.io.IOException;
+import java.util.List;
+import org.apache.beam.sdk.transforms.DoFn;
+import org.apache.beam.sdk.transforms.SerializableFunction;
+import org.apache.commons.lang3.tuple.Pair;
+
+/**
+ * A {@link DoFn} which batches elements and performs asynchronous lookup for them using Google
+ * Cloud Bigtable.
+ *
+ * @param input element type.
+ * @param batched input element type
+ * @param batched response from BigTable type
+ * @param Bigtable lookup value type.
+ */
+public abstract class BigtableBatchDoFn
+ extends GuavaAsyncBatchLookupDoFn {
+
+ private final BigtableOptions options;
+
+ /** Perform asynchronous Bigtable lookup. */
+ public abstract ListenableFuture asyncLookup(
+ BigtableSession session, BatchRequest batchRequest);
+
+ /**
+ * Create a {@link BigtableBatchDoFn} instance.
+ *
+ * @param options Bigtable options.
+ */
+ public BigtableBatchDoFn(
+ BigtableOptions options,
+ int batchSize,
+ SerializableFunction, BatchRequest> batchRequestFn,
+ SerializableFunction>> batchResponseFn,
+ SerializableFunction idExtractorFn) {
+ this(options, batchSize, batchRequestFn, batchResponseFn, idExtractorFn, 1000);
+ }
+
+ /**
+ * Create a {@link BigtableBatchDoFn} instance.
+ *
+ * @param options Bigtable options.
+ * @param maxPendingRequests maximum number of pending requests on every cloned DoFn. This
+ * prevents runner from timing out and retrying bundles.
+ */
+ public BigtableBatchDoFn(
+ BigtableOptions options,
+ int batchSize,
+ SerializableFunction, BatchRequest> batchRequestFn,
+ SerializableFunction>> batchResponseFn,
+ SerializableFunction idExtractorFn,
+ int maxPendingRequests) {
+ this(
+ options,
+ batchSize,
+ batchRequestFn,
+ batchResponseFn,
+ idExtractorFn,
+ maxPendingRequests,
+ new BaseAsyncLookupDoFn.NoOpCacheSupplier<>());
+ }
+
+ /**
+ * Create a {@link BigtableBatchDoFn} instance.
+ *
+ * @param options Bigtable options.
+ * @param maxPendingRequests maximum number of pending requests on every cloned DoFn. This
+ * prevents runner from timing out and retrying bundles.
+ * @param cacheSupplier supplier for lookup cache.
+ */
+ public BigtableBatchDoFn(
+ BigtableOptions options,
+ int batchSize,
+ SerializableFunction, BatchRequest> batchRequestFn,
+ SerializableFunction>> batchResponseFn,
+ SerializableFunction idExtractorFn,
+ int maxPendingRequests,
+ BaseAsyncLookupDoFn.CacheSupplier cacheSupplier) {
+ super(
+ batchSize,
+ batchRequestFn,
+ batchResponseFn,
+ idExtractorFn,
+ maxPendingRequests,
+ cacheSupplier);
+ this.options = options;
+ }
+
+ @Override
+ public ResourceType getResourceType() {
+ // BigtableSession is backed by a gRPC thread safe client
+ return ResourceType.PER_INSTANCE;
+ }
+
+ protected BigtableSession newClient() {
+ try {
+ return new BigtableSession(options);
+ } catch (IOException e) {
+ throw new RuntimeException(e);
+ }
+ }
+}
diff --git a/scio-google-cloud-platform/src/test/scala/com/spotify/scio/bigtable/BigTableBatchDoFnTest.scala b/scio-google-cloud-platform/src/test/scala/com/spotify/scio/bigtable/BigTableBatchDoFnTest.scala
new file mode 100644
index 0000000000..fd579b42b9
--- /dev/null
+++ b/scio-google-cloud-platform/src/test/scala/com/spotify/scio/bigtable/BigTableBatchDoFnTest.scala
@@ -0,0 +1,136 @@
+/*
+ * Copyright 2024 Spotify AB.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package com.spotify.scio.bigtable
+
+import java.util.concurrent.{CompletionException, ConcurrentLinkedQueue}
+import com.google.cloud.bigtable.grpc.BigtableSession
+import com.google.common.cache.{Cache, CacheBuilder}
+import com.google.common.util.concurrent.{Futures, ListenableFuture}
+import com.spotify.scio.testing._
+import com.spotify.scio.transforms.BaseAsyncLookupDoFn.CacheSupplier
+import com.spotify.scio.transforms.JavaAsyncConverters._
+import org.apache.commons.lang3.tuple.Pair
+
+import scala.jdk.CollectionConverters._
+import scala.util.{Failure, Success}
+
+class BigtableBatchDoFnTest extends PipelineSpec {
+ "BigtableDoFn" should "work" in {
+ val fn = new TestBigtableBatchDoFn
+ val output = runWithData(1 to 10)(_.parDo(fn))
+ .map(kv => (kv.getKey, kv.getValue.get()))
+ output should contain theSameElementsAs (1 to 10).map(x => (x, x.toString))
+ }
+
+ it should "work with cache" in {
+ val fn = new TestCachingBigtableBatchDoFn
+ val output = runWithData((1 to 10) ++ (6 to 15))(_.parDo(fn))
+ .map(kv => (kv.getKey, kv.getValue.get()))
+ output should have size 20
+ output should contain theSameElementsAs ((1 to 10) ++ (6 to 15)).map(x => (x, x.toString))
+ BigtableBatchDoFnTest.queue.asScala.toSet should contain theSameElementsAs (1 to 15)
+ BigtableBatchDoFnTest.queue.size() should be <= 20
+ }
+
+ it should "work with failures" in {
+ val fn = new TestFailingBigtableBatchDoFn
+
+ val output = runWithData(Seq[Seq[Int]](1 to 4, 8 to 10))(_.flatten.parDo(fn)).map { kv =>
+ val r = kv.getValue.asScala match {
+ case Success(v) => v
+ case Failure(e: CompletionException) => e.getCause.getMessage
+ case Failure(e) => e.getMessage
+ }
+ (kv.getKey, r)
+ }
+ output should contain theSameElementsAs (
+ (1 to 4).map(x => x -> x.toString) ++
+ (8 to 10).map(x => x -> "failure for 8,9,10")
+ )
+ }
+}
+
+object BigtableBatchDoFnTest {
+ val queue: ConcurrentLinkedQueue[Int] = new ConcurrentLinkedQueue[Int]()
+
+ def batchRequest(input: java.util.List[Int]): List[Int] = input.asScala.toList
+ def batchResponse(input: List[String]): java.util.List[Pair[String, String]] =
+ input.map(x => Pair.of(x, x)).asJava
+ def idExtractor(input: Int): String = input.toString
+}
+
+class TestBigtableBatchDoFn
+ extends BigtableBatchDoFn[Int, List[Int], List[String], String](
+ null,
+ 2,
+ BigtableBatchDoFnTest.batchRequest,
+ BigtableBatchDoFnTest.batchResponse,
+ BigtableBatchDoFnTest.idExtractor
+ ) {
+ override def newClient(): BigtableSession = null
+ override def asyncLookup(
+ session: BigtableSession,
+ input: List[Int]
+ ): ListenableFuture[List[String]] =
+ Futures.immediateFuture(input.map(_.toString))
+}
+
+class TestCachingBigtableBatchDoFn
+ extends BigtableBatchDoFn[Int, List[Int], List[String], String](
+ null,
+ 2,
+ BigtableBatchDoFnTest.batchRequest,
+ BigtableBatchDoFnTest.batchResponse,
+ BigtableBatchDoFnTest.idExtractor,
+ 100,
+ new TestCacheBatchSupplier
+ ) {
+ override def newClient(): BigtableSession = null
+
+ override def asyncLookup(
+ session: BigtableSession,
+ input: List[Int]
+ ): ListenableFuture[List[String]] = {
+ input.foreach(BigtableBatchDoFnTest.queue.add)
+ Futures.immediateFuture(input.map(_.toString))
+ }
+}
+
+class TestFailingBigtableBatchDoFn
+ extends BigtableBatchDoFn[Int, List[Int], List[String], String](
+ null,
+ 4,
+ BigtableBatchDoFnTest.batchRequest,
+ BigtableBatchDoFnTest.batchResponse,
+ BigtableBatchDoFnTest.idExtractor
+ ) {
+ override def newClient(): BigtableSession = null
+ override def asyncLookup(
+ session: BigtableSession,
+ input: List[Int]
+ ): ListenableFuture[List[String]] =
+ if (input.size % 2 == 0) {
+ Futures.immediateFuture(input.map(_.toString))
+ } else {
+ Futures.immediateFailedFuture(new RuntimeException("failure for " + input.mkString(",")))
+ }
+}
+
+class TestCacheBatchSupplier extends CacheSupplier[String, String] {
+ override def get(): Cache[String, String] = CacheBuilder.newBuilder().build()
+}
diff --git a/scio-grpc/src/main/java/com/spotify/scio/grpc/GrpcBatchDoFn.java b/scio-grpc/src/main/java/com/spotify/scio/grpc/GrpcBatchDoFn.java
index 9ffe255d84..b668191303 100644
--- a/scio-grpc/src/main/java/com/spotify/scio/grpc/GrpcBatchDoFn.java
+++ b/scio-grpc/src/main/java/com/spotify/scio/grpc/GrpcBatchDoFn.java
@@ -25,7 +25,7 @@
import com.spotify.scio.transforms.BaseAsyncLookupDoFn.CacheSupplier;
import com.spotify.scio.transforms.GuavaAsyncBatchLookupDoFn;
import io.grpc.Channel;
-import io.grpc.stub.AbstractFutureStub;
+import io.grpc.stub.AbstractStub;
import java.io.Serializable;
import java.util.List;
import org.apache.beam.sdk.transforms.SerializableBiFunction;
@@ -43,7 +43,7 @@
* @param client type.
*/
public class GrpcBatchDoFn<
- Input, BatchRequest, BatchResponse, Output, Client extends AbstractFutureStub>
+ Input, BatchRequest, BatchResponse, Output, Client extends AbstractStub>
extends GuavaAsyncBatchLookupDoFn {
private final ChannelSupplier channelSupplier;
private final SerializableFunction newClientFn;
@@ -104,21 +104,13 @@ protected Client newClient() {
}
public static <
- Input,
- BatchRequest,
- BatchResponse,
- Output,
- ClientType extends AbstractFutureStub>
+ Input, BatchRequest, BatchResponse, Output, ClientType extends AbstractStub>
Builder newBuilder() {
return new Builder<>();
}
public static class Builder<
- Input,
- BatchRequest,
- BatchResponse,
- Output,
- ClientType extends AbstractFutureStub>
+ Input, BatchRequest, BatchResponse, Output, ClientType extends AbstractStub>
implements Serializable {
private ChannelSupplier channelSupplier;
diff --git a/scio-grpc/src/main/scala/com/spotify/scio/grpc/SCollectionSyntax.scala b/scio-grpc/src/main/scala/com/spotify/scio/grpc/SCollectionSyntax.scala
index 121f76bf3f..3752af62d5 100644
--- a/scio-grpc/src/main/scala/com/spotify/scio/grpc/SCollectionSyntax.scala
+++ b/scio-grpc/src/main/scala/com/spotify/scio/grpc/SCollectionSyntax.scala
@@ -147,6 +147,64 @@ class GrpcSCollectionOps[Request](private val self: SCollection[Request]) extend
).map(kvToTuple _)
.mapValues(_.asScala)
}
+
+ def grpcLookupBatchStream[
+ BatchRequest,
+ Response,
+ Result: Coder,
+ Client <: AbstractStub[Client]
+ ](
+ channelSupplier: () => Channel,
+ clientFactory: Channel => Client,
+ batchSize: Int,
+ batchRequestFn: Seq[Request] => BatchRequest,
+ batchResponseFn: List[Response] => Seq[(String, Result)],
+ idExtractorFn: Request => String,
+ maxPendingRequests: Int,
+ cacheSupplier: CacheSupplier[String, Result] = new NoOpCacheSupplier[String, Result]()
+ )(
+ f: Client => (BatchRequest, StreamObserver[Response]) => Unit
+ ): SCollection[(Request, Try[Result])] = self.transform { in =>
+ import self.coder
+ val cleanedChannelSupplier = ClosureCleaner.clean(channelSupplier)
+ val serializableClientFactory = Functions.serializableFn(clientFactory)
+ val serializableLookupFn =
+ Functions.serializableBiFn[Client, BatchRequest, ListenableFuture[JIterable[Response]]] {
+ (client, request) =>
+ val observer = new StreamObservableFuture[Response]()
+ f(client)(request, observer)
+ observer
+ }
+ val serializableBatchRequestFn =
+ Functions.serializableFn[java.util.List[Request], BatchRequest] { inputs =>
+ batchRequestFn(inputs.asScala.toSeq)
+ }
+
+ val serializableBatchResponseFn =
+ Functions.serializableFn[JIterable[Response], java.util.List[Pair[String, Result]]] {
+ batchResponse =>
+ batchResponseFn(batchResponse.asScala.toList).map { case (input, output) =>
+ Pair.of(input, output)
+ }.asJava
+ }
+ val serializableIdExtractorFn = Functions.serializableFn(idExtractorFn)
+ in.parDo(
+ GrpcBatchDoFn
+ .newBuilder[Request, BatchRequest, JIterable[Response], Result, Client]()
+ .withChannelSupplier(() => cleanedChannelSupplier())
+ .withNewClientFn(serializableClientFactory)
+ .withLookupFn(serializableLookupFn)
+ .withMaxPendingRequests(maxPendingRequests)
+ .withBatchSize(batchSize)
+ .withBatchRequestFn(serializableBatchRequestFn)
+ .withBatchResponseFn(serializableBatchResponseFn)
+ .withIdExtractorFn(serializableIdExtractorFn)
+ .withCacheSupplier(cacheSupplier)
+ .build()
+ ).map(kvToTuple _)
+ .mapValues(_.asScala)
+ }
+
}
trait SCollectionSyntax {
diff --git a/scio-grpc/src/test/protobuf/service.proto b/scio-grpc/src/test/protobuf/service.proto
index 6b669a6da0..c8612b0cfb 100644
--- a/scio-grpc/src/test/protobuf/service.proto
+++ b/scio-grpc/src/test/protobuf/service.proto
@@ -43,6 +43,7 @@ service ConcatService {
rpc ConcatClientStreaming(stream ConcatRequest) returns (ConcatResponse) {}
rpc ConcatFullStreaming(stream ConcatRequest) returns (stream ConcatResponse) {}
rpc BatchConcat(BatchRequest) returns (BatchResponse) {}
+ rpc BatchConcatServerStreaming(BatchRequest) returns (stream ConcatResponseWithID) {}
rpc Ping(google.protobuf.Empty) returns (google.protobuf.Empty);
}
diff --git a/scio-grpc/src/test/scala/com/spotify/scio/grpc/GrpcBatchDoFnTest.scala b/scio-grpc/src/test/scala/com/spotify/scio/grpc/GrpcBatchDoFnTest.scala
index b97c059ede..be6abfe1e7 100644
--- a/scio-grpc/src/test/scala/com/spotify/scio/grpc/GrpcBatchDoFnTest.scala
+++ b/scio-grpc/src/test/scala/com/spotify/scio/grpc/GrpcBatchDoFnTest.scala
@@ -17,20 +17,24 @@
package com.spotify.scio.grpc
import com.google.common.cache.{Cache, CacheBuilder}
-import com.spotify.concat.v1.ConcatServiceGrpc.{ConcatServiceFutureStub, ConcatServiceImplBase}
+import com.spotify.concat.v1.ConcatServiceGrpc.{
+ ConcatServiceFutureStub,
+ ConcatServiceImplBase,
+ ConcatServiceStub
+}
import com.spotify.concat.v1._
import com.spotify.scio.testing.PipelineSpec
import com.spotify.scio.transforms.BaseAsyncLookupDoFn.CacheSupplier
+import com.spotify.scio.transforms.UnmatchedRequestException
import io.grpc.netty.NettyChannelBuilder
import io.grpc.stub.StreamObserver
import io.grpc.{Server, ServerBuilder}
-import org.apache.beam.sdk.Pipeline.PipelineExecutionException
import org.scalatest.BeforeAndAfterAll
import java.net.ServerSocket
import java.util.stream.Collectors
import scala.jdk.CollectionConverters._
-import scala.util.{Success, Try}
+import scala.util.{Failure, Success, Try}
object GrpcBatchDoFnTest {
@@ -66,6 +70,11 @@ object GrpcBatchDoFnTest {
def concatBatchResponse(response: BatchResponse): Seq[(String, ConcatResponseWithID)] =
response.getResponseList.asScala.toSeq.map(e => (e.getRequestId, e))
+ def concatListResponse(
+ response: List[ConcatResponseWithID]
+ ): Seq[(String, ConcatResponseWithID)] =
+ response.map(e => (e.getRequestId, e))
+
def idExtractor(concatRequest: ConcatRequestWithID): String =
concatRequest.getRequestId
@@ -83,6 +92,15 @@ object GrpcBatchDoFnTest {
responseObserver.onNext(processBatch(request))
responseObserver.onCompleted()
}
+
+ override def batchConcatServerStreaming(
+ request: BatchRequest,
+ responseObserver: StreamObserver[ConcatResponseWithID]
+ ): Unit = {
+ val batchResponse = processBatch(request)
+ batchResponse.getResponseList.forEach(responseObserver.onNext)
+ responseObserver.onCompleted()
+ }
}
}
@@ -138,6 +156,43 @@ class GrpcBatchDoFnTest extends PipelineSpec with BeforeAndAfterAll {
}
}
+ it should "issue request and propagate streamed responses" in {
+ val input = (0 to 10).map { i =>
+ ConcatRequestWithID
+ .newBuilder()
+ .setRequestId(i.toString)
+ .setStringOne(i.toString)
+ .setStringTwo(i.toString)
+ .build()
+ }
+
+ val expected: Seq[(ConcatRequestWithID, Try[ConcatResponseWithID])] = input.map { req =>
+ val resp = concat(req)
+ req -> Success(resp)
+ }
+
+ runWithContext { sc =>
+ val result = sc
+ .parallelize(input)
+ .grpcLookupBatchStream[
+ BatchRequest,
+ ConcatResponseWithID,
+ ConcatResponseWithID,
+ ConcatServiceStub
+ ](
+ () => NettyChannelBuilder.forTarget(ServiceUri).usePlaintext().build(),
+ ConcatServiceGrpc.newStub,
+ 2,
+ concatBatchRequest,
+ concatListResponse,
+ idExtractor,
+ 2
+ )(_.batchConcatServerStreaming)
+
+ result should containInAnyOrder(expected)
+ }
+ }
+
it should "return cached responses" in {
val request = ConcatRequestWithID
.newBuilder()
@@ -230,7 +285,7 @@ class GrpcBatchDoFnTest extends PipelineSpec with BeforeAndAfterAll {
}
}
- it should "throw an IllegalStateException if gRPC response contains unknown ids" in {
+ it should "fail unmatched inputs" in {
val input = (0 to 1).map { i =>
ConcatRequestWithID
.newBuilder()
@@ -240,30 +295,29 @@ class GrpcBatchDoFnTest extends PipelineSpec with BeforeAndAfterAll {
.build()
}
- assertThrows[IllegalStateException] {
- try {
- runWithContext { sc =>
- sc.parallelize(input)
- .grpcBatchLookup[
- BatchRequest,
- BatchResponse,
- ConcatResponseWithID,
- ConcatServiceFutureStub
- ](
- () => NettyChannelBuilder.forTarget(ServiceUri).usePlaintext().build(),
- ConcatServiceGrpc.newFutureStub,
- 2,
- concatBatchRequest,
- r => r.getResponseList.asScala.toSeq.map(e => ("WrongID-" + e.getRequestId, e)),
- idExtractor,
- 2
- )(_.batchConcat)
- }
- } catch {
- case e: PipelineExecutionException =>
- e.getMessage should include("Expected requestCount == responseCount")
- throw e.getCause
- }
+ val expected: Seq[(ConcatRequestWithID, Try[ConcatResponseWithID])] = input.map { req =>
+ req -> Failure(new UnmatchedRequestException(req.getRequestId))
+ }
+
+ runWithContext { sc =>
+ val result = sc
+ .parallelize(input)
+ .grpcBatchLookup[
+ BatchRequest,
+ BatchResponse,
+ ConcatResponseWithID,
+ ConcatServiceFutureStub
+ ](
+ () => NettyChannelBuilder.forTarget(ServiceUri).usePlaintext().build(),
+ ConcatServiceGrpc.newFutureStub,
+ 2,
+ concatBatchRequest,
+ r => r.getResponseList.asScala.toSeq.map(e => ("WrongID-" + e.getRequestId, e)),
+ idExtractor,
+ 2
+ )(_.batchConcat)
+
+ result should containInAnyOrder(expected)
}
}
}