From 0edc826d28a7143ed87e6e114bf28b903bf06b96 Mon Sep 17 00:00:00 2001 From: Raunaq Morarka Date: Wed, 11 Dec 2024 18:21:19 +0530 Subject: [PATCH] Use bloom filter for evaluating dynamic filters on strings MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Benchmark (filterSize) (inputDataSet) (inputNullChance) (nonNullsSelectivity) (nullsAllowed) Mode Cnt Before Score After Score Units BenchmarkDynamicPageFilter.filterPages 100 VARCHAR_RANDOM 0.05 0.2 false thrpt 20 145.858 ± 4.541 590.506 ± 28.510 ops/s BenchmarkDynamicPageFilter.filterPages 1000 VARCHAR_RANDOM 0.05 0.2 false thrpt 20 136.995 ± 2.395 596.036 ± 22.694 ops/s BenchmarkDynamicPageFilter.filterPages 10000 VARCHAR_RANDOM 0.05 0.2 false thrpt 20 136.990 ± 5.284 594.118 ± 15.764 ops/s BenchmarkDynamicPageFilter.filterPages 100000 VARCHAR_RANDOM 0.05 0.2 false thrpt 20 114.591 ± 7.307 587.445 ± 9.818 ops/s BenchmarkDynamicPageFilter.filterPages 1000000 VARCHAR_RANDOM 0.05 0.2 false thrpt 20 43.234 ± 1.621 578.800 ± 15.694 ops/s BenchmarkDynamicPageFilter.filterPages 5000000 VARCHAR_RANDOM 0.05 0.2 false thrpt 20 40.018 ± 2.245 464.153 ± 20.914 ops/s --- .../trino/sql/gen/columnar/BloomFilter.java | 192 ++++++++++++++++++ .../sql/gen/columnar/DynamicPageFilter.java | 35 ++-- .../trino/sql/planner/DomainTranslator.java | 4 +- .../sql/gen/BenchmarkDynamicPageFilter.java | 8 + .../trino/sql/gen/TestDynamicPageFilter.java | 103 +++++++++- 5 files changed, 315 insertions(+), 27 deletions(-) create mode 100644 core/trino-main/src/main/java/io/trino/sql/gen/columnar/BloomFilter.java diff --git a/core/trino-main/src/main/java/io/trino/sql/gen/columnar/BloomFilter.java b/core/trino-main/src/main/java/io/trino/sql/gen/columnar/BloomFilter.java new file mode 100644 index 000000000000..b4deb3a5c0c9 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/sql/gen/columnar/BloomFilter.java @@ -0,0 +1,192 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.sql.gen.columnar; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.collect.ImmutableList; +import io.airlift.slice.Slice; +import io.airlift.slice.XxHash64; +import io.trino.operator.project.InputChannels; +import io.trino.spi.Page; +import io.trino.spi.block.VariableWidthBlock; +import io.trino.spi.connector.ConnectorSession; +import io.trino.spi.predicate.Domain; +import io.trino.spi.type.CharType; +import io.trino.spi.type.Type; +import io.trino.spi.type.VarbinaryType; +import io.trino.spi.type.VarcharType; + +import java.util.List; +import java.util.function.Supplier; + +public class BloomFilter +{ + private BloomFilter() {} + + public static boolean canUseBloomFilter(Domain domain) + { + Type type = domain.getType(); + if (type.getJavaType() != Slice.class) { + return false; + } + if (!(type instanceof VarcharType || type instanceof CharType || type instanceof VarbinaryType)) { + return false; + } + return !domain.isNone() && !domain.isAll() && domain.isNullableDiscreteSet(); + } + + public static Supplier createBloomFilterEvaluator(Domain discreteDomain, int inputChannel) + { + return () -> new ColumnarFilterEvaluator( + new DictionaryAwareColumnarFilter( + new ColumnarBloomFilter(discreteDomain.getNullableDiscreteSet(), inputChannel))); + } + + private static final class ColumnarBloomFilter + implements ColumnarFilter + { + private final SliceBloomFilter filter; + private final boolean isNullAllowed; + private final InputChannels inputChannels; + + public ColumnarBloomFilter(Domain.DiscreteSet discreteSet, int inputChannel) + { + this.isNullAllowed = discreteSet.containsNull(); + this.filter = new SliceBloomFilter((List) (List) discreteSet.getNonNullValues()); + this.inputChannels = new InputChannels(ImmutableList.of(inputChannel), ImmutableList.of(inputChannel)); + } + + @Override + public int filterPositionsRange(ConnectorSession session, int[] outputPositions, int offset, int size, Page page) + { + VariableWidthBlock block = (VariableWidthBlock) page.getBlock(0); + int selectedPositionsCount = 0; + if (block.mayHaveNull()) { + for (int position = offset; position < offset + size; position++) { + boolean result = block.isNull(position) ? isNullAllowed : filter.test(block, position); + outputPositions[selectedPositionsCount] = position; + selectedPositionsCount += result ? 1 : 0; + } + return selectedPositionsCount; + } + + for (int position = offset; position < offset + size; position++) { + outputPositions[selectedPositionsCount] = position; + selectedPositionsCount += filter.test(block, position) ? 1 : 0; + } + return selectedPositionsCount; + } + + @Override + public int filterPositionsList(ConnectorSession session, int[] outputPositions, int[] activePositions, int offset, int size, Page page) + { + VariableWidthBlock block = (VariableWidthBlock) page.getBlock(0); + int selectedPositionsCount = 0; + if (block.mayHaveNull()) { + for (int index = offset; index < offset + size; index++) { + int position = activePositions[index]; + boolean result = block.isNull(position) ? isNullAllowed : filter.test(block, position); + outputPositions[selectedPositionsCount] = position; + selectedPositionsCount += result ? 1 : 0; + } + return selectedPositionsCount; + } + + for (int index = offset; index < offset + size; index++) { + int position = activePositions[index]; + outputPositions[selectedPositionsCount] = position; + selectedPositionsCount += filter.test(block, position) ? 1 : 0; + } + return selectedPositionsCount; + } + + @Override + public InputChannels getInputChannels() + { + return inputChannels; + } + } + + public static final class SliceBloomFilter + { + private final long[] bloom; + private final int bloomSizeMask; + + /** + * A Bloom filter for a set of Slice values. + * This is approx 2X faster than the Bloom filter implementations in ORC and parquet because + * it uses single hash function and uses that to set 3 bits within a 64 bit word. + * The memory footprint is up to (4 * values.size()) bytes, which is much smaller than maintaining a hash set of strings. + * + * @param values List of values used for filtering + */ + public SliceBloomFilter(List values) + { + int bloomSize = getBloomFilterSize(values.size()); + bloom = new long[bloomSize]; + bloomSizeMask = bloomSize - 1; + for (Slice value : values) { + long hashCode = XxHash64.hash(value); + // Set 3 bits in a 64 bit word + bloom[bloomIndex(hashCode)] |= bloomMask(hashCode); + } + } + + private static int getBloomFilterSize(int valuesCount) + { + // Linear hash table size is the highest power of two less than or equal to number of values * 4. This means that the + // table is under half full, e.g. 127 elements gets 256 slots. + int hashTableSize = Integer.highestOneBit(valuesCount * 4); + // We will allocate 8 bits in the bloom filter for every slot in a comparable hash table. + // The bloomSize is a count of longs, hence / 8. + return Math.max(1, hashTableSize / 8); + } + + public boolean test(VariableWidthBlock block, int position) + { + return contains(block.getRawSlice(), block.getRawSliceOffset(position), block.getSliceLength(position)); + } + + @VisibleForTesting + public boolean contains(Slice data) + { + return contains(data, 0, data.length()); + } + + private boolean contains(Slice data, int offset, int length) + { + long hashCode = XxHash64.hash(data, offset, length); + long mask = bloomMask(hashCode); + return mask == (bloom[bloomIndex(hashCode)] & mask); + } + + private int bloomIndex(long hashCode) + { + // Lower 21 bits are not used by bloomMask + // These are enough for the maximum size array that will be used here + return (int) (hashCode & bloomSizeMask); + } + + private static long bloomMask(long hashCode) + { + // returned mask sets 3 bits based on portions of given hash + // Extract 38th to 43rd bits + return (1L << ((hashCode >> 21) & 63)) + // Extract 32nd to 37th bits + | (1L << ((hashCode >> 27) & 63)) + // Extract 26th to 31st bits + | (1L << ((hashCode >> 33) & 63)); + } + } +} diff --git a/core/trino-main/src/main/java/io/trino/sql/gen/columnar/DynamicPageFilter.java b/core/trino-main/src/main/java/io/trino/sql/gen/columnar/DynamicPageFilter.java index 7e3814426ba2..91172127080a 100644 --- a/core/trino-main/src/main/java/io/trino/sql/gen/columnar/DynamicPageFilter.java +++ b/core/trino-main/src/main/java/io/trino/sql/gen/columnar/DynamicPageFilter.java @@ -41,6 +41,8 @@ import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.ImmutableMap.toImmutableMap; +import static io.trino.sql.gen.columnar.BloomFilter.canUseBloomFilter; +import static io.trino.sql.gen.columnar.BloomFilter.createBloomFilterEvaluator; import static io.trino.sql.gen.columnar.FilterEvaluator.createColumnarFilterEvaluator; import static io.trino.sql.ir.optimizer.IrExpressionOptimizer.newOptimizer; import static io.trino.sql.relational.SqlToRowExpressionTranslator.translate; @@ -107,17 +109,7 @@ public synchronized Supplier createDynamicPageFilterEvaluator(C isBlocked = dynamicFilter.isBlocked(); boolean isAwaitable = dynamicFilter.isAwaitable(); TupleDomain currentPredicate = dynamicFilter.getCurrentPredicate().transformKeys(columnHandles::get); - List expressionConjuncts = domainTranslator.toPredicateConjuncts(currentPredicate) - .stream() - // Run the expression derived from TupleDomain through IR optimizer to simplify predicates. E.g. SimplifyContinuousInValues - .map(expression -> irExpressionOptimizer.process(expression, session, ImmutableMap.of()).orElse(expression)) - .collect(toImmutableList()); - // We translate each conjunct into separate RowExpression to make it easy to profile selectivity - // of dynamic filter per column and drop them if they're ineffective - List rowExpression = expressionConjuncts.stream() - .map(expression -> translate(expression, sourceLayout, metadata, typeManager)) - .collect(toImmutableList()); - compiledDynamicFilter = createDynamicFilterEvaluator(rowExpression, compiler, selectivityThreshold); + compiledDynamicFilter = createDynamicFilterEvaluator(compiler, currentPredicate); if (!isAwaitable) { isBlocked = null; // Dynamic filter will not narrow down anymore } @@ -125,10 +117,25 @@ public synchronized Supplier createDynamicPageFilterEvaluator(C return compiledDynamicFilter; } - private static Supplier createDynamicFilterEvaluator(List rowExpressions, ColumnarFilterCompiler compiler, double selectivityThreshold) + private Supplier createDynamicFilterEvaluator(ColumnarFilterCompiler compiler, TupleDomain currentPredicate) { - List> subExpressionEvaluators = rowExpressions.stream() - .map(expression -> createColumnarFilterEvaluator(expression, compiler)) + if (currentPredicate.isNone()) { + return SelectNoneEvaluator::new; + } + // We translate each conjunct into separate FilterEvaluator to make it easy to profile selectivity + // of dynamic filter per column and drop them if they're ineffective + List> subExpressionEvaluators = currentPredicate.getDomains().orElseThrow() + .entrySet().stream() + .map(entry -> { + if (canUseBloomFilter(entry.getValue())) { + return Optional.of(createBloomFilterEvaluator(entry.getValue(), sourceLayout.get(entry.getKey()))); + } + Expression expression = domainTranslator.toPredicate(entry.getValue(), entry.getKey().toSymbolReference()); + // Run the expression derived from TupleDomain through IR optimizer to simplify predicates. E.g. SimplifyContinuousInValues + expression = irExpressionOptimizer.process(expression, session, ImmutableMap.of()).orElse(expression); + RowExpression rowExpression = translate(expression, sourceLayout, metadata, typeManager); + return createColumnarFilterEvaluator(rowExpression, compiler); + }) .filter(Optional::isPresent) .map(Optional::get) .collect(toImmutableList()); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/DomainTranslator.java b/core/trino-main/src/main/java/io/trino/sql/planner/DomainTranslator.java index c2936990783e..bebface2e623 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/DomainTranslator.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/DomainTranslator.java @@ -119,7 +119,7 @@ public Expression toPredicate(TupleDomain tupleDomain) return IrUtils.combineConjuncts(toPredicateConjuncts(tupleDomain)); } - public List toPredicateConjuncts(TupleDomain tupleDomain) + private List toPredicateConjuncts(TupleDomain tupleDomain) { if (tupleDomain.isNone()) { return ImmutableList.of(FALSE); @@ -132,7 +132,7 @@ public List toPredicateConjuncts(TupleDomain tupleDomain) .collect(toImmutableList()); } - private Expression toPredicate(Domain domain, Reference reference) + public Expression toPredicate(Domain domain, Reference reference) { if (domain.getValues().isNone()) { return domain.isNullAllowed() ? new IsNull(reference) : FALSE; diff --git a/core/trino-main/src/test/java/io/trino/sql/gen/BenchmarkDynamicPageFilter.java b/core/trino-main/src/test/java/io/trino/sql/gen/BenchmarkDynamicPageFilter.java index dfebef659217..1563d22f8702 100644 --- a/core/trino-main/src/test/java/io/trino/sql/gen/BenchmarkDynamicPageFilter.java +++ b/core/trino-main/src/test/java/io/trino/sql/gen/BenchmarkDynamicPageFilter.java @@ -15,6 +15,7 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import io.airlift.slice.Slices; import io.trino.FullConnectorSession; import io.trino.operator.project.SelectedPositions; import io.trino.spi.Page; @@ -52,6 +53,7 @@ import static io.trino.spi.type.RealType.REAL; import static io.trino.spi.type.TypeUtils.readNativeValue; import static io.trino.spi.type.TypeUtils.writeNativeValue; +import static io.trino.spi.type.VarcharType.VARCHAR; import static io.trino.testing.TestingSession.testSessionBuilder; import static io.trino.util.DynamicFiltersTestUtil.createDynamicFilterEvaluator; import static java.lang.Float.floatToIntBits; @@ -87,6 +89,7 @@ public class BenchmarkDynamicPageFilter "INT64_RANDOM", "INT64_FIXED_32K", // LongBitSetFilter "REAL_RANDOM", + "VARCHAR_RANDOM", // BloomFilter }) public DataSet inputDataSet; @@ -99,6 +102,11 @@ public enum DataSet INT64_RANDOM(BIGINT, (block, r) -> BIGINT.writeLong(block, r.nextLong())), INT64_FIXED_32K(BIGINT, (block, r) -> BIGINT.writeLong(block, r.nextLong() % 32768)), REAL_RANDOM(REAL, (block, r) -> REAL.writeLong(block, floatToIntBits(r.nextFloat()))), + VARCHAR_RANDOM(VARCHAR, (block, r) -> { + byte[] buffer = new byte[25]; + r.nextBytes(buffer); + VARCHAR.writeSlice(block, Slices.wrappedBuffer(buffer, 0, buffer.length)); + }), /**/; private final Type type; diff --git a/core/trino-main/src/test/java/io/trino/sql/gen/TestDynamicPageFilter.java b/core/trino-main/src/test/java/io/trino/sql/gen/TestDynamicPageFilter.java index f164723e9329..f1497f836ba9 100644 --- a/core/trino-main/src/test/java/io/trino/sql/gen/TestDynamicPageFilter.java +++ b/core/trino-main/src/test/java/io/trino/sql/gen/TestDynamicPageFilter.java @@ -15,6 +15,8 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import io.airlift.slice.Slice; +import io.airlift.slice.Slices; import io.trino.FullConnectorSession; import io.trino.Session; import io.trino.operator.project.SelectedPositions; @@ -32,6 +34,7 @@ import io.trino.spi.security.ConnectorIdentity; import io.trino.spi.type.RowType; import io.trino.spi.type.Type; +import io.trino.sql.gen.columnar.BloomFilter; import io.trino.sql.gen.columnar.ColumnarFilterCompiler; import io.trino.sql.gen.columnar.DynamicPageFilter; import io.trino.sql.gen.columnar.FilterEvaluator; @@ -39,6 +42,7 @@ import io.trino.sql.planner.Symbol; import org.junit.jupiter.api.Test; +import java.util.ArrayList; import java.util.Arrays; import java.util.List; import java.util.Optional; @@ -57,6 +61,7 @@ import static io.trino.metadata.FunctionManager.createTestingFunctionManager; import static io.trino.operator.project.SelectedPositions.positionsRange; import static io.trino.spi.predicate.Domain.multipleValues; +import static io.trino.spi.predicate.Domain.notNull; import static io.trino.spi.predicate.Domain.onlyNull; import static io.trino.spi.predicate.Domain.singleValue; import static io.trino.spi.type.BigintType.BIGINT; @@ -103,28 +108,52 @@ public void testNonePageFilter() @Test public void testStringFilter() { - ColumnHandle column = new TestingColumnHandle("column"); + ColumnHandle columnA = new TestingColumnHandle("columnA"); + ColumnHandle columnB = new TestingColumnHandle("columnB"); FilterEvaluator filterEvaluator = createDynamicFilterEvaluator( - TupleDomain.withColumnDomains(ImmutableMap.of(column, onlyNull(VARCHAR))), - ImmutableMap.of(column, 0)); - Page page = new Page( + TupleDomain.withColumnDomains(ImmutableMap.of(columnA, onlyNull(VARCHAR))), + ImmutableMap.of(columnA, 0)); + + Page pageWithNull = new Page( createStringsBlock("ab", "bc", null, "cd", null), createStringsBlock(null, "de", "ef", null, "fg")); - verifySelectedPositions(filterPage(page, filterEvaluator), new int[] {2, 4}); + Page pageWithoutNull = new Page( + createStringsBlock("ab", "bc", "cd", "de", "fg"), + createStringsBlock("ab", "bc", "cd", "de", "fg")); + + verifySelectedPositions(filterPage(pageWithNull, filterEvaluator), new int[] {2, 4}); + verifySelectedPositions(filterPage(pageWithoutNull, filterEvaluator), new int[] {}); + filterEvaluator = createDynamicFilterEvaluator( + TupleDomain.withColumnDomains(ImmutableMap.of(columnA, notNull(VARCHAR))), + ImmutableMap.of(columnA, 1)); + verifySelectedPositions(filterPage(pageWithNull, filterEvaluator), new int[] {1, 2, 4}); + verifySelectedPositions(filterPage(pageWithoutNull, filterEvaluator), 5); filterEvaluator = createDynamicFilterEvaluator( TupleDomain.withColumnDomains(ImmutableMap.of( - column, + columnA, multipleValues(VARCHAR, ImmutableList.of("bc", "cd")))), - ImmutableMap.of(column, 0)); - verifySelectedPositions(filterPage(page, filterEvaluator), new int[] {1, 3}); + ImmutableMap.of(columnA, 0)); + verifySelectedPositions(filterPage(pageWithNull, filterEvaluator), new int[] {1, 3}); + verifySelectedPositions(filterPage(pageWithoutNull, filterEvaluator), new int[] {1, 2}); filterEvaluator = createDynamicFilterEvaluator( TupleDomain.withColumnDomains(ImmutableMap.of( - column, + columnA, Domain.create(ValueSet.of(VARCHAR, utf8Slice("ab")), true))), - ImmutableMap.of(column, 0)); - verifySelectedPositions(filterPage(page, filterEvaluator), new int[] {0, 2, 4}); + ImmutableMap.of(columnA, 0)); + verifySelectedPositions(filterPage(pageWithNull, filterEvaluator), new int[] {0, 2, 4}); + verifySelectedPositions(filterPage(pageWithoutNull, filterEvaluator), new int[] {0}); + + filterEvaluator = createDynamicFilterEvaluator( + TupleDomain.withColumnDomains(ImmutableMap.of( + columnA, + Domain.create(ValueSet.of(VARCHAR, utf8Slice("ab"), utf8Slice("cd")), true), + columnB, + Domain.create(ValueSet.of(VARCHAR, utf8Slice("de")), true))), + ImmutableMap.of(columnA, 0, columnB, 1)); + verifySelectedPositions(filterPage(pageWithNull, filterEvaluator), new int[] {0, 3}); + verifySelectedPositions(filterPage(pageWithoutNull, filterEvaluator), new int[] {}); } @Test @@ -461,6 +490,58 @@ columnD, getRangePredicate(-50, 90))), } } + @Test + void testSliceBloomFilter() + { + BloomFilter.SliceBloomFilter filter = new BloomFilter.SliceBloomFilter( + ImmutableList.of( + utf8Slice("Igne"), + utf8Slice("natura"), + utf8Slice("renovitur"), + utf8Slice("integra."))); + assertThat(filter.contains(utf8Slice("Igne"))).isTrue(); + assertThat(filter.contains(utf8Slice("natura"))).isTrue(); + assertThat(filter.contains(utf8Slice("renovitur"))).isTrue(); + assertThat(filter.contains(utf8Slice("integra."))).isTrue(); + + assertThat(filter.contains(utf8Slice("natur"))).isFalse(); + assertThat(filter.contains(utf8Slice("apple"))).isFalse(); + + int valuesCount = 10000; + List testValues = new ArrayList<>(valuesCount); + List filterValues = new ArrayList<>(); + byte base = 0; + for (int i = 0; i < valuesCount; i++) { + Slice value = sequentialBytes(base, i); + testValues.add(value); + base = (byte) (base + i); + if (i % 9 == 0) { + filterValues.add(value); + } + } + + filter = new BloomFilter.SliceBloomFilter(filterValues); + int hits = 0; + for (int i = 0; i < valuesCount; i++) { + boolean contains = filter.contains(testValues.get(i)); + if (i % 9 == 0) { + // No false negatives + assertThat(contains).isTrue(); + } + hits += contains ? 1 : 0; + } + assertThat((double) hits / valuesCount).isBetween(0.1, 0.115); + } + + private static Slice sequentialBytes(byte base, int length) + { + byte[] bytes = new byte[length]; + for (int i = 0; i < length; i++) { + bytes[i] = (byte) (base + i); + } + return Slices.wrappedBuffer(bytes); + } + private static SelectedPositions filterPage(Page page, FilterEvaluator filterEvaluator) { FilterEvaluator.SelectionResult result = filterEvaluator.evaluate(FULL_CONNECTOR_SESSION, positionsRange(0, page.getPositionCount()), page);