From c3b29479663a7f789be06353127fc08a8f68b9db Mon Sep 17 00:00:00 2001 From: Ryo Kitagawa Date: Sat, 14 Dec 2024 15:53:56 +0900 Subject: [PATCH] feat: support QUALIFY statement --- langserver/internal/source/file/analyze.go | 103 +++++++------------ langserver/internal/source/file/file_test.go | 18 ++++ 2 files changed, 54 insertions(+), 67 deletions(-) diff --git a/langserver/internal/source/file/analyze.go b/langserver/internal/source/file/analyze.go index 7c478d2..223485b 100644 --- a/langserver/internal/source/file/analyze.go +++ b/langserver/internal/source/file/analyze.go @@ -8,7 +8,6 @@ import ( bq "cloud.google.com/go/bigquery" "github.com/goccy/go-zetasql" "github.com/goccy/go-zetasql/ast" - rast "github.com/goccy/go-zetasql/resolved_ast" "github.com/goccy/go-zetasql/types" "github.com/kitagry/bqls/langserver/internal/bigquery" "github.com/kitagry/bqls/langserver/internal/lsp" @@ -32,69 +31,29 @@ func NewAnalyzer(logger *logrus.Logger, bqClient bigquery.Client) *Analyzer { } } -func (a *Analyzer) AnalyzeStatement(rawText string, stmt ast.StatementNode, catalog types.Catalog) (*zetasql.AnalyzerOutput, error) { +func (a *Analyzer) langOpt() (*zetasql.LanguageOptions, error) { langOpt := zetasql.NewLanguageOptions() langOpt.SetNameResolutionMode(zetasql.NameResolutionDefault) langOpt.SetProductMode(types.ProductExternal) - langOpt.SetEnabledLanguageFeatures([]zetasql.LanguageFeature{ - zetasql.FeatureAnalyticFunctions, - zetasql.FeatureNamedArguments, - zetasql.FeatureNumericType, - zetasql.FeatureBignumericType, - zetasql.FeatureV13DecimalAlias, - zetasql.FeatureCreateTableNotNull, - zetasql.FeatureParameterizedTypes, - zetasql.FeatureTablesample, - zetasql.FeatureTimestampNanos, - zetasql.FeatureV11HavingInAggregate, - zetasql.FeatureV11NullHandlingModifierInAggregate, - zetasql.FeatureV11NullHandlingModifierInAnalytic, - zetasql.FeatureV11OrderByCollate, - zetasql.FeatureV11SelectStarExceptReplace, - zetasql.FeatureV12SafeFunctionCall, - zetasql.FeatureJsonType, - zetasql.FeatureJsonArrayFunctions, - zetasql.FeatureJsonStrictNumberParsing, - zetasql.FeatureV13IsDistinct, - zetasql.FeatureV13FormatInCast, - zetasql.FeatureV13DateArithmetics, - zetasql.FeatureV11OrderByInAggregate, - zetasql.FeatureV11LimitInAggregate, - zetasql.FeatureV13DateTimeConstructors, - zetasql.FeatureV13ExtendedDateTimeSignatures, - zetasql.FeatureV12CivilTime, - zetasql.FeatureV12WeekWithWeekday, - zetasql.FeatureIntervalType, - zetasql.FeatureGroupByRollup, - zetasql.FeatureV13NullsFirstLastInOrderBy, - zetasql.FeatureV13Qualify, - zetasql.FeatureV13AllowDashesInTableName, - zetasql.FeatureGeography, - zetasql.FeatureV13ExtendedGeographyParsers, - zetasql.FeatureTemplateFunctions, - zetasql.FeatureV11WithOnSubquery, - zetasql.FeatureV13Pivot, - zetasql.FeatureV13Unpivot, - zetasql.FeatureV13WithRecursive, - zetasql.FeatureV13Qualify, - }) - langOpt.SetSupportedStatementKinds([]rast.Kind{ - rast.BeginStmt, - rast.CommitStmt, - rast.MergeStmt, - rast.QueryStmt, - rast.InsertStmt, - rast.UpdateStmt, - rast.DeleteStmt, - rast.DropStmt, - rast.TruncateStmt, - rast.CreateTableStmt, - rast.CreateTableAsSelectStmt, - rast.CreateProcedureStmt, - rast.CreateFunctionStmt, - rast.CreateTableFunctionStmt, - rast.CreateViewStmt, - }) + langOpt.EnableMaximumLanguageFeatures() + langOpt.EnableLanguageFeature(zetasql.FeatureV13AllowDashesInTableName) + langOpt.EnableLanguageFeature(zetasql.FeatureV13Qualify) + langOpt.EnableLanguageFeature(zetasql.FeatureV13ScriptLabel) + langOpt.EnableLanguageFeature(zetasql.FeatureAnalyticFunctions) + langOpt.SetSupportsAllStatementKinds() + langOpt.EnableAllReservableKeywords(true) + err := langOpt.EnableReservableKeyword("QUALIFY", true) + if err != nil { + return nil, err + } + return langOpt, nil +} + +func (a *Analyzer) AnalyzeStatement(rawText string, stmt ast.StatementNode, catalog types.Catalog) (*zetasql.AnalyzerOutput, error) { + langOpt, err := a.langOpt() + if err != nil { + return nil, err + } opts := zetasql.NewAnalyzerOptions() opts.SetLanguage(langOpt) opts.SetAllowUndeclaredParameters(true) @@ -103,7 +62,17 @@ func (a *Analyzer) AnalyzeStatement(rawText string, stmt ast.StatementNode, cata return zetasql.AnalyzeStatementFromParserAST(rawText, stmt, catalog, opts) } -func (p *Analyzer) ParseFile(uri string, src string) ParsedFile { +func (a *Analyzer) parseScript(src string) (ast.ScriptNode, error) { + langOpt, err := a.langOpt() + if err != nil { + return nil, err + } + opts := zetasql.NewParserOptions() + opts.SetLanguageOptions(langOpt) + return zetasql.ParseScript(src, opts, zetasql.ErrorMessageOneLine) +} + +func (a *Analyzer) ParseFile(uri string, src string) ParsedFile { fixedSrc, errs, fixOffsets := fixDot(src) var node ast.ScriptNode @@ -111,7 +80,7 @@ func (p *Analyzer) ParseFile(uri string, src string) ParsedFile { for _retry := 0; _retry < 10; _retry++ { var err error var fo []FixOffset - node, err = zetasql.ParseScript(fixedSrc, zetasql.NewParserOptions(), zetasql.ErrorMessageOneLine) + node, err = a.parseScript(fixedSrc) if err != nil { pErr := parseZetaSQLError(err) if strings.Contains(pErr.Msg, "SELECT list must not be empty") { @@ -139,14 +108,14 @@ func (p *Analyzer) ParseFile(uri string, src string) ParsedFile { return nil }) - catalog := p.catalog.Clone() + catalog := a.catalog.Clone() declarationMap := make(map[string]string) for _, s := range stmts { if s.Kind() == ast.VariableDeclaration { node := s.(*ast.VariableDeclarationNode) dummyValue, err := getDummyValueForDeclarationNode(node) if err != nil { - p.logger.Debug("failed to get default value for declaration", err) + a.logger.Debug("failed to get default value for declaration", err) } list := node.VariableList().IdentifierList() for _, l := range list { @@ -157,7 +126,7 @@ func (p *Analyzer) ParseFile(uri string, src string) ParsedFile { if s.Kind() == ast.CreateFunctionStatement { node := s.(*ast.CreateFunctionStatementNode) - newFunc, err := p.createFunctionTypes(node, fixedSrc) + newFunc, err := a.createFunctionTypes(node, fixedSrc) if err != nil { errs = append(errs, *err) continue @@ -170,7 +139,7 @@ func (p *Analyzer) ParseFile(uri string, src string) ParsedFile { catalog.AddFunctionWithName(name, newFunc) } - output, err := p.AnalyzeStatement(fixedSrc, s, catalog) + output, err := a.AnalyzeStatement(fixedSrc, s, catalog) if err == nil { rnode = append(rnode, output) continue diff --git a/langserver/internal/source/file/file_test.go b/langserver/internal/source/file/file_test.go index d16e4c3..f7abb3d 100644 --- a/langserver/internal/source/file/file_test.go +++ b/langserver/internal/source/file/file_test.go @@ -566,6 +566,24 @@ func TestProject_ParseFile(t *testing.T) { }, }, }, + "parse QUALIFY statement": { + file: "SELECT * FROM `project.dataset.table` t1\nQUALIFY RANK() OVER (PARTITION BY city ORDER BY temperature) = 1", + bqTableMetadataMap: map[string]*bq.TableMetadata{ + "project.dataset.table": { + Schema: bq.Schema{ + { + Name: "city", + Type: bq.StringFieldType, + }, + { + Name: "temperature", + Type: bq.IntegerFieldType, + }, + }, + }, + }, + expectedErrs: []file.Error{}, + }, } for n, tt := range tests {