Skip to content

Commit

Permalink
[enhancement](Nereids) fast compute hash code of deep expression tree…
Browse files Browse the repository at this point in the history
… to reduce conflict (apache#38981)

The Expression.hashCode default is getClass().hashCode(), just contains
one level information, so the lots of expressions which is same type
will return the same hash code and conflict, then compare deeply in the
HashMap cause inefficient and hold table lock for long time.

This pr support fast compute hash code by the bottom literal and slot,
reduce the compare expression time because of the conflict of hash code

In my test case, the sql planner time can reduce from 20 minutes(not
finished) to 35 seconds
  • Loading branch information
924060929 authored Aug 8, 2024
1 parent 3713ee5 commit d5049ec
Show file tree
Hide file tree
Showing 9 changed files with 58 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,15 @@ public Integer visit(Expression expr, Void context) {
if (expr.children().isEmpty()) {
return 0;
}
return collectCommonExpressionByDepth(expr.children().stream().map(child ->
child.accept(this, context)).reduce(Math::max).map(m -> m + 1).orElse(1), expr);
return collectCommonExpressionByDepth(
expr.children()
.stream()
.map(child -> child.accept(this, context))
.reduce(Math::max)
.map(m -> m + 1)
.orElse(1),
expr
);
}

private int collectCommonExpressionByDepth(int depth, Expression expr) {
Expand All @@ -53,7 +60,6 @@ private int collectCommonExpressionByDepth(int depth, Expression expr) {

public static Set<Expression> getExpressionsFromDepthMap(
int depth, Map<Integer, Set<Expression>> depthMap) {
depthMap.putIfAbsent(depth, new LinkedHashSet<>());
return depthMap.get(depth);
return depthMap.computeIfAbsent(depth, d -> new LinkedHashSet<>());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionVisitor;
import org.apache.doris.nereids.util.ExpressionUtils;

import java.util.HashSet;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Set;

Expand All @@ -39,9 +39,9 @@
*/
public class PredicatesSplitter {

private final Set<Expression> equalPredicates = new HashSet<>();
private final Set<Expression> rangePredicates = new HashSet<>();
private final Set<Expression> residualPredicates = new HashSet<>();
private final Set<Expression> equalPredicates = new LinkedHashSet<>();
private final Set<Expression> rangePredicates = new LinkedHashSet<>();
private final Set<Expression> residualPredicates = new LinkedHashSet<>();
private final List<Expression> conjunctExpressions;

public PredicatesSplitter(Expression target) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ private List<Expression> removeGeneratedNotNull(Collection<Expression> exprs, Ca
// predicatesNotContainIsNotNull infer nonNullable slots: `id`
// slotsFromIsNotNull: `id`, `name`
// remove `name` (it's generated), remove `id` (because `id > 0` already contains it)
Set<Expression> predicatesNotContainIsNotNull = Sets.newHashSet();
Set<Expression> predicatesNotContainIsNotNull = Sets.newLinkedHashSet();
List<Slot> slotsFromIsNotNull = Lists.newArrayList();

for (Expression expr : exprs) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ public abstract class Expression extends AbstractTreeNode<Expression> implements
private final boolean compareWidthAndDepth;
private final Supplier<Set<Slot>> inputSlots = Suppliers.memoize(
() -> collect(e -> e instanceof Slot && !(e instanceof ArrayItemSlot)));
private final int fastChildrenHashCode;

protected Expression(Expression... children) {
super(children);
Expand All @@ -80,12 +81,14 @@ protected Expression(Expression... children) {
this.depth = 1;
this.width = 1;
this.compareWidthAndDepth = supportCompareWidthAndDepth();
this.fastChildrenHashCode = 0;
break;
case 1:
Expression child = children[0];
this.depth = child.depth + 1;
this.width = child.width;
this.compareWidthAndDepth = child.compareWidthAndDepth && supportCompareWidthAndDepth();
this.fastChildrenHashCode = child.fastChildrenHashCode() + 1;
break;
case 2:
Expression left = children[0];
Expand All @@ -94,21 +97,25 @@ protected Expression(Expression... children) {
this.width = left.width + right.width;
this.compareWidthAndDepth =
left.compareWidthAndDepth && right.compareWidthAndDepth && supportCompareWidthAndDepth();
this.fastChildrenHashCode = left.fastChildrenHashCode() + right.fastChildrenHashCode() + 2;
break;
default:
int maxChildDepth = 0;
int sumChildWidth = 0;
boolean compareWidthAndDepth = true;
int fastChildrenHashCode = 0;
for (Expression expression : children) {
child = expression;
maxChildDepth = Math.max(child.depth, maxChildDepth);
sumChildWidth += child.width;
hasUnbound |= child.hasUnbound;
compareWidthAndDepth &= child.compareWidthAndDepth;
fastChildrenHashCode = fastChildrenHashCode + expression.fastChildrenHashCode() + 1;
}
this.depth = maxChildDepth + 1;
this.width = sumChildWidth;
this.compareWidthAndDepth = compareWidthAndDepth;
this.fastChildrenHashCode = fastChildrenHashCode;
}

checkLimit();
Expand All @@ -129,12 +136,14 @@ protected Expression(List<Expression> children, boolean inferred) {
this.depth = 1;
this.width = 1;
this.compareWidthAndDepth = supportCompareWidthAndDepth();
this.fastChildrenHashCode = 0;
break;
case 1:
Expression child = children.get(0);
this.depth = child.depth + 1;
this.width = child.width;
this.compareWidthAndDepth = child.compareWidthAndDepth && supportCompareWidthAndDepth();
this.fastChildrenHashCode = child.fastChildrenHashCode() + 1;
break;
case 2:
Expression left = children.get(0);
Expand All @@ -143,21 +152,25 @@ protected Expression(List<Expression> children, boolean inferred) {
this.width = left.width + right.width;
this.compareWidthAndDepth =
left.compareWidthAndDepth && right.compareWidthAndDepth && supportCompareWidthAndDepth();
this.fastChildrenHashCode = left.fastChildrenHashCode() + right.fastChildrenHashCode() + 2;
break;
default:
int maxChildDepth = 0;
int sumChildWidth = 0;
boolean compareWidthAndDepth = true;
int fastChildrenhashCode = 0;
for (Expression expression : children) {
child = expression;
maxChildDepth = Math.max(child.depth, maxChildDepth);
sumChildWidth += child.width;
hasUnbound |= child.hasUnbound;
compareWidthAndDepth &= child.compareWidthAndDepth;
fastChildrenhashCode = fastChildrenhashCode + expression.fastChildrenHashCode() + 1;
}
this.depth = maxChildDepth + 1;
this.width = sumChildWidth;
this.compareWidthAndDepth = compareWidthAndDepth && supportCompareWidthAndDepth();
this.fastChildrenHashCode = fastChildrenhashCode;
}

checkLimit();
Expand Down Expand Up @@ -211,6 +224,10 @@ public TypeCheckResult checkInputDataTypes() {
return checkInputDataTypesInternal();
}

public int fastChildrenHashCode() {
return fastChildrenHashCode;
}

protected TypeCheckResult checkInputDataTypesInternal() {
return TypeCheckResult.SUCCESS;
}
Expand Down Expand Up @@ -406,7 +423,9 @@ public boolean equals(Object o) {
return false;
}
Expression that = (Expression) o;
if ((compareWidthAndDepth && (this.width != that.width || this.depth != that.depth))
if ((compareWidthAndDepth
&& (this.width != that.width || this.depth != that.depth
|| this.fastChildrenHashCode != that.fastChildrenHashCode))
|| arity() != that.arity() || !extraEquals(that)) {
return false;
}
Expand All @@ -430,7 +449,7 @@ protected boolean extraEquals(Expression that) {

@Override
public int hashCode() {
return getClass().hashCode();
return getClass().hashCode() + fastChildrenHashCode();
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,11 +60,21 @@ public boolean nullable() {
return true;
}

@Override
public String toString() {
return "$" + placeholderId.asInt();
}

@Override
public String toSql() {
return "?";
}

@Override
public int fastChildrenHashCode() {
return placeholderId.asInt();
}

@Override
public DataType getDataType() throws UnboundException {
return NullType.INSTANCE;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,11 @@ public int hashCode() {
return exprId.asInt();
}

@Override
public int fastChildrenHashCode() {
return exprId.asInt();
}

@Override
public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {
return visitor.visitSlotReference(this, context);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,11 @@ public int hashCode() {
return Objects.hashCode(getValue());
}

@Override
public int fastChildrenHashCode() {
return Objects.hashCode(getValue());
}

@Override
public String toString() {
return String.valueOf(getValue());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ private static void updateScanNodeConjuncts(OlapScanNode scanNode, List<Expr> co
} else if (binaryPredicate.getChild(1) instanceof LiteralExpr) {
binaryPredicate.setChild(1, conjunctVals.get(i));
} else {
Preconditions.checkState(false, "Should conatains literal in " + binaryPredicate.toSqlImpl());
Preconditions.checkState(false, "Should contains literal in " + binaryPredicate.toSqlImpl());
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ private void assetEquals(String expression,
String expectedRangeExpr,
String expectedResidualExpr) {

Map<String, Slot> mem = Maps.newHashMap();
Map<String, Slot> mem = Maps.newLinkedHashMap();
Expression targetExpr = replaceUnboundSlot(PARSER.parseExpression(expression), mem);
SplitPredicate splitPredicate = Predicates.splitPredicates(targetExpr);

Expand Down

0 comments on commit d5049ec

Please sign in to comment.