Skip to content

Commit

Permalink
use the new faster Merge method
Browse files Browse the repository at this point in the history
Signed-off-by: Andres Taylor <[email protected]>
  • Loading branch information
systay committed Jan 17, 2025
1 parent dddf936 commit b16f792
Show file tree
Hide file tree
Showing 13 changed files with 71 additions and 48 deletions.
7 changes: 4 additions & 3 deletions go/vt/vtgate/planbuilder/operators/ast_to_op.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,17 +120,18 @@ func cloneASTAndSemState[T sqlparser.SQLNode](ctx *plancontext.PlanningContext,
}

// findTablesContained returns the TableSet of all the contained
func findTablesContained(ctx *plancontext.PlanningContext, node sqlparser.SQLNode) (result semantics.TableSet) {
func findTablesContained(ctx *plancontext.PlanningContext, node sqlparser.SQLNode) semantics.TableSet {
var tables []semantics.TableSet
_ = sqlparser.Walk(func(node sqlparser.SQLNode) (bool, error) {
t, ok := node.(*sqlparser.AliasedTableExpr)
if !ok {
return true, nil
}
ts := ctx.SemTable.TableSetFor(t)
result = result.Merge(ts)
tables = append(tables, ts)
return true, nil
}, node)
return
return semantics.MergeTableSets(tables...)
}

// joinPredicateCollector is used to inspect the predicates inside the subquery, looking for any
Expand Down
7 changes: 4 additions & 3 deletions go/vt/vtgate/planbuilder/operators/helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,14 +71,15 @@ type tableIDIntroducer interface {
introducesTableID() semantics.TableSet
}

func TableID(op Operator) (result semantics.TableSet) {
func TableID(op Operator) semantics.TableSet {
var tables []semantics.TableSet
_ = Visit(op, func(this Operator) error {
if tbl, ok := this.(tableIDIntroducer); ok {
result = result.Merge(tbl.introducesTableID())
tables = append(tables, tbl.introducesTableID())
}
return nil
})
return
return semantics.MergeTableSets(tables...)
}

// TableUser is used to signal that this operator directly interacts with one or more tables
Expand Down
6 changes: 3 additions & 3 deletions go/vt/vtgate/planbuilder/operators/querygraph.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,11 +66,11 @@ var _ Operator = (*QueryGraph)(nil)

// Introduces implements the tableIDIntroducer interface
func (qg *QueryGraph) introducesTableID() semantics.TableSet {
var ts semantics.TableSet
var ts []semantics.TableSet
for _, table := range qg.Tables {
ts = ts.Merge(table.ID)
ts = append(ts, table.ID)
}
return ts
return semantics.MergeTableSets(ts...)
}

// GetPredicates returns the predicates that are applicable for the two given TableSets
Expand Down
6 changes: 3 additions & 3 deletions go/vt/vtgate/planbuilder/operators/route.go
Original file line number Diff line number Diff line change
Expand Up @@ -806,9 +806,9 @@ func (r *Route) getTruncateColumnCount() int {
}

func (r *Route) introducesTableID() semantics.TableSet {
id := semantics.EmptyTableSet()
var ts []semantics.TableSet
for _, route := range r.MergedWith {
id = id.Merge(TableID(route))
ts = append(ts, TableID(route))
}
return id
return semantics.MergeTableSets(ts...)
}
6 changes: 3 additions & 3 deletions go/vt/vtgate/planbuilder/operators/subquery_planning.go
Original file line number Diff line number Diff line change
Expand Up @@ -228,11 +228,11 @@ func tryPushSubQueryInJoin(
// we want to push the subquery as close to its needs
// as possible, so that we can potentially merge them together
// TODO: we need to check dependencies and break apart all expressions in the subquery, not just the merge predicates
deps := semantics.EmptyTableSet()
var ts []semantics.TableSet
for _, predicate := range inner.GetMergePredicates() {
deps = deps.Merge(ctx.SemTable.RecursiveDeps(predicate))
ts = append(ts, ctx.SemTable.RecursiveDeps(predicate))
}
deps = deps.Remove(innerID)
deps := semantics.MergeTableSets(ts...).Remove(innerID)

// in general, we don't want to push down uncorrelated subqueries into the RHS of a join,
// since this side is executed once per row from the LHS, so we would unnecessarily execute
Expand Down
21 changes: 13 additions & 8 deletions go/vt/vtgate/semantics/binder.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,13 +91,15 @@ func (b *binder) bindTableNames(cursor *sqlparser.Cursor, tables sqlparser.Table
return nil
}
current := b.scoper.currentScope()
var targets []TableSet
for _, target := range tables {
finalDep, err := b.findDependentTableSet(current, target)
if err != nil {
return err
}
b.targets = b.targets.Merge(finalDep.direct)
targets = append(targets, finalDep.direct)
}
b.targets = MergeTableSets(targets...)
return nil
}

Expand Down Expand Up @@ -184,19 +186,21 @@ func (b *binder) findDependentTableSet(current *scope, target sqlparser.TableNam

func (b *binder) bindCountStar(node *sqlparser.CountStar) error {
scope := b.scoper.currentScope()
var ts TableSet
var deps []TableSet
for _, tbl := range scope.tables {
switch tbl := tbl.(type) {
case *vTableInfo:
for _, col := range tbl.cols {
if sqlparser.Equals.Expr(node, col) {
ts = ts.Merge(b.recursive[col])
deps = append(deps, b.recursive[col])
}
}
default:
ts = ts.Merge(tbl.getTableSet(b.org))
deps = append(deps, tbl.getTableSet(b.org))
}
}

ts := MergeTableSets(deps...)
b.recursive[node] = ts
b.direct[node] = ts
return nil
Expand Down Expand Up @@ -246,17 +250,18 @@ func (b *binder) setSubQueryDependencies(subq *sqlparser.Subquery) error {
subqRecursiveDeps := b.recursive.dependencies(subq)
subqDirectDeps := b.direct.dependencies(subq)

tablesToKeep := EmptyTableSet()
var tablesToKeep []TableSet
sco := currScope
for sco != nil {
for _, table := range sco.tables {
tablesToKeep = tablesToKeep.Merge(table.getTableSet(b.org))
tablesToKeep = append(tablesToKeep, table.getTableSet(b.org))
}
sco = sco.parent
}

b.recursive[subq] = subqRecursiveDeps.KeepOnly(tablesToKeep)
b.direct[subq] = subqDirectDeps.KeepOnly(tablesToKeep)
keep := MergeTableSets(tablesToKeep...)
b.recursive[subq] = subqRecursiveDeps.KeepOnly(keep)
b.direct[subq] = subqDirectDeps.KeepOnly(keep)
return nil
}

Expand Down
8 changes: 4 additions & 4 deletions go/vt/vtgate/semantics/cte_table.go
Original file line number Diff line number Diff line change
Expand Up @@ -163,20 +163,20 @@ type CTE struct {
Merged bool
}

func (cte *CTE) recursive(org originable) (id TableSet) {
func (cte *CTE) recursive(org originable) TableSet {
if cte.recursiveDeps != nil {
return *cte.recursiveDeps
}

var tables []TableSet
// We need to find the recursive dependencies of the CTE
// We'll do this by walking the inner query and finding all the tables
_ = sqlparser.Walk(func(node sqlparser.SQLNode) (kontinue bool, err error) {
ate, ok := node.(*sqlparser.AliasedTableExpr)
if !ok {
return true, nil
}
id = id.Merge(org.tableSetFor(ate))
tables = append(tables, org.tableSetFor(ate))
return true, nil
}, cte.Query)
return
return MergeTableSets(tables...)
}
5 changes: 4 additions & 1 deletion go/vt/vtgate/semantics/derived_table.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,12 +92,15 @@ func handleAliasedExpr(vTbl *DerivedTable, expr *sqlparser.AliasedExpr, cols sql
}

func handleUnexpandedStarExpression(tables []TableInfo, vTbl *DerivedTable, org originable) {
var tableSets []TableSet
for _, table := range tables {
vTbl.tables = vTbl.tables.Merge(table.getTableSet(org))
ts := table.getTableSet(org)
tableSets = append(tableSets, ts)
if !table.authoritative() {
vTbl.isAuthoritative = false
}
}
vTbl.tables = MergeTableSets(tableSets...)
}

// dependencies implements the TableInfo interface
Expand Down
7 changes: 3 additions & 4 deletions go/vt/vtgate/semantics/scoper.go
Original file line number Diff line number Diff line change
Expand Up @@ -228,12 +228,11 @@ func (s *scoper) up(cursor *sqlparser.Cursor) error {
s.popScope()
}
case *sqlparser.Select, *sqlparser.GroupBy, *sqlparser.Update, *sqlparser.Insert, *sqlparser.Union, *sqlparser.Delete:
id := EmptyTableSet()
var tables []TableSet
for _, tableInfo := range s.currentScope().tables {
set := tableInfo.getTableSet(s.org)
id = id.Merge(set)
tables = append(tables, tableInfo.getTableSet(s.org))
}
s.statementIDs[s.currentScope().stmt] = id
s.statementIDs[s.currentScope().stmt] = MergeTableSets(tables...)
s.popScope()
case *sqlparser.Where:
if node.Type != sqlparser.HavingClause {
Expand Down
11 changes: 9 additions & 2 deletions go/vt/vtgate/semantics/semantic_table.go
Original file line number Diff line number Diff line change
Expand Up @@ -697,11 +697,14 @@ func (d ExprDependencies) dependencies(expr sqlparser.Expr) (deps TableSet) {
if found {
return deps
}

// If we did not find the expression in the cache, we'll add it after calculating it
defer func() {
d[expr] = deps
}()
}

var tables []TableSet
// During the original semantic analysis, all ColNames were found and bound to the corresponding tables
// Here, we'll walk the expression tree and look to see if we can find any sub-expressions
// that have already set dependencies.
Expand All @@ -714,13 +717,17 @@ func (d ExprDependencies) dependencies(expr sqlparser.Expr) (deps TableSet) {
}

set, found := d[expr]
deps = deps.Merge(set)
if found {
tables = append(tables, set)
}

// if we found a cached value, there is no need to continue down to visit children
return !found, nil
}, expr)

return deps
deps = MergeTableSets(tables...)

return
}

// RewriteDerivedTableExpression rewrites all the ColName instances in the supplied expression with
Expand Down
10 changes: 6 additions & 4 deletions go/vt/vtgate/semantics/table_collector.go
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ func (tc *tableCollector) visitUnion(union *sqlparser.Union) error {
}

size := firstSelect.GetColumnCount()
info.recursive = make([]TableSet, size)
recursiveDeps := make([][]TableSet, size)
typers := make([]evalengine.TypeAggregator, size)
collations := tc.org.collationEnv()

Expand All @@ -179,8 +179,8 @@ func (tc *tableCollector) visitUnion(union *sqlparser.Union) error {
if !ok {
continue
}
_, recursiveDeps, qt := tc.org.depsForExpr(ae.Expr)
info.recursive[i] = info.recursive[i].Merge(recursiveDeps)
_, deps, qt := tc.org.depsForExpr(ae.Expr)
recursiveDeps[i] = append(recursiveDeps[i], deps)
if err := typers[i].Add(qt, collations); err != nil {
return err
}
Expand All @@ -191,8 +191,10 @@ func (tc *tableCollector) visitUnion(union *sqlparser.Union) error {
return err
}

for _, ts := range typers {
info.recursive = make([]TableSet, size)
for i, ts := range typers {
info.types = append(info.types, ts.Type())
info.recursive[i] = MergeTableSets(recursiveDeps[i]...)
}
tc.unionInfo[union] = info
return nil
Expand Down
21 changes: 12 additions & 9 deletions go/vt/vtgate/semantics/table_set.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package semantics

import (
"fmt"
"unsafe"

"vitess.io/vitess/go/vt/vtgate/semantics/bitset"
)
Expand Down Expand Up @@ -115,15 +116,17 @@ func EmptyTableSet() TableSet {
}

// MergeTableSets merges all the given TableSet into a single one
func MergeTableSets(tss ...TableSet) TableSet {
var result bitset.Bitset
for _, t := range tss {
result = result.Or(bitset.Bitset(t))
func MergeTableSets(tableSets ...TableSet) TableSet {
if len(tableSets) == 0 {
return ""
}
return TableSet(result)
}

// TableSetFromIds returns TableSet for all the id passed in argument.
func TableSetFromIds(tids ...int) (ts TableSet) {
return TableSet(bitset.Build(tids...))
// Trick: re-interpret slice header of tableSets as []Bitset
// This is safe because the memory layout of []TableSet and []Bitset is the same
// The alternative would be to loop over all TableSets and convert them to Bitsets, but that would be slower
bs := *(*[]bitset.Bitset)(unsafe.Pointer(&tableSets))

// Now pass it to MergeMany
merged := bitset.Merge(bs...)
return TableSet(merged)
}
4 changes: 3 additions & 1 deletion go/vt/vtgate/semantics/vtable.go
Original file line number Diff line number Diff line change
Expand Up @@ -165,12 +165,14 @@ func selectExprsToInfos(
colNames = append(colNames, expr.As.String())
}
case *sqlparser.StarExpr:
var tableSets []TableSet
for _, table := range tables {
ts = ts.Merge(table.getTableSet(org))
tableSets = append(tableSets, table.getTableSet(org))
if !table.authoritative() {
isAuthoritative = false
}
}
ts = MergeTableSets(tableSets...)
}
}
return
Expand Down

0 comments on commit b16f792

Please sign in to comment.