diff --git a/sql/analyzer/prune_columns.go b/sql/analyzer/prune_columns.go index b485dec6c..9274541f5 100644 --- a/sql/analyzer/prune_columns.go +++ b/sql/analyzer/prune_columns.go @@ -29,12 +29,7 @@ func pruneColumns(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) { findUsedColumns(columns, n) - n, err := addSubqueryBarriers(n) - if err != nil { - return nil, err - } - - n, err = pruneUnusedColumns(n, columns) + n, err := pruneUnusedColumns(n, columns) if err != nil { return nil, err } @@ -81,12 +76,7 @@ func pruneSubqueryColumns( findUsedColumns(columns, n.Child) - node, err := addSubqueryBarriers(n.Child) - if err != nil { - return nil, err - } - - node, err = pruneUnusedColumns(node, columns) + node, err := pruneUnusedColumns(n.Child, columns) if err != nil { return nil, err } @@ -126,17 +116,6 @@ func findUsedColumns(columns usedColumns, n sql.Node) { }) } -func addSubqueryBarriers(n sql.Node) (sql.Node, error) { - return n.TransformUp(func(n sql.Node) (sql.Node, error) { - sq, ok := n.(*plan.SubqueryAlias) - if !ok { - return n, nil - } - - return &subqueryBarrier{sq}, nil - }) -} - func pruneSubqueries( ctx *sql.Context, a *Analyzer, @@ -144,12 +123,12 @@ func pruneSubqueries( parentColumns usedColumns, ) (sql.Node, error) { return n.TransformUp(func(n sql.Node) (sql.Node, error) { - barrier, ok := n.(*subqueryBarrier) + subq, ok := n.(*plan.SubqueryAlias) if !ok { return n, nil } - return pruneSubqueryColumns(ctx, a, barrier.SubqueryAlias, parentColumns) + return pruneSubqueryColumns(ctx, a, subq, parentColumns) }) } @@ -173,39 +152,53 @@ type tableColumnPair struct { func fixRemainingFieldsIndexes(n sql.Node) (sql.Node, error) { return n.TransformUp(func(n sql.Node) (sql.Node, error) { - exp, ok := n.(sql.Expressioner) - if !ok { - return n, nil - } - - var schema sql.Schema - for _, c := range n.Children() { - schema = append(schema, c.Schema()...) - } + switch n := n.(type) { + case *plan.SubqueryAlias: + child, err := fixRemainingFieldsIndexes(n.Child) + if err != nil { + return nil, err + } - if len(schema) == 0 { - return n, nil - } + return plan.NewSubqueryAlias(n.Name(), child), nil + default: + exp, ok := n.(sql.Expressioner) + if !ok { + return n, nil + } - indexes := make(map[tableColumnPair]int) - for i, col := range schema { - indexes[tableColumnPair{col.Source, col.Name}] = i - } + var schema sql.Schema + for _, c := range n.Children() { + schema = append(schema, c.Schema()...) + } - return exp.TransformExpressions(func(e sql.Expression) (sql.Expression, error) { - gf, ok := e.(*expression.GetField) - if !ok { - return e, nil + if len(schema) == 0 { + return n, nil } - idx, ok := indexes[tableColumnPair{gf.Table(), gf.Name()}] - if !ok { - return nil, fmt.Errorf("unable to find column %q of table %q", gf.Name(), gf.Table()) + indexes := make(map[tableColumnPair]int) + for i, col := range schema { + indexes[tableColumnPair{col.Source, col.Name}] = i } - ngf := *gf - return ngf.WithIndex(idx), nil - }) + return exp.TransformExpressions(func(e sql.Expression) (sql.Expression, error) { + gf, ok := e.(*expression.GetField) + if !ok { + return e, nil + } + + idx, ok := indexes[tableColumnPair{gf.Table(), gf.Name()}] + if !ok { + return nil, fmt.Errorf("unable to find column %q of table %q", gf.Name(), gf.Table()) + } + + if idx == gf.Index() { + return gf, nil + } + + ngf := *gf + return ngf.WithIndex(idx), nil + }) + } }) } @@ -290,11 +283,3 @@ func shouldPruneExpr(e sql.Expression, cols usedColumns) bool { return true } - -type subqueryBarrier struct { - *plan.SubqueryAlias -} - -func (b *subqueryBarrier) TransformUp(f sql.TransformNodeFunc) (sql.Node, error) { - return f(b) -} diff --git a/sql/analyzer/prune_columns_test.go b/sql/analyzer/prune_columns_test.go index 882486871..fb98bb2bd 100644 --- a/sql/analyzer/prune_columns_test.go +++ b/sql/analyzer/prune_columns_test.go @@ -263,7 +263,7 @@ func TestPruneColumns(t *testing.T) { ), expression.NewEquals( gf(0, "t1", "foo"), - gf(3, "t2", "foo"), + gf(1, "t2", "foo"), ), ), ),