Skip to content

Commit

Permalink
Use bloom filter for evaluating dynamic filters on strings
Browse files Browse the repository at this point in the history
BenchmarkDynamicPageFilter.filterPages
(filterSize)  (inputDataSet)  (inputNullChance)  (nonNullsSelectivity)  (nullsAllowed)   Mode  Cnt     Before Score       After Score  Units
         100  VARCHAR_RANDOM               0.05                    0.2           false  thrpt   20  145.858 ± 4.541  590.506 ± 28.510  ops/s
        1000  VARCHAR_RANDOM               0.05                    0.2           false  thrpt   20  136.995 ± 2.395  596.036 ± 22.694  ops/s
       10000  VARCHAR_RANDOM               0.05                    0.2           false  thrpt   20  136.990 ± 5.284  594.118 ± 15.764  ops/s
      100000  VARCHAR_RANDOM               0.05                    0.2           false  thrpt   20  114.591 ± 7.307  587.445 ±  9.818  ops/s
     1000000  VARCHAR_RANDOM               0.05                    0.2           false  thrpt   20   43.234 ± 1.621  578.800 ± 15.694  ops/s
     5000000  VARCHAR_RANDOM               0.05                    0.2           false  thrpt   20   40.018 ± 2.245  464.153 ± 20.914  ops/s
  • Loading branch information
raunaqmorarka committed Dec 20, 2024
1 parent a99d96e commit 13b8ccd
Show file tree
Hide file tree
Showing 5 changed files with 356 additions and 39 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.sql.gen.columnar;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ImmutableList;
import io.airlift.slice.Slice;
import io.airlift.slice.XxHash64;
import io.trino.operator.project.InputChannels;
import io.trino.spi.Page;
import io.trino.spi.block.Block;
import io.trino.spi.block.ValueBlock;
import io.trino.spi.connector.ConnectorSession;
import io.trino.spi.predicate.Domain;
import io.trino.spi.type.CharType;
import io.trino.spi.type.Type;
import io.trino.spi.type.VarbinaryType;
import io.trino.spi.type.VarcharType;

import java.util.List;
import java.util.function.Supplier;

import static com.google.common.base.Verify.verify;
import static java.util.Objects.requireNonNull;

public class BloomFilter
{
private BloomFilter() {}

public static boolean canUseBloomFilter(Domain domain)
{
Type type = domain.getType();
if (type instanceof VarcharType || type instanceof CharType || type instanceof VarbinaryType) {
verify(type.getJavaType() == Slice.class, "Type is not backed by Slice");
return !domain.isNone() && !domain.isAll() && domain.isNullableDiscreteSet();
}
return false;
}

public static Supplier<FilterEvaluator> createBloomFilterEvaluator(Domain domain, int inputChannel)
{
return () -> new ColumnarFilterEvaluator(
new DictionaryAwareColumnarFilter(
new ColumnarBloomFilter(domain.getNullableDiscreteSet(), inputChannel, domain.getType())));
}

private static final class ColumnarBloomFilter
implements ColumnarFilter
{
private final SliceBloomFilter filter;
private final boolean isNullAllowed;
private final InputChannels inputChannels;

public ColumnarBloomFilter(Domain.DiscreteSet discreteSet, int inputChannel, Type type)
{
this.isNullAllowed = discreteSet.containsNull();
this.filter = new SliceBloomFilter((List<Slice>) (List<?>) discreteSet.getNonNullValues(), type);
this.inputChannels = new InputChannels(ImmutableList.of(inputChannel), ImmutableList.of(inputChannel));
}

@Override
public int filterPositionsRange(ConnectorSession session, int[] outputPositions, int offset, int size, Page page)
{
ValueBlock block = (ValueBlock) page.getBlock(0);
int selectedPositionsCount = 0;
for (int position = offset; position < offset + size; position++) {
boolean result = block.isNull(position) ? isNullAllowed : filter.test(block, position);
outputPositions[selectedPositionsCount] = position;
selectedPositionsCount += result ? 1 : 0;
}
return selectedPositionsCount;
}

@Override
public int filterPositionsList(ConnectorSession session, int[] outputPositions, int[] activePositions, int offset, int size, Page page)
{
ValueBlock block = (ValueBlock) page.getBlock(0);
int selectedPositionsCount = 0;
for (int index = offset; index < offset + size; index++) {
int position = activePositions[index];
boolean result = block.isNull(position) ? isNullAllowed : filter.test(block, position);
outputPositions[selectedPositionsCount] = position;
selectedPositionsCount += result ? 1 : 0;
}
return selectedPositionsCount;
}

@Override
public InputChannels getInputChannels()
{
return inputChannels;
}
}

public static final class SliceBloomFilter
{
private final long[] bloom;
private final int bloomSizeMask;
private final Type type;

/**
* A Bloom filter for a set of Slice values.
* This is approx 2X faster than the Bloom filter implementations in ORC and parquet because
* it uses single hash function and uses that to set 3 bits within a 64 bit word.
* The memory footprint is up to (4 * values.size()) bytes, which is much smaller than maintaining a hash set of strings.
*
* @param values List of values used for filtering
*/
public SliceBloomFilter(List<Slice> values, Type type)
{
this.type = requireNonNull(type, "type is null");
int bloomSize = getBloomFilterSize(values.size());
bloom = new long[bloomSize];
bloomSizeMask = bloomSize - 1;
for (Slice value : values) {
long hashCode = XxHash64.hash(value);
// Set 3 bits in a 64 bit word
bloom[bloomIndex(hashCode)] |= bloomMask(hashCode);
}
}

private static int getBloomFilterSize(int valuesCount)
{
// Linear hash table size is the highest power of two less than or equal to number of values * 4. This means that the
// table is under half full, e.g. 127 elements gets 256 slots.
int hashTableSize = Integer.highestOneBit(valuesCount * 4);
// We will allocate 8 bits in the bloom filter for every slot in a comparable hash table.
// The bloomSize is a count of longs, hence / 8.
return Math.max(1, hashTableSize / 8);
}

public boolean test(Block block, int position)
{
return contains(type.getSlice(block, position));
}

public boolean contains(Slice data)
{
long hashCode = XxHash64.hash(data);
long mask = bloomMask(hashCode);
return mask == (bloom[bloomIndex(hashCode)] & mask);
}

@VisibleForTesting
public boolean contains(Slice data, int offset, int length)
{
long hashCode = XxHash64.hash(data, offset, length);
long mask = bloomMask(hashCode);
return mask == (bloom[bloomIndex(hashCode)] & mask);
}

private int bloomIndex(long hashCode)
{
// Lower 21 bits are not used by bloomMask
// These are enough for the maximum size array that will be used here
return (int) (hashCode & bloomSizeMask);
}

private static long bloomMask(long hashCode)
{
// returned mask sets 3 bits based on portions of given hash
// Extract 38th to 43rd bits
return (1L << ((hashCode >> 21) & 63))
// Extract 32nd to 37th bits
| (1L << ((hashCode >> 27) & 63))
// Extract 26th to 31st bits
| (1L << ((hashCode >> 33) & 63));
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@

import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.ImmutableMap.toImmutableMap;
import static io.trino.sql.gen.columnar.BloomFilter.canUseBloomFilter;
import static io.trino.sql.gen.columnar.BloomFilter.createBloomFilterEvaluator;
import static io.trino.sql.gen.columnar.FilterEvaluator.createColumnarFilterEvaluator;
import static io.trino.sql.ir.optimizer.IrExpressionOptimizer.newOptimizer;
import static io.trino.sql.relational.SqlToRowExpressionTranslator.translate;
Expand Down Expand Up @@ -107,28 +109,33 @@ public synchronized Supplier<FilterEvaluator> createDynamicPageFilterEvaluator(C
isBlocked = dynamicFilter.isBlocked();
boolean isAwaitable = dynamicFilter.isAwaitable();
TupleDomain<Symbol> currentPredicate = dynamicFilter.getCurrentPredicate().transformKeys(columnHandles::get);
List<Expression> expressionConjuncts = domainTranslator.toPredicateConjuncts(currentPredicate)
.stream()
// Run the expression derived from TupleDomain through IR optimizer to simplify predicates. E.g. SimplifyContinuousInValues
.map(expression -> irExpressionOptimizer.process(expression, session, ImmutableMap.of()).orElse(expression))
.collect(toImmutableList());
// We translate each conjunct into separate RowExpression to make it easy to profile selectivity
// of dynamic filter per column and drop them if they're ineffective
List<RowExpression> rowExpression = expressionConjuncts.stream()
.map(expression -> translate(expression, sourceLayout, metadata, typeManager))
.collect(toImmutableList());
compiledDynamicFilter = createDynamicFilterEvaluator(rowExpression, compiler, selectivityThreshold);
compiledDynamicFilter = createDynamicFilterEvaluator(compiler, currentPredicate);
if (!isAwaitable) {
isBlocked = null; // Dynamic filter will not narrow down anymore
}
}
return compiledDynamicFilter;
}

private static Supplier<FilterEvaluator> createDynamicFilterEvaluator(List<RowExpression> rowExpressions, ColumnarFilterCompiler compiler, double selectivityThreshold)
private Supplier<FilterEvaluator> createDynamicFilterEvaluator(ColumnarFilterCompiler compiler, TupleDomain<Symbol> currentPredicate)
{
List<Supplier<FilterEvaluator>> subExpressionEvaluators = rowExpressions.stream()
.map(expression -> createColumnarFilterEvaluator(expression, compiler))
if (currentPredicate.isNone()) {
return SelectNoneEvaluator::new;
}
// We translate each conjunct into separate FilterEvaluator to make it easy to profile selectivity
// of dynamic filter per column and drop them if they're ineffective
List<Supplier<FilterEvaluator>> subExpressionEvaluators = currentPredicate.getDomains().orElseThrow()
.entrySet().stream()
.map(entry -> {
if (canUseBloomFilter(entry.getValue())) {
return Optional.of(createBloomFilterEvaluator(entry.getValue(), sourceLayout.get(entry.getKey())));
}
Expression expression = domainTranslator.toPredicate(entry.getValue(), entry.getKey().toSymbolReference());
// Run the expression derived from TupleDomain through IR optimizer to simplify predicates. E.g. SimplifyContinuousInValues
expression = irExpressionOptimizer.process(expression, session, ImmutableMap.of()).orElse(expression);
RowExpression rowExpression = translate(expression, sourceLayout, metadata, typeManager);
return createColumnarFilterEvaluator(rowExpression, compiler);
})
.filter(Optional::isPresent)
.map(Optional::get)
.collect(toImmutableList());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ public Expression toPredicate(TupleDomain<Symbol> tupleDomain)
return IrUtils.combineConjuncts(toPredicateConjuncts(tupleDomain));
}

public List<Expression> toPredicateConjuncts(TupleDomain<Symbol> tupleDomain)
private List<Expression> toPredicateConjuncts(TupleDomain<Symbol> tupleDomain)
{
if (tupleDomain.isNone()) {
return ImmutableList.of(FALSE);
Expand All @@ -132,7 +132,7 @@ public List<Expression> toPredicateConjuncts(TupleDomain<Symbol> tupleDomain)
.collect(toImmutableList());
}

private Expression toPredicate(Domain domain, Reference reference)
public Expression toPredicate(Domain domain, Reference reference)
{
if (domain.getValues().isNone()) {
return domain.isNullAllowed() ? new IsNull(reference) : FALSE;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import io.airlift.slice.Slices;
import io.trino.FullConnectorSession;
import io.trino.operator.project.SelectedPositions;
import io.trino.spi.Page;
Expand Down Expand Up @@ -43,15 +44,14 @@
import java.util.Random;
import java.util.concurrent.TimeUnit;

import static com.google.common.collect.ImmutableList.toImmutableList;
import static io.trino.jmh.Benchmarks.benchmark;
import static io.trino.operator.project.SelectedPositions.positionsRange;
import static io.trino.spi.predicate.Domain.DiscreteSet;
import static io.trino.spi.type.BigintType.BIGINT;
import static io.trino.spi.type.IntegerType.INTEGER;
import static io.trino.spi.type.RealType.REAL;
import static io.trino.spi.type.TypeUtils.readNativeValue;
import static io.trino.spi.type.TypeUtils.writeNativeValue;
import static io.trino.spi.type.VarcharType.VARCHAR;
import static io.trino.testing.TestingSession.testSessionBuilder;
import static io.trino.util.DynamicFiltersTestUtil.createDynamicFilterEvaluator;
import static java.lang.Float.floatToIntBits;
Expand All @@ -65,18 +65,19 @@
@Measurement(iterations = 10, time = 500, timeUnit = TimeUnit.MILLISECONDS)
public class BenchmarkDynamicPageFilter
{
private static final int MAX_ROWS = 200_000;
private static final int MAX_ROWS = 400_000;
private static final FullConnectorSession FULL_CONNECTOR_SESSION = new FullConnectorSession(
testSessionBuilder().build(),
ConnectorIdentity.ofUser("test"));
private static final ColumnHandle COLUMN_HANDLE = new TestingColumnHandle("dummy");

@Param("0.05")
public double inputNullChance = 0.05;

@Param("0.2")
public double nonNullsSelectivity = 0.2;

@Param({"100", "1000", "5000"})
@Param({"100", "1000", "10000"})
public int filterSize = 100;

@Param("false")
Expand All @@ -87,6 +88,7 @@ public class BenchmarkDynamicPageFilter
"INT64_RANDOM",
"INT64_FIXED_32K", // LongBitSetFilter
"REAL_RANDOM",
"VARCHAR_RANDOM", // BloomFilter
})
public DataSet inputDataSet;

Expand All @@ -99,6 +101,11 @@ public enum DataSet
INT64_RANDOM(BIGINT, (block, r) -> BIGINT.writeLong(block, r.nextLong())),
INT64_FIXED_32K(BIGINT, (block, r) -> BIGINT.writeLong(block, r.nextLong() % 32768)),
REAL_RANDOM(REAL, (block, r) -> REAL.writeLong(block, floatToIntBits(r.nextFloat()))),
VARCHAR_RANDOM(VARCHAR, (block, r) -> {
byte[] buffer = new byte[r.nextInt(10, 15)];
r.nextBytes(buffer);
VARCHAR.writeSlice(block, Slices.wrappedBuffer(buffer, 0, buffer.length));
}),
/**/;

private final Type type;
Expand All @@ -121,7 +128,7 @@ public TupleDomain<ColumnHandle> createFilterTupleDomain(int filterSize, boolean
}
}
return TupleDomain.withColumnDomains(ImmutableMap.of(
new TestingColumnHandle("dummy"),
COLUMN_HANDLE,
Domain.create(ValueSet.copyOf(type, valuesBuilder.build()), nullsAllowed)));
}

Expand All @@ -132,12 +139,9 @@ private List<Page> createInputTestData(
long inputRows)
{
List<Object> nonNullValues = filter.getDomains().orElseThrow()
.values().stream()
.flatMap(domain -> {
DiscreteSet nullableDiscreteSet = domain.getNullableDiscreteSet();
return nullableDiscreteSet.getNonNullValues().stream();
})
.collect(toImmutableList());
.get(COLUMN_HANDLE)
.getNullableDiscreteSet()
.getNonNullValues();

// pick a random value from the filter
return createSingleColumnData(
Expand All @@ -163,7 +167,7 @@ public void setup()
inputData = inputDataSet.createInputTestData(filterPredicate, inputNullChance, nonNullsSelectivity, MAX_ROWS);
filterEvaluator = createDynamicFilterEvaluator(
filterPredicate,
ImmutableMap.of(new TestingColumnHandle("dummy"), 0),
ImmutableMap.of(COLUMN_HANDLE, 0),
1);
}

Expand Down Expand Up @@ -199,7 +203,7 @@ private static List<Page> createSingleColumnData(ValueWriter valueWriter, Type t
if (blockBuilder.getPositionCount() >= batchSize) {
Block block = blockBuilder.build();
pages.add(new Page(new LazyBlock(block.getPositionCount(), () -> block)));
batchSize = Math.min(1024, batchSize * 2);
batchSize = Math.min(8192, batchSize * 2);
blockBuilder = type.createBlockBuilder(null, batchSize);
}
}
Expand Down
Loading

0 comments on commit 13b8ccd

Please sign in to comment.