Skip to content

Commit

Permalink
Fix QueryStmtNode output; wrap unoin statements in parentheses (#41)
Browse files Browse the repository at this point in the history
  • Loading branch information
ohaibbq authored Apr 7, 2024
1 parent debde73 commit 7e669c6
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 9 deletions.
52 changes: 43 additions & 9 deletions internal/formatter.go
Original file line number Diff line number Diff line change
Expand Up @@ -800,12 +800,6 @@ func (n *FilterScanNode) FormatSQL(ctx context.Context) (string, error) {
if err != nil {
return "", err
}
nodeMap := nodeMapFromContext(ctx)
for _, node := range nodeMap.FindNodeFromResolvedNode(n.node) {
if _, ok := node.(*parsed_ast.HavingNode); ok {
return fmt.Sprintf("%s HAVING %s", input, filter), nil
}
}
currentQuery := string(removeExpressions.ReplaceAllString(input, ""))

// Qualify the statement if the input is not wrapped in parens
Expand Down Expand Up @@ -980,15 +974,31 @@ func (n *SetOperationScanNode) FormatSQL(ctx context.Context) (string, error) {
case ast.SetOperationTypeExceptDistinct:
opType = "EXCEPT"
default:
opType = "UNKONWN"
opType = "UNKNOWN"
}
var queries []string
for _, item := range n.node.InputItemList() {
var outputColumns []string
for _, outputColumn := range item.OutputColumnList() {
outputColumns = append(outputColumns, fmt.Sprintf("`%s`", uniqueColumnName(ctx, outputColumn)))
}
query, err := newNode(item).FormatSQL(ctx)
if err != nil {
return "", err
}
queries = append(queries, query)

formattedInput, err := formatInput(query)
if err != nil {
return "", err
}

queries = append(
queries,
fmt.Sprintf("SELECT %s %s",
strings.Join(outputColumns, ", "),
formattedInput,
),
)
}
columnMaps := []string{}
if len(n.node.InputItemList()) != 0 {
Expand Down Expand Up @@ -1354,11 +1364,35 @@ func (n *ExplainStmtNode) FormatSQL(ctx context.Context) (string, error) {
return "", nil
}

// FormatSQL Formats the outermost query statement that runs and produces rows of output, like a SELECT
// The node's `OutputColumnList()` gives user-visible column names that should be returned. There may be duplicate names,
// and multiple output columns may reference the same column from `Query()`
// https://github.com/google/zetasql/blob/master/docs/resolved_ast.md#ResolvedQueryStmt
func (n *QueryStmtNode) FormatSQL(ctx context.Context) (string, error) {
if n.node == nil {
return "", nil
}
return newNode(n.node.Query()).FormatSQL(ctx)
input, err := newNode(n.node.Query()).FormatSQL(ctx)
if err != nil {
return "", err
}

var columns []string
for _, outputColumnNode := range n.node.OutputColumnList() {
columns = append(
columns,
fmt.Sprintf("`%s` AS `%s`",
uniqueColumnName(ctx, outputColumnNode.Column()),
outputColumnNode.Name(),
),
)
}

return fmt.Sprintf(
"SELECT %s FROM (%s)",
strings.Join(columns, ", "),
input,
), nil
}

func (n *CreateDatabaseStmtNode) FormatSQL(ctx context.Context) (string, error) {
Expand Down
21 changes: 21 additions & 0 deletions query_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,27 @@ func TestQuery(t *testing.T) {
expectedRows [][]interface{}
expectedErr string
}{
// Regression test for https://github.com/goccy/go-zetasqlite/issues/191
{
name: "distinct union",
query: `WITH toks AS (SELECT true AS x, 1 AS y)
SELECT DISTINCT x, x as y FROM toks`,
expectedRows: [][]interface{}{{true, true}},
},
{
name: "with scan union all",
query: `(WITH toks AS (SELECT 1 AS x) SELECT x FROM toks)
UNION ALL
(WITH toks2 AS (SELECT 2 AS x) SELECT x FROM toks2)`,
expectedRows: [][]interface{}{{int64(1)}, {int64(2)}},
},
{
name: "having with union all",
query: `(WITH toks AS (SELECT 1 AS x) SELECT COUNT(x) AS total_rows FROM toks WHERE x > 0 HAVING total_rows >= 0)
UNION ALL
(WITH toks2 AS (SELECT 2 AS x) SELECT COUNT(x) AS total_rows FROM toks2 WHERE x > 0 HAVING total_rows >= 0)`,
expectedRows: [][]interface{}{{int64(1)}, {int64(1)}},
},
// priority 2 operator
{
name: "unary plus operator",
Expand Down

0 comments on commit 7e669c6

Please sign in to comment.