Skip to content

Commit

Permalink
Extend batched lookup functionality
Browse files Browse the repository at this point in the history
Add batched version of  grpcLookupStream and BigTableDoFn
  • Loading branch information
lofifnc committed Nov 29, 2024
1 parent 4427ece commit d39cba0
Show file tree
Hide file tree
Showing 6 changed files with 381 additions and 13 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
/*
* Copyright 2024 Spotify AB.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package com.spotify.scio.bigtable;

import com.google.cloud.bigtable.config.BigtableOptions;
import com.google.cloud.bigtable.grpc.BigtableSession;
import com.google.common.util.concurrent.ListenableFuture;
import com.spotify.scio.transforms.BaseAsyncLookupDoFn;
import com.spotify.scio.transforms.GuavaAsyncBatchLookupDoFn;
import java.io.IOException;
import java.util.List;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.SerializableFunction;
import org.apache.commons.lang3.tuple.Pair;

/**
* A {@link DoFn} which batches elements and performs asynchronous lookup for them using Google
* Cloud Bigtable.
*
* @param <Input> input element type.
* @param <BatchRequest> batched input element type
* @param <BatchResponse> batched response from BigTable type
* @param <Result> Bigtable lookup value type.
*/
public abstract class BigtableBatchDoFn<Input, BatchRequest, BatchResponse, Result>
extends GuavaAsyncBatchLookupDoFn<Input, BatchRequest, BatchResponse, Result, BigtableSession> {

private final BigtableOptions options;

/** Perform asynchronous Bigtable lookup. */
public abstract ListenableFuture<BatchResponse> asyncLookup(
BigtableSession session, BatchRequest batchRequest);

/**
* Create a {@link BigtableBatchDoFn} instance.
*
* @param options Bigtable options.
*/
public BigtableBatchDoFn(
BigtableOptions options,
int batchSize,
SerializableFunction<List<Input>, BatchRequest> batchRequestFn,
SerializableFunction<BatchResponse, List<Pair<String, Result>>> batchResponseFn,
SerializableFunction<Input, String> idExtractorFn) {
this(options, batchSize, batchRequestFn, batchResponseFn, idExtractorFn, 1000);
}

/**
* Create a {@link BigtableBatchDoFn} instance.
*
* @param options Bigtable options.
* @param maxPendingRequests maximum number of pending requests on every cloned DoFn. This
* prevents runner from timing out and retrying bundles.
*/
public BigtableBatchDoFn(
BigtableOptions options,
int batchSize,
SerializableFunction<List<Input>, BatchRequest> batchRequestFn,
SerializableFunction<BatchResponse, List<Pair<String, Result>>> batchResponseFn,
SerializableFunction<Input, String> idExtractorFn,
int maxPendingRequests) {
this(
options,
batchSize,
batchRequestFn,
batchResponseFn,
idExtractorFn,
maxPendingRequests,
new BaseAsyncLookupDoFn.NoOpCacheSupplier<>());
}

/**
* Create a {@link BigtableBatchDoFn} instance.
*
* @param options Bigtable options.
* @param maxPendingRequests maximum number of pending requests on every cloned DoFn. This
* prevents runner from timing out and retrying bundles.
* @param cacheSupplier supplier for lookup cache.
*/
public BigtableBatchDoFn(
BigtableOptions options,
int batchSize,
SerializableFunction<List<Input>, BatchRequest> batchRequestFn,
SerializableFunction<BatchResponse, List<Pair<String, Result>>> batchResponseFn,
SerializableFunction<Input, String> idExtractorFn,
int maxPendingRequests,
BaseAsyncLookupDoFn.CacheSupplier<String, Result> cacheSupplier) {
super(
batchSize,
batchRequestFn,
batchResponseFn,
idExtractorFn,
maxPendingRequests,
cacheSupplier);
this.options = options;
}

@Override
public ResourceType getResourceType() {
// BigtableSession is backed by a gRPC thread safe client
return ResourceType.PER_INSTANCE;
}

protected BigtableSession newClient() {
try {
return new BigtableSession(options);
} catch (IOException e) {
throw new RuntimeException(e);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
/*
* Copyright 2024 Spotify AB.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package com.spotify.scio.bigtable

import java.util.concurrent.{CompletionException, ConcurrentLinkedQueue}
import com.google.cloud.bigtable.grpc.BigtableSession
import com.google.common.cache.{Cache, CacheBuilder}
import com.google.common.util.concurrent.{Futures, ListenableFuture}
import com.spotify.scio.testing._
import com.spotify.scio.transforms.BaseAsyncLookupDoFn.CacheSupplier
import com.spotify.scio.transforms.JavaAsyncConverters._
import org.apache.commons.lang3.tuple.Pair

import scala.jdk.CollectionConverters._
import scala.util.{Failure, Success}

class BigtableBatchDoFnTest extends PipelineSpec {
"BigtableDoFn" should "work" in {
val fn = new TestBigtableBatchDoFn
val output = runWithData(1 to 10)(_.parDo(fn))
.map(kv => (kv.getKey, kv.getValue.get()))
output should contain theSameElementsAs (1 to 10).map(x => (x, x.toString))
}

it should "work with cache" in {
val fn = new TestCachingBigtableBatchDoFn
val output = runWithData((1 to 10) ++ (6 to 15))(_.parDo(fn))
.map(kv => (kv.getKey, kv.getValue.get()))
output should have size 20
output should contain theSameElementsAs ((1 to 10) ++ (6 to 15)).map(x => (x, x.toString))
BigtableBatchDoFnTest.queue.asScala.toSet should contain theSameElementsAs (1 to 15)
BigtableBatchDoFnTest.queue.size() should be <= 20
}

it should "work with failures" in {
val fn = new TestFailingBigtableBatchDoFn

val output = runWithData(Seq[Seq[Int]](1 to 4, 8 to 10))(_.flatten.parDo(fn)).map { kv =>
val r = kv.getValue.asScala match {
case Success(v) => v
case Failure(e: CompletionException) => e.getCause.getMessage
case Failure(e) => e.getMessage
}
(kv.getKey, r)
}
output should contain theSameElementsAs (
(1 to 4).map(x => x -> x.toString) ++
(8 to 10).map(x => x -> "failure for 8,9,10")
)
}
}

object BigtableBatchDoFnTest {
val queue: ConcurrentLinkedQueue[Int] = new ConcurrentLinkedQueue[Int]()

def batchRequest(input: java.util.List[Int]): List[Int] = input.asScala.toList
def batchResponse(input: List[String]): java.util.List[Pair[String, String]] =
input.map(x => Pair.of(x, x)).asJava
def idExtractor(input: Int): String = input.toString
}

class TestBigtableBatchDoFn
extends BigtableBatchDoFn[Int, List[Int], List[String], String](
null,
2,
BigtableBatchDoFnTest.batchRequest,
BigtableBatchDoFnTest.batchResponse,
BigtableBatchDoFnTest.idExtractor
) {
override def newClient(): BigtableSession = null
override def asyncLookup(
session: BigtableSession,
input: List[Int]
): ListenableFuture[List[String]] =
Futures.immediateFuture(input.map(_.toString))
}

class TestCachingBigtableBatchDoFn
extends BigtableBatchDoFn[Int, List[Int], List[String], String](
null,
2,
BigtableBatchDoFnTest.batchRequest,
BigtableBatchDoFnTest.batchResponse,
BigtableBatchDoFnTest.idExtractor,
100,
new TestCacheBatchSupplier
) {
override def newClient(): BigtableSession = null

override def asyncLookup(
session: BigtableSession,
input: List[Int]
): ListenableFuture[List[String]] = {
input.foreach(BigtableBatchDoFnTest.queue.add)
Futures.immediateFuture(input.map(_.toString))
}
}

class TestFailingBigtableBatchDoFn
extends BigtableBatchDoFn[Int, List[Int], List[String], String](
null,
4,
BigtableBatchDoFnTest.batchRequest,
BigtableBatchDoFnTest.batchResponse,
BigtableBatchDoFnTest.idExtractor
) {
override def newClient(): BigtableSession = null
override def asyncLookup(
session: BigtableSession,
input: List[Int]
): ListenableFuture[List[String]] =
if (input.size % 2 == 0) {
Futures.immediateFuture(input.map(_.toString))
} else {
Futures.immediateFailedFuture(new RuntimeException("failure for " + input.mkString(",")))
}
}

class TestCacheBatchSupplier extends CacheSupplier[String, String] {
override def get(): Cache[String, String] = CacheBuilder.newBuilder().build()
}
16 changes: 4 additions & 12 deletions scio-grpc/src/main/java/com/spotify/scio/grpc/GrpcBatchDoFn.java
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
import com.spotify.scio.transforms.BaseAsyncLookupDoFn.CacheSupplier;
import com.spotify.scio.transforms.GuavaAsyncBatchLookupDoFn;
import io.grpc.Channel;
import io.grpc.stub.AbstractFutureStub;
import io.grpc.stub.AbstractStub;
import java.io.Serializable;
import java.util.List;
import org.apache.beam.sdk.transforms.SerializableBiFunction;
Expand All @@ -43,7 +43,7 @@
* @param <Client> client type.
*/
public class GrpcBatchDoFn<
Input, BatchRequest, BatchResponse, Output, Client extends AbstractFutureStub<Client>>
Input, BatchRequest, BatchResponse, Output, Client extends AbstractStub<Client>>
extends GuavaAsyncBatchLookupDoFn<Input, BatchRequest, BatchResponse, Output, Client> {
private final ChannelSupplier channelSupplier;
private final SerializableFunction<Channel, Client> newClientFn;
Expand Down Expand Up @@ -104,21 +104,13 @@ protected Client newClient() {
}

public static <
Input,
BatchRequest,
BatchResponse,
Output,
ClientType extends AbstractFutureStub<ClientType>>
Input, BatchRequest, BatchResponse, Output, ClientType extends AbstractStub<ClientType>>
Builder<Input, BatchRequest, BatchResponse, Output, ClientType> newBuilder() {
return new Builder<>();
}

public static class Builder<
Input,
BatchRequest,
BatchResponse,
Output,
ClientType extends AbstractFutureStub<ClientType>>
Input, BatchRequest, BatchResponse, Output, ClientType extends AbstractStub<ClientType>>
implements Serializable {

private ChannelSupplier channelSupplier;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,64 @@ class GrpcSCollectionOps[Request](private val self: SCollection[Request]) extend
).map(kvToTuple _)
.mapValues(_.asScala)
}

def grpcLookupBatchStream[
BatchRequest,
Response,
Result: Coder,
Client <: AbstractStub[Client]
](
channelSupplier: () => Channel,
clientFactory: Channel => Client,
batchSize: Int,
batchRequestFn: Seq[Request] => BatchRequest,
batchResponseFn: List[Response] => Seq[(String, Result)],
idExtractorFn: Request => String,
maxPendingRequests: Int,
cacheSupplier: CacheSupplier[String, Result] = new NoOpCacheSupplier[String, Result]()
)(
f: Client => (BatchRequest, StreamObserver[Response]) => Unit
): SCollection[(Request, Try[Result])] = self.transform { in =>
import self.coder
val cleanedChannelSupplier = ClosureCleaner.clean(channelSupplier)
val serializableClientFactory = Functions.serializableFn(clientFactory)
val serializableLookupFn =
Functions.serializableBiFn[Client, BatchRequest, ListenableFuture[JIterable[Response]]] {
(client, request) =>
val observer = new StreamObservableFuture[Response]()
f(client)(request, observer)
observer
}
val serializableBatchRequestFn =
Functions.serializableFn[java.util.List[Request], BatchRequest] { inputs =>
batchRequestFn(inputs.asScala.toSeq)
}

val serializableBatchResponseFn =
Functions.serializableFn[JIterable[Response], java.util.List[Pair[String, Result]]] {
batchResponse =>
batchResponseFn(batchResponse.asScala.toList).map { case (input, output) =>
Pair.of(input, output)
}.asJava
}
val serializableIdExtractorFn = Functions.serializableFn(idExtractorFn)
in.parDo(
GrpcBatchDoFn
.newBuilder[Request, BatchRequest, JIterable[Response], Result, Client]()
.withChannelSupplier(() => cleanedChannelSupplier())
.withNewClientFn(serializableClientFactory)
.withLookupFn(serializableLookupFn)
.withMaxPendingRequests(maxPendingRequests)
.withBatchSize(batchSize)
.withBatchRequestFn(serializableBatchRequestFn)
.withBatchResponseFn(serializableBatchResponseFn)
.withIdExtractorFn(serializableIdExtractorFn)
.withCacheSupplier(cacheSupplier)
.build()
).map(kvToTuple _)
.mapValues(_.asScala)
}

}

trait SCollectionSyntax {
Expand Down
1 change: 1 addition & 0 deletions scio-grpc/src/test/protobuf/service.proto
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ service ConcatService {
rpc ConcatClientStreaming(stream ConcatRequest) returns (ConcatResponse) {}
rpc ConcatFullStreaming(stream ConcatRequest) returns (stream ConcatResponse) {}
rpc BatchConcat(BatchRequest) returns (BatchResponse) {}
rpc BatchConcatServerStreaming(BatchRequest) returns (stream ConcatResponseWithID) {}

rpc Ping(google.protobuf.Empty) returns (google.protobuf.Empty);
}
Loading

0 comments on commit d39cba0

Please sign in to comment.