diff --git a/internal/formatter.go b/internal/formatter.go index cb6f323..25a0df8 100644 --- a/internal/formatter.go +++ b/internal/formatter.go @@ -800,12 +800,6 @@ func (n *FilterScanNode) FormatSQL(ctx context.Context) (string, error) { if err != nil { return "", err } - nodeMap := nodeMapFromContext(ctx) - for _, node := range nodeMap.FindNodeFromResolvedNode(n.node) { - if _, ok := node.(*parsed_ast.HavingNode); ok { - return fmt.Sprintf("%s HAVING %s", input, filter), nil - } - } currentQuery := string(removeExpressions.ReplaceAllString(input, "")) // Qualify the statement if the input is not wrapped in parens @@ -980,15 +974,31 @@ func (n *SetOperationScanNode) FormatSQL(ctx context.Context) (string, error) { case ast.SetOperationTypeExceptDistinct: opType = "EXCEPT" default: - opType = "UNKONWN" + opType = "UNKNOWN" } var queries []string for _, item := range n.node.InputItemList() { + var outputColumns []string + for _, outputColumn := range item.OutputColumnList() { + outputColumns = append(outputColumns, fmt.Sprintf("`%s`", uniqueColumnName(ctx, outputColumn))) + } query, err := newNode(item).FormatSQL(ctx) if err != nil { return "", err } - queries = append(queries, query) + + formattedInput, err := formatInput(query) + if err != nil { + return "", err + } + + queries = append( + queries, + fmt.Sprintf("SELECT %s %s", + strings.Join(outputColumns, ", "), + formattedInput, + ), + ) } columnMaps := []string{} if len(n.node.InputItemList()) != 0 { @@ -1354,11 +1364,35 @@ func (n *ExplainStmtNode) FormatSQL(ctx context.Context) (string, error) { return "", nil } +// FormatSQL Formats the outermost query statement that runs and produces rows of output, like a SELECT +// The node's `OutputColumnList()` gives user-visible column names that should be returned. There may be duplicate names, +// and multiple output columns may reference the same column from `Query()` +// https://github.com/google/zetasql/blob/master/docs/resolved_ast.md#ResolvedQueryStmt func (n *QueryStmtNode) FormatSQL(ctx context.Context) (string, error) { if n.node == nil { return "", nil } - return newNode(n.node.Query()).FormatSQL(ctx) + input, err := newNode(n.node.Query()).FormatSQL(ctx) + if err != nil { + return "", err + } + + var columns []string + for _, outputColumnNode := range n.node.OutputColumnList() { + columns = append( + columns, + fmt.Sprintf("`%s` AS `%s`", + uniqueColumnName(ctx, outputColumnNode.Column()), + outputColumnNode.Name(), + ), + ) + } + + return fmt.Sprintf( + "SELECT %s FROM (%s)", + strings.Join(columns, ", "), + input, + ), nil } func (n *CreateDatabaseStmtNode) FormatSQL(ctx context.Context) (string, error) { diff --git a/query_test.go b/query_test.go index 4905f7c..929bb9d 100644 --- a/query_test.go +++ b/query_test.go @@ -39,6 +39,27 @@ func TestQuery(t *testing.T) { expectedRows [][]interface{} expectedErr string }{ + // Regression test for https://github.com/goccy/go-zetasqlite/issues/191 + { + name: "distinct union", + query: `WITH toks AS (SELECT true AS x, 1 AS y) + SELECT DISTINCT x, x as y FROM toks`, + expectedRows: [][]interface{}{{true, true}}, + }, + { + name: "with scan union all", + query: `(WITH toks AS (SELECT 1 AS x) SELECT x FROM toks) +UNION ALL +(WITH toks2 AS (SELECT 2 AS x) SELECT x FROM toks2)`, + expectedRows: [][]interface{}{{int64(1)}, {int64(2)}}, + }, + { + name: "having with union all", + query: `(WITH toks AS (SELECT 1 AS x) SELECT COUNT(x) AS total_rows FROM toks WHERE x > 0 HAVING total_rows >= 0) +UNION ALL +(WITH toks2 AS (SELECT 2 AS x) SELECT COUNT(x) AS total_rows FROM toks2 WHERE x > 0 HAVING total_rows >= 0)`, + expectedRows: [][]interface{}{{int64(1)}, {int64(1)}}, + }, // priority 2 operator { name: "unary plus operator",