From ffdebcc5265ed1207f27066319dd0a1e666efbff Mon Sep 17 00:00:00 2001 From: Raunaq Morarka Date: Wed, 11 Dec 2024 18:21:19 +0530 Subject: [PATCH] Use bloom filter for evaluating dynamic filters on strings MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit BenchmarkDynamicPageFilter.filterPages (filterSize) (inputDataSet) (inputNullChance) (nonNullsSelectivity) (nullsAllowed) Mode Cnt Before Score After Score Units 2 VARCHAR_RANDOM 0.01 0.2 false thrpt 10 80.908 ± 1.927 172.244 ± 1.067 ops/s 5 VARCHAR_RANDOM 0.01 0.2 false thrpt 10 81.052 ± 2.569 175.619 ± 1.225 ops/s 10 VARCHAR_RANDOM 0.01 0.2 false thrpt 10 76.787 ± 1.561 176.371 ± 0.559 ops/s 100 VARCHAR_RANDOM 0.01 0.2 false thrpt 10 75.631 ± 1.372 174.288 ± 1.024 ops/s 1000 VARCHAR_RANDOM 0.01 0.2 false thrpt 10 69.615 ± 0.721 173.340 ± 0.867 ops/s 10000 VARCHAR_RANDOM 0.01 0.2 false thrpt 10 75.401 ± 1.233 173.285 ± 1.752 ops/s 100000 VARCHAR_RANDOM 0.01 0.2 false thrpt 10 64.335 ± 2.936 170.087 ± 1.370 ops/s 1000000 VARCHAR_RANDOM 0.01 0.2 false thrpt 10 16.808 ± 3.205 170.403 ± 1.471 ops/s 5000000 VARCHAR_RANDOM 0.01 0.2 false thrpt 10 15.766 ± 0.820 150.588 ± 4.034 ops/s --- .../trino/sql/gen/columnar/BloomFilter.java | 344 ++++++++++++++++++ .../sql/gen/columnar/DynamicPageFilter.java | 35 +- .../trino/sql/planner/DomainTranslator.java | 4 +- .../sql/gen/BenchmarkDynamicPageFilter.java | 30 +- .../trino/sql/gen/TestDynamicPageFilter.java | 154 +++++++- 5 files changed, 528 insertions(+), 39 deletions(-) create mode 100644 core/trino-main/src/main/java/io/trino/sql/gen/columnar/BloomFilter.java diff --git a/core/trino-main/src/main/java/io/trino/sql/gen/columnar/BloomFilter.java b/core/trino-main/src/main/java/io/trino/sql/gen/columnar/BloomFilter.java new file mode 100644 index 000000000000..257ae4304763 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/sql/gen/columnar/BloomFilter.java @@ -0,0 +1,344 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.sql.gen.columnar; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.Throwables; +import com.google.common.cache.CacheBuilder; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import io.airlift.bytecode.BytecodeBlock; +import io.airlift.bytecode.ClassDefinition; +import io.airlift.bytecode.FieldDefinition; +import io.airlift.bytecode.MethodDefinition; +import io.airlift.bytecode.Parameter; +import io.airlift.bytecode.Scope; +import io.airlift.bytecode.Variable; +import io.airlift.bytecode.control.ForLoop; +import io.airlift.bytecode.expression.BytecodeExpression; +import io.airlift.slice.Slice; +import io.airlift.slice.XxHash64; +import io.trino.annotation.UsedByGeneratedCode; +import io.trino.cache.NonEvictableCache; +import io.trino.operator.project.InputChannels; +import io.trino.spi.Page; +import io.trino.spi.TrinoException; +import io.trino.spi.block.Block; +import io.trino.spi.connector.ConnectorSession; +import io.trino.spi.predicate.Domain; +import io.trino.spi.type.CharType; +import io.trino.spi.type.Type; +import io.trino.spi.type.VarbinaryType; +import io.trino.spi.type.VarcharType; +import org.objectweb.asm.MethodTooLargeException; + +import java.lang.invoke.MethodHandle; +import java.util.List; +import java.util.Optional; +import java.util.concurrent.TimeUnit; +import java.util.function.Supplier; + +import static com.google.common.base.Verify.verify; +import static io.airlift.bytecode.Access.FINAL; +import static io.airlift.bytecode.Access.PRIVATE; +import static io.airlift.bytecode.Access.PUBLIC; +import static io.airlift.bytecode.Access.a; +import static io.airlift.bytecode.Parameter.arg; +import static io.airlift.bytecode.ParameterizedType.type; +import static io.airlift.bytecode.expression.BytecodeExpressions.add; +import static io.airlift.bytecode.expression.BytecodeExpressions.constantInt; +import static io.airlift.bytecode.expression.BytecodeExpressions.inlineIf; +import static io.airlift.bytecode.expression.BytecodeExpressions.lessThan; +import static io.trino.cache.CacheUtils.uncheckedCacheGet; +import static io.trino.cache.SafeCaches.buildNonEvictableCache; +import static io.trino.spi.StandardErrorCode.COMPILER_ERROR; +import static io.trino.spi.StandardErrorCode.QUERY_EXCEEDED_COMPILER_LIMIT; +import static io.trino.sql.gen.columnar.ColumnarFilterCompiler.updateOutputPositions; +import static io.trino.util.CompilerUtils.defineClass; +import static io.trino.util.CompilerUtils.makeClassName; +import static io.trino.util.Reflection.constructorMethodHandle; +import static java.util.Objects.requireNonNull; + +public class BloomFilter +{ + // Generate a ColumnarBloomFilter class per Type to avoid mega-morphic call site when reading a position from input block + private static final NonEvictableCache COLUMNAR_BLOOM_FILTER_CACHE = buildNonEvictableCache( + CacheBuilder.newBuilder() + .maximumSize(100) + .expireAfterWrite(2, TimeUnit.HOURS)); + + private BloomFilter() {} + + public static boolean canUseBloomFilter(Domain domain) + { + Type type = domain.getType(); + if (type instanceof VarcharType || type instanceof CharType || type instanceof VarbinaryType) { + verify(type.getJavaType() == Slice.class, "Type is not backed by Slice"); + return !domain.isNone() + && !domain.isAll() + && domain.isNullableDiscreteSet() + && domain.getValues().getRanges().getRangeCount() > 1; // Bloom filter is not faster to evaluate for single value + } + return false; + } + + public static Supplier createBloomFilterEvaluator(Domain domain, int inputChannel) + { + return () -> new ColumnarFilterEvaluator( + new DictionaryAwareColumnarFilter( + createColumnarBloomFilter(domain.getType(), inputChannel, domain.getNullableDiscreteSet()).get())); + } + + private static Supplier createColumnarBloomFilter(Type type, int inputChannel, Domain.DiscreteSet discreteSet) + { + MethodHandle filterConstructor = uncheckedCacheGet( + COLUMNAR_BLOOM_FILTER_CACHE, + type, + () -> generateColumnarBloomFilterClass(type)); + + return () -> { + try { + SliceBloomFilter filter = new SliceBloomFilter((List) (List) discreteSet.getNonNullValues(), discreteSet.containsNull(), type); + InputChannels inputChannels = new InputChannels(ImmutableList.of(inputChannel), ImmutableList.of(inputChannel)); + return (ColumnarFilter) filterConstructor.invoke(filter, inputChannels); + } + catch (Throwable e) { + throw new RuntimeException(e); + } + }; + } + + private static MethodHandle generateColumnarBloomFilterClass(Type type) + { + ClassDefinition classDefinition = new ClassDefinition( + a(PUBLIC, FINAL), + makeClassName(ColumnarFilter.class.getSimpleName() + "_bloom_filter_" + type, Optional.empty()), + type(Object.class), + type(ColumnarFilter.class)); + + FieldDefinition filterField = classDefinition.declareField(a(PRIVATE, FINAL), "filter", SliceBloomFilter.class); + FieldDefinition inputChannelsField = classDefinition.declareField(a(PRIVATE, FINAL), "inputChannels", InputChannels.class); + Parameter filterParameter = arg("filter", SliceBloomFilter.class); + Parameter inputChannelsParameter = arg("inputChannels", InputChannels.class); + MethodDefinition constructorDefinition = classDefinition.declareConstructor(a(PUBLIC), filterParameter, inputChannelsParameter); + BytecodeBlock body = constructorDefinition.getBody(); + Variable thisVariable = constructorDefinition.getThis(); + body.comment("super();") + .append(thisVariable) + .invokeConstructor(Object.class) + .append(thisVariable.setField(filterField, filterParameter)) + .append(thisVariable.setField(inputChannelsField, inputChannelsParameter)) + .ret(); + + // getInputChannels + MethodDefinition method = classDefinition.declareMethod(a(PUBLIC), "getInputChannels", type(InputChannels.class)); + Scope scope = method.getScope(); + method.getBody().append(scope.getThis().getField(inputChannelsField).ret()); + + generateFilterRangeMethod(classDefinition); + generateFilterListMethod(classDefinition); + + Class filterClass; + try { + filterClass = defineClass(classDefinition, ColumnarFilter.class, ImmutableMap.of(), ColumnarFilterCompiler.class.getClassLoader()); + } + catch (Exception e) { + if (Throwables.getRootCause(e) instanceof MethodTooLargeException) { + throw new TrinoException(QUERY_EXCEEDED_COMPILER_LIMIT, + "Query exceeded maximum filters. Please reduce the number of filters referenced and re-run the query.", e); + } + throw new TrinoException(COMPILER_ERROR, e.getCause()); + } + return constructorMethodHandle(filterClass, SliceBloomFilter.class, InputChannels.class); + } + + private static void generateFilterRangeMethod(ClassDefinition classDefinition) + { + Parameter session = arg("session", ConnectorSession.class); + Parameter outputPositions = arg("outputPositions", int[].class); + Parameter offset = arg("offset", int.class); + Parameter size = arg("size", int.class); + Parameter page = arg("page", Page.class); + + MethodDefinition method = classDefinition.declareMethod( + a(PUBLIC), + "filterPositionsRange", + type(int.class), + ImmutableList.of(session, outputPositions, offset, size, page)); + Scope scope = method.getScope(); + BytecodeBlock body = method.getBody(); + + Variable block = declareBlockVariable(page, scope, body); + Variable outputPositionsCount = scope.declareVariable("outputPositionsCount", body, constantInt(0)); + Variable position = scope.declareVariable(int.class, "position"); + Variable result = scope.declareVariable(boolean.class, "result"); + + /* for(int position = offset; position < offset + size; ++position) { + * boolean result = block.isNull(position) ? this.filter.containsNull() : this.filter.test(block, position); + * outputPositions[outputPositionsCount] = position; + * outputPositionsCount += result ? 1 : 0; + * } + */ + body.append(new ForLoop("nullable range based loop") + .initialize(position.set(offset)) + .condition(lessThan(position, add(offset, size))) + .update(position.increment()) + .body(new BytecodeBlock() + .append(generateBloomFilterTest(scope, block, position, result)) + .append(updateOutputPositions(result, position, outputPositions, outputPositionsCount)))); + + body.append(outputPositionsCount.ret()); + } + + private static void generateFilterListMethod(ClassDefinition classDefinition) + { + Parameter session = arg("session", ConnectorSession.class); + Parameter outputPositions = arg("outputPositions", int[].class); + Parameter activePositions = arg("activePositions", int[].class); + Parameter offset = arg("offset", int.class); + Parameter size = arg("size", int.class); + Parameter page = arg("page", Page.class); + + MethodDefinition method = classDefinition.declareMethod( + a(PUBLIC), + "filterPositionsList", + type(int.class), + ImmutableList.of(session, outputPositions, activePositions, offset, size, page)); + Scope scope = method.getScope(); + BytecodeBlock body = method.getBody(); + + Variable block = declareBlockVariable(page, scope, body); + Variable outputPositionsCount = scope.declareVariable("outputPositionsCount", body, constantInt(0)); + Variable index = scope.declareVariable(int.class, "index"); + Variable position = scope.declareVariable(int.class, "position"); + Variable result = scope.declareVariable(boolean.class, "result"); + + /* for(int index = offset; index < offset + size; ++index) { + * int position = activePositions[index]; + * boolean result = block.isNull(position) ? this.filter.containsNull() : this.filter.test(block, position); + * outputPositions[outputPositionsCount] = position; + * outputPositionsCount += result ? 1 : 0; + * } + */ + body.append(new ForLoop("nullable range based loop") + .initialize(index.set(offset)) + .condition(lessThan(index, add(offset, size))) + .update(index.increment()) + .body(new BytecodeBlock() + .append(position.set(activePositions.getElement(index))) + .append(generateBloomFilterTest(scope, block, position, result)) + .append(updateOutputPositions(result, position, outputPositions, outputPositionsCount)))); + + body.append(outputPositionsCount.ret()); + } + + private static Variable declareBlockVariable(Parameter page, Scope scope, BytecodeBlock body) + { + return scope.declareVariable( + "block", + body, + page.invoke("getBlock", Block.class, constantInt(0))); + } + + private static BytecodeBlock generateBloomFilterTest(Scope scope, Variable block, Variable position, Variable result) + { + BytecodeExpression filter = scope.getThis().getField("filter", SliceBloomFilter.class); + // boolean result = block.isNull(position) ? this.filter.containsNull() : this.filter.test(block, position) + return new BytecodeBlock() + .append(result.set(inlineIf( + block.invoke("isNull", boolean.class, position), + filter.invoke("containsNull", boolean.class), + filter.invoke("test", boolean.class, block, position)))); + } + + public static final class SliceBloomFilter + { + private final long[] bloom; + private final int bloomSizeMask; + private final Type type; + private final boolean containsNull; + + /** + * A Bloom filter for a set of Slice values. + * This is approx 2X faster than the Bloom filter implementations in ORC and parquet because + * it uses single hash function and uses that to set 3 bits within a 64 bit word. + * The memory footprint is up to (4 * values.size()) bytes, which is much smaller than maintaining a hash set of strings. + * + * @param values values used for filtering + * @param containsNull whether null values are contained by the filter + * @param type type of the values + */ + public SliceBloomFilter(List values, boolean containsNull, Type type) + { + this.containsNull = containsNull; + this.type = requireNonNull(type, "type is null"); + int bloomSize = getBloomFilterSize(values.size()); + bloom = new long[bloomSize]; + bloomSizeMask = bloomSize - 1; + for (Slice value : values) { + long hashCode = XxHash64.hash(value); + // Set 3 bits in a 64 bit word + bloom[bloomIndex(hashCode)] |= bloomMask(hashCode); + } + } + + @UsedByGeneratedCode + public boolean containsNull() + { + return containsNull; + } + + @UsedByGeneratedCode + public boolean test(Block block, int position) + { + return contains(type.getSlice(block, position)); + } + + @VisibleForTesting + public boolean contains(Slice data) + { + long hashCode = XxHash64.hash(data); + long mask = bloomMask(hashCode); + return mask == (bloom[bloomIndex(hashCode)] & mask); + } + + private int bloomIndex(long hashCode) + { + // Lower 21 bits are not used by bloomMask + // These are enough for the maximum size array that will be used here + return (int) (hashCode & bloomSizeMask); + } + + private static long bloomMask(long hashCode) + { + // returned mask sets 3 bits based on portions of given hash + // Extract 38th to 43rd bits + return (1L << ((hashCode >> 21) & 63)) + // Extract 32nd to 37th bits + | (1L << ((hashCode >> 27) & 63)) + // Extract 26th to 31st bits + | (1L << ((hashCode >> 33) & 63)); + } + + private static int getBloomFilterSize(int valuesCount) + { + // Linear hash table size is the highest power of two less than or equal to number of values * 4. This means that the + // table is under half full, e.g. 127 elements gets 256 slots. + int hashTableSize = Integer.highestOneBit(valuesCount * 4); + // We will allocate 8 bits in the bloom filter for every slot in a comparable hash table. + // The bloomSize is a count of longs, hence / 8. + return Math.max(1, hashTableSize / 8); + } + } +} diff --git a/core/trino-main/src/main/java/io/trino/sql/gen/columnar/DynamicPageFilter.java b/core/trino-main/src/main/java/io/trino/sql/gen/columnar/DynamicPageFilter.java index 7e3814426ba2..91172127080a 100644 --- a/core/trino-main/src/main/java/io/trino/sql/gen/columnar/DynamicPageFilter.java +++ b/core/trino-main/src/main/java/io/trino/sql/gen/columnar/DynamicPageFilter.java @@ -41,6 +41,8 @@ import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.ImmutableMap.toImmutableMap; +import static io.trino.sql.gen.columnar.BloomFilter.canUseBloomFilter; +import static io.trino.sql.gen.columnar.BloomFilter.createBloomFilterEvaluator; import static io.trino.sql.gen.columnar.FilterEvaluator.createColumnarFilterEvaluator; import static io.trino.sql.ir.optimizer.IrExpressionOptimizer.newOptimizer; import static io.trino.sql.relational.SqlToRowExpressionTranslator.translate; @@ -107,17 +109,7 @@ public synchronized Supplier createDynamicPageFilterEvaluator(C isBlocked = dynamicFilter.isBlocked(); boolean isAwaitable = dynamicFilter.isAwaitable(); TupleDomain currentPredicate = dynamicFilter.getCurrentPredicate().transformKeys(columnHandles::get); - List expressionConjuncts = domainTranslator.toPredicateConjuncts(currentPredicate) - .stream() - // Run the expression derived from TupleDomain through IR optimizer to simplify predicates. E.g. SimplifyContinuousInValues - .map(expression -> irExpressionOptimizer.process(expression, session, ImmutableMap.of()).orElse(expression)) - .collect(toImmutableList()); - // We translate each conjunct into separate RowExpression to make it easy to profile selectivity - // of dynamic filter per column and drop them if they're ineffective - List rowExpression = expressionConjuncts.stream() - .map(expression -> translate(expression, sourceLayout, metadata, typeManager)) - .collect(toImmutableList()); - compiledDynamicFilter = createDynamicFilterEvaluator(rowExpression, compiler, selectivityThreshold); + compiledDynamicFilter = createDynamicFilterEvaluator(compiler, currentPredicate); if (!isAwaitable) { isBlocked = null; // Dynamic filter will not narrow down anymore } @@ -125,10 +117,25 @@ public synchronized Supplier createDynamicPageFilterEvaluator(C return compiledDynamicFilter; } - private static Supplier createDynamicFilterEvaluator(List rowExpressions, ColumnarFilterCompiler compiler, double selectivityThreshold) + private Supplier createDynamicFilterEvaluator(ColumnarFilterCompiler compiler, TupleDomain currentPredicate) { - List> subExpressionEvaluators = rowExpressions.stream() - .map(expression -> createColumnarFilterEvaluator(expression, compiler)) + if (currentPredicate.isNone()) { + return SelectNoneEvaluator::new; + } + // We translate each conjunct into separate FilterEvaluator to make it easy to profile selectivity + // of dynamic filter per column and drop them if they're ineffective + List> subExpressionEvaluators = currentPredicate.getDomains().orElseThrow() + .entrySet().stream() + .map(entry -> { + if (canUseBloomFilter(entry.getValue())) { + return Optional.of(createBloomFilterEvaluator(entry.getValue(), sourceLayout.get(entry.getKey()))); + } + Expression expression = domainTranslator.toPredicate(entry.getValue(), entry.getKey().toSymbolReference()); + // Run the expression derived from TupleDomain through IR optimizer to simplify predicates. E.g. SimplifyContinuousInValues + expression = irExpressionOptimizer.process(expression, session, ImmutableMap.of()).orElse(expression); + RowExpression rowExpression = translate(expression, sourceLayout, metadata, typeManager); + return createColumnarFilterEvaluator(rowExpression, compiler); + }) .filter(Optional::isPresent) .map(Optional::get) .collect(toImmutableList()); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/DomainTranslator.java b/core/trino-main/src/main/java/io/trino/sql/planner/DomainTranslator.java index c2936990783e..bebface2e623 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/DomainTranslator.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/DomainTranslator.java @@ -119,7 +119,7 @@ public Expression toPredicate(TupleDomain tupleDomain) return IrUtils.combineConjuncts(toPredicateConjuncts(tupleDomain)); } - public List toPredicateConjuncts(TupleDomain tupleDomain) + private List toPredicateConjuncts(TupleDomain tupleDomain) { if (tupleDomain.isNone()) { return ImmutableList.of(FALSE); @@ -132,7 +132,7 @@ public List toPredicateConjuncts(TupleDomain tupleDomain) .collect(toImmutableList()); } - private Expression toPredicate(Domain domain, Reference reference) + public Expression toPredicate(Domain domain, Reference reference) { if (domain.getValues().isNone()) { return domain.isNullAllowed() ? new IsNull(reference) : FALSE; diff --git a/core/trino-main/src/test/java/io/trino/sql/gen/BenchmarkDynamicPageFilter.java b/core/trino-main/src/test/java/io/trino/sql/gen/BenchmarkDynamicPageFilter.java index dfebef659217..21d114874ea4 100644 --- a/core/trino-main/src/test/java/io/trino/sql/gen/BenchmarkDynamicPageFilter.java +++ b/core/trino-main/src/test/java/io/trino/sql/gen/BenchmarkDynamicPageFilter.java @@ -15,6 +15,7 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import io.airlift.slice.Slices; import io.trino.FullConnectorSession; import io.trino.operator.project.SelectedPositions; import io.trino.spi.Page; @@ -43,15 +44,14 @@ import java.util.Random; import java.util.concurrent.TimeUnit; -import static com.google.common.collect.ImmutableList.toImmutableList; import static io.trino.jmh.Benchmarks.benchmark; import static io.trino.operator.project.SelectedPositions.positionsRange; -import static io.trino.spi.predicate.Domain.DiscreteSet; import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.IntegerType.INTEGER; import static io.trino.spi.type.RealType.REAL; import static io.trino.spi.type.TypeUtils.readNativeValue; import static io.trino.spi.type.TypeUtils.writeNativeValue; +import static io.trino.spi.type.VarcharType.VARCHAR; import static io.trino.testing.TestingSession.testSessionBuilder; import static io.trino.util.DynamicFiltersTestUtil.createDynamicFilterEvaluator; import static java.lang.Float.floatToIntBits; @@ -65,10 +65,11 @@ @Measurement(iterations = 10, time = 500, timeUnit = TimeUnit.MILLISECONDS) public class BenchmarkDynamicPageFilter { - private static final int MAX_ROWS = 200_000; + private static final int MAX_ROWS = 400_000; private static final FullConnectorSession FULL_CONNECTOR_SESSION = new FullConnectorSession( testSessionBuilder().build(), ConnectorIdentity.ofUser("test")); + private static final ColumnHandle COLUMN_HANDLE = new TestingColumnHandle("dummy"); @Param("0.05") public double inputNullChance = 0.05; @@ -76,7 +77,7 @@ public class BenchmarkDynamicPageFilter @Param("0.2") public double nonNullsSelectivity = 0.2; - @Param({"100", "1000", "5000"}) + @Param({"100", "1000", "10000"}) public int filterSize = 100; @Param("false") @@ -87,6 +88,7 @@ public class BenchmarkDynamicPageFilter "INT64_RANDOM", "INT64_FIXED_32K", // LongBitSetFilter "REAL_RANDOM", + "VARCHAR_RANDOM", // BloomFilter }) public DataSet inputDataSet; @@ -99,6 +101,11 @@ public enum DataSet INT64_RANDOM(BIGINT, (block, r) -> BIGINT.writeLong(block, r.nextLong())), INT64_FIXED_32K(BIGINT, (block, r) -> BIGINT.writeLong(block, r.nextLong() % 32768)), REAL_RANDOM(REAL, (block, r) -> REAL.writeLong(block, floatToIntBits(r.nextFloat()))), + VARCHAR_RANDOM(VARCHAR, (block, r) -> { + byte[] buffer = new byte[r.nextInt(10, 15)]; + r.nextBytes(buffer); + VARCHAR.writeSlice(block, Slices.wrappedBuffer(buffer, 0, buffer.length)); + }), /**/; private final Type type; @@ -121,7 +128,7 @@ public TupleDomain createFilterTupleDomain(int filterSize, boolean } } return TupleDomain.withColumnDomains(ImmutableMap.of( - new TestingColumnHandle("dummy"), + COLUMN_HANDLE, Domain.create(ValueSet.copyOf(type, valuesBuilder.build()), nullsAllowed))); } @@ -132,12 +139,9 @@ private List createInputTestData( long inputRows) { List nonNullValues = filter.getDomains().orElseThrow() - .values().stream() - .flatMap(domain -> { - DiscreteSet nullableDiscreteSet = domain.getNullableDiscreteSet(); - return nullableDiscreteSet.getNonNullValues().stream(); - }) - .collect(toImmutableList()); + .get(COLUMN_HANDLE) + .getNullableDiscreteSet() + .getNonNullValues(); // pick a random value from the filter return createSingleColumnData( @@ -163,7 +167,7 @@ public void setup() inputData = inputDataSet.createInputTestData(filterPredicate, inputNullChance, nonNullsSelectivity, MAX_ROWS); filterEvaluator = createDynamicFilterEvaluator( filterPredicate, - ImmutableMap.of(new TestingColumnHandle("dummy"), 0), + ImmutableMap.of(COLUMN_HANDLE, 0), 1); } @@ -199,7 +203,7 @@ private static List createSingleColumnData(ValueWriter valueWriter, Type t if (blockBuilder.getPositionCount() >= batchSize) { Block block = blockBuilder.build(); pages.add(new Page(new LazyBlock(block.getPositionCount(), () -> block))); - batchSize = Math.min(1024, batchSize * 2); + batchSize = Math.min(8192, batchSize * 2); blockBuilder = type.createBlockBuilder(null, batchSize); } } diff --git a/core/trino-main/src/test/java/io/trino/sql/gen/TestDynamicPageFilter.java b/core/trino-main/src/test/java/io/trino/sql/gen/TestDynamicPageFilter.java index f164723e9329..baf76c88296d 100644 --- a/core/trino-main/src/test/java/io/trino/sql/gen/TestDynamicPageFilter.java +++ b/core/trino-main/src/test/java/io/trino/sql/gen/TestDynamicPageFilter.java @@ -15,6 +15,8 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import io.airlift.slice.Slice; +import io.airlift.slice.Slices; import io.trino.FullConnectorSession; import io.trino.Session; import io.trino.operator.project.SelectedPositions; @@ -30,8 +32,10 @@ import io.trino.spi.predicate.TupleDomain; import io.trino.spi.predicate.ValueSet; import io.trino.spi.security.ConnectorIdentity; +import io.trino.spi.type.CharType; import io.trino.spi.type.RowType; import io.trino.spi.type.Type; +import io.trino.sql.gen.columnar.BloomFilter; import io.trino.sql.gen.columnar.ColumnarFilterCompiler; import io.trino.sql.gen.columnar.DynamicPageFilter; import io.trino.sql.gen.columnar.FilterEvaluator; @@ -39,6 +43,7 @@ import io.trino.sql.planner.Symbol; import org.junit.jupiter.api.Test; +import java.util.ArrayList; import java.util.Arrays; import java.util.List; import java.util.Optional; @@ -52,11 +57,13 @@ import static io.trino.block.BlockAssertions.createLongSequenceBlock; import static io.trino.block.BlockAssertions.createLongsBlock; import static io.trino.block.BlockAssertions.createRowBlock; +import static io.trino.block.BlockAssertions.createSlicesBlock; import static io.trino.block.BlockAssertions.createStringsBlock; import static io.trino.block.BlockAssertions.createTypedLongsBlock; import static io.trino.metadata.FunctionManager.createTestingFunctionManager; import static io.trino.operator.project.SelectedPositions.positionsRange; import static io.trino.spi.predicate.Domain.multipleValues; +import static io.trino.spi.predicate.Domain.notNull; import static io.trino.spi.predicate.Domain.onlyNull; import static io.trino.spi.predicate.Domain.singleValue; import static io.trino.spi.type.BigintType.BIGINT; @@ -64,7 +71,9 @@ import static io.trino.spi.type.IntegerType.INTEGER; import static io.trino.spi.type.RealType.REAL; import static io.trino.spi.type.RowType.rowType; +import static io.trino.spi.type.VarbinaryType.VARBINARY; import static io.trino.spi.type.VarcharType.VARCHAR; +import static io.trino.sql.gen.columnar.BloomFilter.canUseBloomFilter; import static io.trino.sql.planner.TestingPlannerContext.PLANNER_CONTEXT; import static io.trino.testing.TestingSession.testSessionBuilder; import static io.trino.util.DynamicFiltersTestUtil.TestingDynamicFilter; @@ -101,28 +110,86 @@ public void testNonePageFilter() } @Test - public void testStringFilter() + void testVarcharFilter() { - ColumnHandle column = new TestingColumnHandle("column"); + ColumnHandle columnA = new TestingColumnHandle("columnA"); + ColumnHandle columnB = new TestingColumnHandle("columnB"); FilterEvaluator filterEvaluator = createDynamicFilterEvaluator( - TupleDomain.withColumnDomains(ImmutableMap.of(column, onlyNull(VARCHAR))), - ImmutableMap.of(column, 0)); - Page page = new Page( + TupleDomain.withColumnDomains(ImmutableMap.of(columnA, onlyNull(VARCHAR))), + ImmutableMap.of(columnA, 0)); + + Page pageWithNull = new Page( createStringsBlock("ab", "bc", null, "cd", null), createStringsBlock(null, "de", "ef", null, "fg")); - verifySelectedPositions(filterPage(page, filterEvaluator), new int[] {2, 4}); + Page pageWithoutNull = new Page( + createStringsBlock("ab", "bc", "cd", "de", "fg"), + createStringsBlock("ab", "bc", "cd", "de", "fg")); + + verifySelectedPositions(filterPage(pageWithNull, filterEvaluator), new int[] {2, 4}); + verifySelectedPositions(filterPage(pageWithoutNull, filterEvaluator), new int[] {}); + filterEvaluator = createDynamicFilterEvaluator( + TupleDomain.withColumnDomains(ImmutableMap.of(columnA, notNull(VARCHAR))), + ImmutableMap.of(columnA, 1)); + verifySelectedPositions(filterPage(pageWithNull, filterEvaluator), new int[] {1, 2, 4}); + verifySelectedPositions(filterPage(pageWithoutNull, filterEvaluator), 5); filterEvaluator = createDynamicFilterEvaluator( TupleDomain.withColumnDomains(ImmutableMap.of( - column, + columnA, multipleValues(VARCHAR, ImmutableList.of("bc", "cd")))), - ImmutableMap.of(column, 0)); - verifySelectedPositions(filterPage(page, filterEvaluator), new int[] {1, 3}); + ImmutableMap.of(columnA, 0)); + verifySelectedPositions(filterPage(pageWithNull, filterEvaluator), new int[] {1, 3}); + verifySelectedPositions(filterPage(pageWithoutNull, filterEvaluator), new int[] {1, 2}); + + filterEvaluator = createDynamicFilterEvaluator( + TupleDomain.withColumnDomains(ImmutableMap.of( + columnA, + multipleValues(VARCHAR, ImmutableList.of("a", "ab"), true))), + ImmutableMap.of(columnA, 0)); + verifySelectedPositions(filterPage(pageWithNull, filterEvaluator), new int[] {0, 2, 4}); + verifySelectedPositions(filterPage(pageWithoutNull, filterEvaluator), new int[] {0}); filterEvaluator = createDynamicFilterEvaluator( + TupleDomain.withColumnDomains(ImmutableMap.of( + columnA, + multipleValues(VARCHAR, ImmutableList.of("ab", "cd"), true), + columnB, + multipleValues(VARCHAR, ImmutableList.of("a", "de"), true))), + ImmutableMap.of(columnA, 0, columnB, 1)); + verifySelectedPositions(filterPage(pageWithNull, filterEvaluator), new int[] {0, 3}); + verifySelectedPositions(filterPage(pageWithoutNull, filterEvaluator), new int[] {}); + } + + @Test + void testVarbinaryFilter() + { + ColumnHandle column = new TestingColumnHandle("columnA"); + + Page page = new Page( + createSlicesBlock(utf8Slice("ab"), utf8Slice("bc"), null, utf8Slice("cd"), null), + createSlicesBlock(null, utf8Slice("de"), utf8Slice("ef"), null, utf8Slice("fg"))); + + FilterEvaluator filterEvaluator = createDynamicFilterEvaluator( + TupleDomain.withColumnDomains(ImmutableMap.of( + column, + multipleValues(VARBINARY, ImmutableList.of("a", "ab"), true))), + ImmutableMap.of(column, 0)); + verifySelectedPositions(filterPage(page, filterEvaluator), new int[] {0, 2, 4}); + } + + @Test + void testCharFilter() + { + ColumnHandle column = new TestingColumnHandle("columnA"); + + Page page = new Page( + createSlicesBlock(utf8Slice("ab"), utf8Slice("bc"), null, utf8Slice("cd"), null), + createSlicesBlock(null, utf8Slice("de"), utf8Slice("ef"), null, utf8Slice("fg"))); + + FilterEvaluator filterEvaluator = createDynamicFilterEvaluator( TupleDomain.withColumnDomains(ImmutableMap.of( column, - Domain.create(ValueSet.of(VARCHAR, utf8Slice("ab")), true))), + multipleValues(CharType.createCharType(2), ImmutableList.of("a", "ab"), true))), ImmutableMap.of(column, 0)); verifySelectedPositions(filterPage(page, filterEvaluator), new int[] {0, 2, 4}); } @@ -461,6 +528,73 @@ columnD, getRangePredicate(-50, 90))), } } + @Test + void testCanUseBloomFilter() + { + assertThat(canUseBloomFilter(Domain.multipleValues(BIGINT, ImmutableList.of(1L, 2L, 3L)))).isFalse(); + assertThat(canUseBloomFilter(Domain.singleValue(VARCHAR, utf8Slice("A")))).isFalse(); + assertThat(canUseBloomFilter(Domain.notNull(VARCHAR))).isFalse(); + assertThat(canUseBloomFilter(Domain.onlyNull(VARCHAR))).isFalse(); + assertThat(canUseBloomFilter(Domain.none(VARCHAR))).isFalse(); + assertThat(canUseBloomFilter(Domain.all(VARCHAR))).isFalse(); + assertThat(canUseBloomFilter(Domain.create(ValueSet.of(VARCHAR, utf8Slice("a"), utf8Slice("b")), false))).isTrue(); + assertThat(canUseBloomFilter(Domain.create(ValueSet.of(VARCHAR, utf8Slice("a"), utf8Slice("b"), utf8Slice("c")), true))).isTrue(); + } + + @Test + void testSliceBloomFilter() + { + BloomFilter.SliceBloomFilter filter = new BloomFilter.SliceBloomFilter( + ImmutableList.of( + utf8Slice("Igne"), + utf8Slice("natura"), + utf8Slice("renovitur"), + utf8Slice("integra.")), + false, + VARCHAR); + assertThat(filter.contains(utf8Slice("Igne"))).isTrue(); + assertThat(filter.contains(utf8Slice("natura"))).isTrue(); + assertThat(filter.contains(utf8Slice("renovitur"))).isTrue(); + assertThat(filter.contains(utf8Slice("integra."))).isTrue(); + + assertThat(filter.contains(utf8Slice("natur"))).isFalse(); + assertThat(filter.contains(utf8Slice("apple"))).isFalse(); + + int valuesCount = 10000; + List testValues = new ArrayList<>(valuesCount); + List filterValues = new ArrayList<>(); + byte base = 0; + for (int i = 0; i < valuesCount; i++) { + Slice value = sequentialBytes(base, i); + testValues.add(value); + base = (byte) (base + i); + if (i % 9 == 0) { + filterValues.add(value); + } + } + + filter = new BloomFilter.SliceBloomFilter(filterValues, false, VARCHAR); + int hits = 0; + for (int i = 0; i < valuesCount; i++) { + boolean contains = filter.contains(testValues.get(i)); + if (i % 9 == 0) { + // No false negatives + assertThat(contains).isTrue(); + } + hits += contains ? 1 : 0; + } + assertThat((double) hits / valuesCount).isBetween(0.1, 0.115); + } + + private static Slice sequentialBytes(byte base, int length) + { + byte[] bytes = new byte[length]; + for (int i = 0; i < length; i++) { + bytes[i] = (byte) (base + i); + } + return Slices.wrappedBuffer(bytes); + } + private static SelectedPositions filterPage(Page page, FilterEvaluator filterEvaluator) { FilterEvaluator.SelectionResult result = filterEvaluator.evaluate(FULL_CONNECTOR_SESSION, positionsRange(0, page.getPositionCount()), page);