Skip to content

Commit 1983b61

Browse files
spkrkaclaude
andcommitted
Add SMBCollection: unified fluent API for Sort-Merge Bucket operations
This introduces SMBCollection, a new fluent API that unifies and improves all SMB operations in Scio. ## Key Improvements ### 1. Unified API Traditional SMB operations are fragmented across disjoint methods solving specific sub-problems: - `sortMergeJoin` - read and join to SCollection - `sortMergeTransform` - read, transform, and write back to SMB - `sortMergeGroupByKey` - read single source to SCollection - `sortMergeCoGroup` - read multiple sources to SCollection SMBCollection provides a single, composable API for all SMB workflows. ### 2. Familiar SCollection-like Ergonomics Uses familiar functional operations (`map`, `filter`, `flatMap`) instead of imperative callbacks: **Before (Traditional API):** ```scala sc.sortMergeTransform(classOf[Integer], usersRead) .to(output) .via { case (key, users, outputCollector) => users.foreach { user => val transformed = transformUser(user) outputCollector.accept(transformed) // ❌ Imperative callback } } ``` **After (SMBCollection):** ```scala SMBCollection.read(classOf[Integer], usersRead) .flatMap(users => users.map(transformUser)) // ✅ Functional style .saveAsSortedBucket(output) ``` ### 3. Better Interoperability Seamlessly convert between SMB and SCollection: ```scala val base = SMBCollection.cogroup2(classOf[Integer], usersRead, accountsRead) .map { case (_, (users, accounts)) => expensiveJoin(users, accounts) } // SMB outputs (stay bucketed) base.map(_.summary).saveAsSortedBucket(summaryOutput) base.map(_.details).saveAsSortedBucket(detailsOutput) // SCollection output (for non-SMB operations) val sc = base.toDeferredSCollection().get sc.filter(_.needsProcessing).saveAsTextFile(textOutput) sc.run() // All outputs execute in one pass! ``` ### 4. Zero-Shuffle Multi-Output (Massive Performance Gains) Create multiple SMB outputs from the same computation with zero shuffles. **Before (Traditional - SCollection fanout):** ```scala // Reads once, joins once, BUT shuffles 3 times val joined = sc.sortMergeJoin(classOf[Integer], usersRead, accountsRead) .map { case (userId, (user, account)) => expensiveJoin(user, account) // Runs once ✓ } // ❌ Each saveAsSortedBucket does a GroupByKey shuffle! joined.map(_.summary).saveAsSortedBucket(summaryOutput) // Shuffle 1 joined.map(_.details).saveAsSortedBucket(detailsOutput) // Shuffle 2 joined.filter(_.isHighValue).saveAsSortedBucket(highValueOutput) // Shuffle 3 ``` **After (SMBCollection - zero shuffles):** ```scala // Reads once, joins once, zero shuffles! val base = SMBCollection.cogroup2(classOf[Integer], usersRead, accountsRead) .map { case (_, (users, accounts)) => expensiveJoin(users, accounts) // Runs ONCE } // ✅ Fan out to multiple SMB outputs - data already bucketed! base.map(_.summary).saveAsSortedBucket(summaryOutput) base.map(_.details).saveAsSortedBucket(detailsOutput) base.filter(_.isHighValue).saveAsSortedBucket(highValueOutput) sc.run() // Single pass execution ``` **Performance Impact:** | Scenario | Traditional (SCollection fanout) | SMBCollection Multi-Output | Cost Reduction | |----------|----------------------------------|----------------------------|----------------| | 1TB → 3 SMB outputs | 1TB read + ~3TB shuffle | 1TB read, 0 shuffle | **~4× savings** | | 2TB join → 5 outputs | 2TB read + ~10TB shuffle | 2TB read, 0 shuffle | **~6× savings** | | 500GB → 10 outputs | 500GB read + ~5TB shuffle | 500GB read, 0 shuffle | **~11× savings** | ## Complete Example See `SortMergeBucketMultiOutputExample` in scio-examples for a full working example showing how to create multiple derived datasets (summary, details, high-value users) from a single expensive user-account join with zero shuffles. ## API Design - Type signature: `SMBCollection[K1, K2, V]` - tracks keys for type safety, methods work with V directly - `read()` returns `Iterable[V]` without key wrapper - `cogroup2()` returns `(K, (Iterable[L], Iterable[R]))` - Standard transformations: `map`, `filter`, `flatMap` (not `mapValues`/`flatMapValues`) - Side inputs: clean `(SideInputContext, V)` signature - Auto-execution: outputs execute via `sc.onClose()` hook ## Limitations - Currently supports up to 4-way cogroups (`cogroup2`, `cogroup3`, `cogroup4`) - For 5-22 way cogroups, use traditional `sortMergeCoGroup` - Note: This is not a systemic limitation - easily extensible by adding `cogroup5` through `cogroup22` methods ## Documentation Updated documentation includes: - Complete fluent API guide with multi-output examples - API comparison table (fluent vs traditional) - Performance impact analysis - Migration examples - When to use which API 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Sonnet 4.5 <[email protected]>
1 parent c234ba6 commit 1983b61

26 files changed

+6816
-19
lines changed

scio-examples/src/main/scala/com/spotify/scio/examples/extra/SortMergeBucketExample.scala

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -280,3 +280,104 @@ object SortMergeBucketTransformExample {
280280
()
281281
}
282282
}
283+
284+
object SortMergeBucketMultiOutputExample {
285+
import com.spotify.scio.smb._
286+
287+
case class AccountProjection(id: Int, amount: Double)
288+
case class UserSummary(userId: Int, age: Int, totalAmount: Double, accountCount: Int)
289+
case class UserDetails(userId: Int, age: Int, accounts: Seq[AccountProjection])
290+
case class HighValueUser(userId: Int, age: Int, totalAmount: Double)
291+
292+
def pipeline(cmdLineArgs: Array[String]): ScioContext = {
293+
val (sc, args) = ContextAndArgs(cmdLineArgs)
294+
pipeline(sc, args)
295+
sc
296+
}
297+
298+
def pipeline(sc: ScioContext, args: Args): Unit = {
299+
implicit val coder: Coder[GenericRecord] = avroGenericRecordCoder(
300+
SortMergeBucketExample.UserDataSchema
301+
)
302+
implicit val scImpl: ScioContext = sc
303+
304+
// #SortMergeBucketExample_multi_output
305+
// Create base collection with expensive shared computation
306+
// This cogroup and map runs ONCE, results are shared across all outputs
307+
val base = SMBCollection
308+
.cogroup2(
309+
classOf[Integer],
310+
ParquetAvroSortedBucketIO
311+
.read(new TupleTag[GenericRecord]("users"), SortMergeBucketExample.UserDataSchema)
312+
.withFilterPredicate(FilterApi.lt(FilterApi.intColumn("age"), Int.box(50)))
313+
.from(args("users")),
314+
ParquetTypeSortedBucketIO
315+
.read(new TupleTag[AccountProjection]("accounts"))
316+
.from(args("accounts"))
317+
)
318+
.map { case (_, (users, accounts)) =>
319+
// Expensive computation happens ONCE per key group
320+
// Results are pushed to all three outputs below
321+
val accountList = accounts.toSeq
322+
val totalAmount = accountList.map(_.amount).sum
323+
(users.toSeq, accountList, totalAmount)
324+
}
325+
326+
// Output 1: Summary - just the aggregated metrics
327+
base
328+
.map { case (users, accounts, total) =>
329+
UserSummary(
330+
userId = users.head.get("userId").asInstanceOf[Int],
331+
age = users.head.get("age").asInstanceOf[Int],
332+
totalAmount = total,
333+
accountCount = accounts.size
334+
)
335+
}
336+
.saveAsSortedBucket(
337+
ParquetTypeSortedBucketIO
338+
.transformOutput[Integer, UserSummary]("userId")
339+
.to(args("summaryOutput"))
340+
)
341+
342+
// Output 2: Details - full account information
343+
base
344+
.map { case (users, accounts, _) =>
345+
UserDetails(
346+
userId = users.head.get("userId").asInstanceOf[Int],
347+
age = users.head.get("age").asInstanceOf[Int],
348+
accounts = accounts
349+
)
350+
}
351+
.saveAsSortedBucket(
352+
ParquetTypeSortedBucketIO
353+
.transformOutput[Integer, UserDetails]("userId")
354+
.to(args("detailsOutput"))
355+
)
356+
357+
// Output 3: High value users only - filtered subset
358+
base
359+
.filter { case (_, _, total) => total > 1000.0 }
360+
.map { case (users, _, total) =>
361+
HighValueUser(
362+
userId = users.head.get("userId").asInstanceOf[Int],
363+
age = users.head.get("age").asInstanceOf[Int],
364+
totalAmount = total
365+
)
366+
}
367+
.saveAsSortedBucket(
368+
ParquetTypeSortedBucketIO
369+
.transformOutput[Integer, HighValueUser]("userId")
370+
.to(args("highValueOutput"))
371+
)
372+
373+
// All outputs execute automatically when sc.run() is called
374+
// SMB data is read ONCE, expensive computation runs ONCE, zero shuffles!
375+
// #SortMergeBucketExample_multi_output
376+
}
377+
378+
def main(cmdLineArgs: Array[String]): Unit = {
379+
val sc = pipeline(cmdLineArgs)
380+
sc.run().waitUntilDone()
381+
()
382+
}
383+
}

scio-smb/src/main/java/org/apache/beam/sdk/extensions/smb/AvroFileOperations.java

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
package org.apache.beam.sdk.extensions.smb;
1919

2020
import java.io.IOException;
21+
import java.io.ObjectInputStream;
22+
import java.io.ObjectOutputStream;
2123
import java.io.Serializable;
2224
import java.nio.channels.Channels;
2325
import java.nio.channels.ReadableByteChannel;
@@ -44,7 +46,8 @@ public class AvroFileOperations<ValueT> extends FileOperations<ValueT> {
4446

4547
private final AvroDatumFactory<ValueT> datumFactory;
4648
private final SerializableSchemaSupplier schemaSupplier;
47-
private PatchedSerializableAvroCodecFactory codec;
49+
private transient CodecFactory codec;
50+
private PatchedSerializableAvroCodecFactory serializableCodec;
4851
private Map<String, Object> metadata;
4952

5053
static CodecFactory defaultCodec() {
@@ -55,7 +58,7 @@ private AvroFileOperations(AvroDatumFactory<ValueT> datumFactory, Schema schema)
5558
super(Compression.UNCOMPRESSED, MimeTypes.BINARY); // Avro has its own compression via codec
5659
this.schemaSupplier = new SerializableSchemaSupplier(schema);
5760
this.datumFactory = datumFactory;
58-
this.codec = new PatchedSerializableAvroCodecFactory(defaultCodec());
61+
this.codec = defaultCodec();
5962
}
6063

6164
public static <V extends IndexedRecord> AvroFileOperations<V> of(
@@ -64,7 +67,7 @@ public static <V extends IndexedRecord> AvroFileOperations<V> of(
6467
}
6568

6669
public AvroFileOperations<ValueT> withCodec(CodecFactory codec) {
67-
this.codec = new PatchedSerializableAvroCodecFactory(codec);
70+
this.codec = codec;
6871
return this;
6972
}
7073

@@ -76,7 +79,7 @@ public AvroFileOperations<ValueT> withMetadata(Map<String, Object> metadata) {
7679
@Override
7780
public void populateDisplayData(Builder builder) {
7881
super.populateDisplayData(builder);
79-
builder.add(DisplayData.item("codecFactory", codec.getCodec().getClass()));
82+
builder.add(DisplayData.item("codecFactory", codec.getClass()));
8083
builder.add(DisplayData.item("schema", schemaSupplier.schema.getFullName()));
8184
}
8285

@@ -91,7 +94,7 @@ protected FileIO.Sink<ValueT> createSink() {
9194
final AvroIO.Sink<ValueT> sink =
9295
((AvroIO.Sink<ValueT>) AvroIO.sink(getSchema()))
9396
.withDatumWriterFactory(datumFactory)
94-
.withCodec(codec.getCodec());
97+
.withCodec(codec);
9598

9699
if (metadata != null) {
97100
return sink.withMetadata(metadata);
@@ -110,6 +113,23 @@ Schema getSchema() {
110113
return schemaSupplier.get();
111114
}
112115

116+
/**
117+
* Custom serialization to handle non-serializable CodecFactory. Converts codec to
118+
* PatchedSerializableAvroCodecFactory before serialization.
119+
*/
120+
private void writeObject(ObjectOutputStream out) throws IOException {
121+
// Convert CodecFactory to serializable form
122+
serializableCodec = new PatchedSerializableAvroCodecFactory(codec);
123+
out.defaultWriteObject();
124+
}
125+
126+
/** Custom deserialization to restore CodecFactory from PatchedSerializableAvroCodecFactory. */
127+
private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException {
128+
in.defaultReadObject();
129+
// Restore CodecFactory from serializable form
130+
codec = serializableCodec.getCodec();
131+
}
132+
113133
private static class SerializableSchemaString implements Serializable {
114134
private final String schema;
115135

0 commit comments

Comments
 (0)