Skip to content

Commit

Permalink
Add BigTableBatchDoFn
Browse files Browse the repository at this point in the history
  • Loading branch information
lofifnc committed Nov 29, 2024
1 parent 067634c commit 2bfe0f1
Show file tree
Hide file tree
Showing 5 changed files with 294 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,18 @@ class AsyncBatchLookupDoFnTest extends PipelineSpec {
()
}

/** Tests that the doFn can handle missing keys in response */
private def testMissing[F, T: Coder, C <: AsyncBatchClient](
doFn: BaseAsyncBatchLookupDoFn[Int, List[Int], List[String], String, C, F, T]
)(tryFn: T => Option[String]): Unit = {
val output = runWithData(Seq[Seq[Int]](-4 to 10))(_.flatten.parDo(doFn))
.map(kv => (kv.getKey, tryFn(kv.getValue)))
output should contain theSameElementsAs (-4 to 10).map(x =>
(x, if (x > 0) Some(x.toString) else None)
)
()
}

private def testFailure[F, T: Coder](
doFn: BaseAsyncBatchLookupDoFn[Int, List[Int], List[String], String, AsyncBatchClient, F, T]
)(tryFn: T => Try[String]): Unit = {
Expand Down Expand Up @@ -122,6 +134,10 @@ class AsyncBatchLookupDoFnTest extends PipelineSpec {
testCache(new CachingGuavaBatchLookupDoFn)(_.get())
}

it should "work with partial batch responses" in {
testMissing(new GuavaBatchLookupDoFn)(_.get().asScala.headOption)
}

it should "work with failures" in {
testFailure(new FailingGuavaBatchLookupDoFn)(_.asScala)
}
Expand All @@ -134,6 +150,10 @@ class AsyncBatchLookupDoFnTest extends PipelineSpec {
testCache(new CachingJavaBatchLookupDoFn)(_.get())
}

it should "work with partial batch responses" in {
testMissing(new JavaBatchLookupDoFn)(_.get().asScala.headOption)
}

it should "work with failures" in {
testFailure(new FailingJavaBatchLookupDoFn)(_.asScala)
}
Expand All @@ -146,6 +166,10 @@ class AsyncBatchLookupDoFnTest extends PipelineSpec {
testCache(new CachingScalaBatchLookupDoFn)(_.get)
}

it should "work with partial batch responses" in {
testMissing(new ScalaBatchLookupDoFn)(_.get)
}

it should "work with failures" in {
testFailure(new FailingScalaBatchLookupDoFn)(identity)
}
Expand Down Expand Up @@ -207,7 +231,7 @@ class GuavaBatchLookupDoFn extends AbstractGuavaAsyncBatchLookupDoFn() {
session: AsyncBatchClient,
input: List[Int]
): ListenableFuture[List[String]] =
Futures.immediateFuture(input.map(_.toString))
Futures.immediateFuture(input.filter(_ > 0).map(_.toString))
}

class CachingGuavaBatchLookupDoFn
Expand Down Expand Up @@ -277,7 +301,7 @@ class JavaBatchLookupDoFn extends AbstractJavaAsyncBatchLookupDoFn() {
session: AsyncBatchClient,
input: List[Int]
): CompletableFuture[List[String]] =
CompletableFuture.supplyAsync(() => input.map(_.toString))
CompletableFuture.supplyAsync(() => input.filter(_ > 0).map(_.toString))
}

class CachingJavaBatchLookupDoFn
Expand Down Expand Up @@ -331,7 +355,7 @@ class ScalaBatchLookupDoFn extends AbstractScalaAsyncBatchLookupDoFn() {
override def getResourceType: ResourceType = ResourceType.PER_INSTANCE
override protected def newClient(): AsyncBatchClient = null
override def asyncLookup(session: AsyncBatchClient, input: List[Int]): Future[List[String]] =
Future.successful(input.map(_.toString))
Future.successful(input.filter(_ > 0).map(_.toString))
}

class CachingScalaBatchLookupDoFn
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
/*
* Copyright 2017 Spotify AB.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package com.spotify.scio.bigtable;

import com.google.cloud.bigtable.config.BigtableOptions;
import com.google.cloud.bigtable.grpc.BigtableSession;
import com.google.common.util.concurrent.ListenableFuture;
import com.spotify.scio.transforms.BaseAsyncLookupDoFn;
import com.spotify.scio.transforms.GuavaAsyncBatchLookupDoFn;
import java.io.IOException;
import java.util.List;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.SerializableFunction;
import org.apache.commons.lang3.tuple.Pair;

/**
* A {@link DoFn} which batches elements and performs asynchronous lookup for them using Google
* Cloud Bigtable.
*
* @param <Input> input element type.
* @param <BatchRequest> batched input element type
* @param <BatchResponse> batched response from BigTable type
* @param <Result> Bigtable lookup value type.
*/
public abstract class BigtableBatchDoFn<Input, BatchRequest, BatchResponse, Result>
extends GuavaAsyncBatchLookupDoFn<Input, BatchRequest, BatchResponse, Result, BigtableSession> {

private final BigtableOptions options;

/** Perform asynchronous Bigtable lookup. */
public abstract ListenableFuture<BatchResponse> asyncLookup(
BigtableSession session, BatchRequest batchRequest);

/**
* Create a {@link BigtableBatchDoFn} instance.
*
* @param options Bigtable options.
*/
public BigtableBatchDoFn(
BigtableOptions options,
int batchSize,
SerializableFunction<List<Input>, BatchRequest> batchRequestFn,
SerializableFunction<BatchResponse, List<Pair<String, Result>>> batchResponseFn,
SerializableFunction<Input, String> idExtractorFn) {
this(options, batchSize, batchRequestFn, batchResponseFn, idExtractorFn, 1000);
}

/**
* Create a {@link BigtableBatchDoFn} instance.
*
* @param options Bigtable options.
* @param maxPendingRequests maximum number of pending requests on every cloned DoFn. This
* prevents runner from timing out and retrying bundles.
*/
public BigtableBatchDoFn(
BigtableOptions options,
int batchSize,
SerializableFunction<List<Input>, BatchRequest> batchRequestFn,
SerializableFunction<BatchResponse, List<Pair<String, Result>>> batchResponseFn,
SerializableFunction<Input, String> idExtractorFn,
int maxPendingRequests) {
this(
options,
batchSize,
batchRequestFn,
batchResponseFn,
idExtractorFn,
maxPendingRequests,
new BaseAsyncLookupDoFn.NoOpCacheSupplier<>());
}

/**
* Create a {@link BigtableBatchDoFn} instance.
*
* @param options Bigtable options.
* @param maxPendingRequests maximum number of pending requests on every cloned DoFn. This
* prevents runner from timing out and retrying bundles.
* @param cacheSupplier supplier for lookup cache.
*/
public BigtableBatchDoFn(
BigtableOptions options,
int batchSize,
SerializableFunction<List<Input>, BatchRequest> batchRequestFn,
SerializableFunction<BatchResponse, List<Pair<String, Result>>> batchResponseFn,
SerializableFunction<Input, String> idExtractorFn,
int maxPendingRequests,
BaseAsyncLookupDoFn.CacheSupplier<String, Result> cacheSupplier) {
super(
batchSize,
batchRequestFn,
batchResponseFn,
idExtractorFn,
maxPendingRequests,
cacheSupplier);
this.options = options;
}

@Override
public ResourceType getResourceType() {
// BigtableSession is backed by a gRPC thread safe client
return ResourceType.PER_INSTANCE;
}

protected BigtableSession newClient() {
try {
return new BigtableSession(options);
} catch (IOException e) {
throw new RuntimeException(e);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
/*
* Copyright 2019 Spotify AB.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package com.spotify.scio.bigtable

import java.util.concurrent.{CompletionException, ConcurrentLinkedQueue}
import com.google.cloud.bigtable.grpc.BigtableSession
import com.google.common.cache.{Cache, CacheBuilder}
import com.google.common.util.concurrent.{Futures, ListenableFuture}
import com.spotify.scio.testing._
import com.spotify.scio.transforms.BaseAsyncLookupDoFn.CacheSupplier
import com.spotify.scio.transforms.JavaAsyncConverters._
import org.apache.commons.lang3.tuple.Pair

import scala.jdk.CollectionConverters._
import scala.util.{Failure, Success}

class BigtableBatchDoFnTest extends PipelineSpec {
"BigtableDoFn" should "work" in {
val fn = new TestBigtableBatchDoFn
val output = runWithData(1 to 10)(_.parDo(fn))
.map(kv => (kv.getKey, kv.getValue.get().iterator().next()))
output should contain theSameElementsAs (1 to 10).map(x => (x, x.toString))
}

it should "work with cache" in {
val fn = new TestCachingBigtableBatchDoFn
val output = runWithData((1 to 10) ++ (6 to 15))(_.parDo(fn))
.map(kv => (kv.getKey, kv.getValue.get().iterator().next()))
output should have size 20
output should contain theSameElementsAs ((1 to 10) ++ (6 to 15)).map(x => (x, x.toString))
BigtableBatchDoFnTest.queue.asScala.toSet should contain theSameElementsAs (1 to 15)
BigtableBatchDoFnTest.queue.size() should be <= 20
}

it should "work with failures" in {
val fn = new TestFailingBigtableBatchDoFn

val output = runWithData(Seq[Seq[Int]](1 to 4, 8 to 10))(_.flatten.parDo(fn)).map { kv =>
val r = kv.getValue.asScala match {
case Success(v) => v.asScala.head
case Failure(e: CompletionException) => e.getCause.getMessage
case Failure(e) => e.getMessage
}
(kv.getKey, r)
}
output should contain theSameElementsAs (
(1 to 4).map(x => x -> x.toString) ++
(8 to 10).map(x => x -> "failure for 8,9,10")
)
}
}

object BigtableBatchDoFnTest {
val queue: ConcurrentLinkedQueue[Int] = new ConcurrentLinkedQueue[Int]()

def batchRequest(input: java.util.List[Int]): List[Int] = input.asScala.toList
def batchResponse(input: List[String]): java.util.List[Pair[String, String]] =
input.map(x => Pair.of(x, x)).asJava
def idExtractor(input: Int): String = input.toString
}

class TestBigtableBatchDoFn
extends BigtableBatchDoFn[Int, List[Int], List[String], String](
null,
2,
BigtableBatchDoFnTest.batchRequest,
BigtableBatchDoFnTest.batchResponse,
BigtableBatchDoFnTest.idExtractor
) {
override def newClient(): BigtableSession = null
override def asyncLookup(
session: BigtableSession,
input: List[Int]
): ListenableFuture[List[String]] =
Futures.immediateFuture(input.map(_.toString))
}

class TestCachingBigtableBatchDoFn
extends BigtableBatchDoFn[Int, List[Int], List[String], String](
null,
2,
BigtableBatchDoFnTest.batchRequest,
BigtableBatchDoFnTest.batchResponse,
BigtableBatchDoFnTest.idExtractor,
100,
new TestCacheBatchSupplier
) {
override def newClient(): BigtableSession = null

override def asyncLookup(
session: BigtableSession,
input: List[Int]
): ListenableFuture[List[String]] = {
input.foreach(BigtableBatchDoFnTest.queue.add)
Futures.immediateFuture(input.map(_.toString))
}
}

class TestFailingBigtableBatchDoFn
extends BigtableBatchDoFn[Int, List[Int], List[String], String](
null,
4,
BigtableBatchDoFnTest.batchRequest,
BigtableBatchDoFnTest.batchResponse,
BigtableBatchDoFnTest.idExtractor
) {
override def newClient(): BigtableSession = null
override def asyncLookup(
session: BigtableSession,
input: List[Int]
): ListenableFuture[List[String]] =
if (input.size % 2 == 0) {
Futures.immediateFuture(input.map(_.toString))
} else {
Futures.immediateFailedFuture(new RuntimeException("failure for " + input.mkString(",")))
}
}

class TestCacheBatchSupplier extends CacheSupplier[String, String] {
override def get(): Cache[String, String] = CacheBuilder.newBuilder().build()
}
15 changes: 3 additions & 12 deletions scio-grpc/src/main/java/com/spotify/scio/grpc/GrpcBatchDoFn.java
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,9 @@
import com.spotify.scio.transforms.BaseAsyncLookupDoFn.CacheSupplier;
import com.spotify.scio.transforms.GuavaAsyncBatchLookupDoFn;
import io.grpc.Channel;
import io.grpc.stub.AbstractStub;
import java.io.Serializable;
import java.util.List;

import io.grpc.stub.AbstractStub;
import org.apache.beam.sdk.transforms.SerializableBiFunction;
import org.apache.beam.sdk.transforms.SerializableFunction;
import org.apache.commons.lang3.tuple.Pair;
Expand Down Expand Up @@ -105,21 +104,13 @@ protected Client newClient() {
}

public static <
Input,
BatchRequest,
BatchResponse,
Output,
ClientType extends AbstractStub<ClientType>>
Input, BatchRequest, BatchResponse, Output, ClientType extends AbstractStub<ClientType>>
Builder<Input, BatchRequest, BatchResponse, Output, ClientType> newBuilder() {
return new Builder<>();
}

public static class Builder<
Input,
BatchRequest,
BatchResponse,
Output,
ClientType extends AbstractStub<ClientType>>
Input, BatchRequest, BatchResponse, Output, ClientType extends AbstractStub<ClientType>>
implements Serializable {

private ChannelSupplier channelSupplier;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,8 @@ class GrpcBatchDoFnTest extends PipelineSpec with BeforeAndAfterAll {
.setStringTwo(i.toString)
.build()
}
val expected: Seq[(ConcatRequestWithID, Try[Option[ConcatResponseWithID]])] = input.map(req => req -> Success(None))
val expected: Seq[(ConcatRequestWithID, Try[Option[ConcatResponseWithID]])] =
input.map(req => req -> Success(None))
try {
runWithContext { sc =>
sc.parallelize(input)
Expand Down

0 comments on commit 2bfe0f1

Please sign in to comment.