Skip to content

Commit

Permalink
Merge pull request #93 from kitagry/support-tmp-function
Browse files Browse the repository at this point in the history
Support user defined function
  • Loading branch information
kitagry authored Dec 28, 2023
2 parents f96c942 + 015a6e2 commit f4efa8b
Show file tree
Hide file tree
Showing 6 changed files with 178 additions and 9 deletions.
13 changes: 12 additions & 1 deletion langserver/diagnostic.go
Original file line number Diff line number Diff line change
Expand Up @@ -172,8 +172,19 @@ func convertErrorsToDiagnostics(errs []file.Error) []lsp.Diagnostic {
Start: err.Position,
End: endPosition,
},
Message: err.Msg,
Message: err.Msg,
Severity: or(err.Severity, lsp.Error),
}
}
return result
}

func or[T comparable](list ...T) T {
var zero T
for _, t := range list {
if t != zero {
return t
}
}
return zero
}
121 changes: 116 additions & 5 deletions langserver/internal/source/file/analyze.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,15 @@ import (
rast "github.com/goccy/go-zetasql/resolved_ast"
"github.com/goccy/go-zetasql/types"
"github.com/kitagry/bqls/langserver/internal/bigquery"
"github.com/kitagry/bqls/langserver/internal/lsp"
"github.com/kitagry/bqls/langserver/internal/source/helper"
"github.com/sirupsen/logrus"
)

type Analyzer struct {
logger *logrus.Logger
bqClient bigquery.Client
catalog types.Catalog
catalog *Catalog
}

func NewAnalyzer(logger *logrus.Logger, bqClient bigquery.Client) *Analyzer {
Expand All @@ -30,10 +32,10 @@ func NewAnalyzer(logger *logrus.Logger, bqClient bigquery.Client) *Analyzer {
}
}

func (a *Analyzer) AnalyzeStatement(rawText string, stmt ast.StatementNode) (*zetasql.AnalyzerOutput, error) {
func (a *Analyzer) AnalyzeStatement(rawText string, stmt ast.StatementNode, catalog types.Catalog) (*zetasql.AnalyzerOutput, error) {
langOpt := zetasql.NewLanguageOptions()
langOpt.SetNameResolutionMode(zetasql.NameResolutionDefault)
langOpt.SetProductMode(types.ProductInternal)
langOpt.SetProductMode(types.ProductExternal)
langOpt.SetEnabledLanguageFeatures([]zetasql.LanguageFeature{
zetasql.FeatureAnalyticFunctions,
zetasql.FeatureNamedArguments,
Expand Down Expand Up @@ -96,7 +98,7 @@ func (a *Analyzer) AnalyzeStatement(rawText string, stmt ast.StatementNode) (*ze
opts.SetAllowUndeclaredParameters(true)
opts.SetErrorMessageMode(zetasql.ErrorMessageOneLine)
opts.SetParseLocationRecordType(zetasql.ParseLocationRecordCodeSearch)
return zetasql.AnalyzeStatementFromParserAST(rawText, stmt, a.catalog, opts)
return zetasql.AnalyzeStatementFromParserAST(rawText, stmt, catalog, opts)
}

func (p *Analyzer) ParseFile(uri string, src string) ParsedFile {
Expand Down Expand Up @@ -135,6 +137,7 @@ func (p *Analyzer) ParseFile(uri string, src string) ParsedFile {
return nil
})

catalog := p.catalog.Clone()
declarationMap := make(map[string]string)
for _, s := range stmts {
if s.Kind() == ast.VariableDeclaration {
Expand All @@ -149,7 +152,23 @@ func (p *Analyzer) ParseFile(uri string, src string) ParsedFile {
}
continue
}
output, err := p.AnalyzeStatement(fixedSrc, s)

if s.Kind() == ast.CreateFunctionStatement {
node := s.(*ast.CreateFunctionStatementNode)
newFunc, err := p.createFunctionTypes(node, fixedSrc)
if err != nil {
errs = append(errs, *err)
continue
}

name := ""
for _, n := range node.FunctionDeclaration().Name().Names() {
name += n.Name()
}
catalog.AddFunctionWithName(name, newFunc)
}

output, err := p.AnalyzeStatement(fixedSrc, s, catalog)
if err == nil {
rnode = append(rnode, output)
continue
Expand Down Expand Up @@ -241,6 +260,54 @@ func (a *Analyzer) GetTableMetadataFromPath(ctx context.Context, path string) (*
}
}

func (p *Analyzer) createFunctionTypes(node *ast.CreateFunctionStatementNode, sourceFile string) (*types.Function, *Error) {
argTypes := []*types.FunctionArgumentType{}
for _, parameter := range node.FunctionDeclaration().Parameters().ParameterEntries() {
typ, err := getTypeFromTypeNode(parameter.Type())
if err != nil {
p.logger.Debug("failed to get type from parameter ", err)
return nil, nil
}
opt := types.NewFunctionArgumentTypeOptions(types.RequiredArgumentCardinality)
opt.SetArgumentName(parameter.Name().Name())
args := types.NewFunctionArgumentType(typ, opt)
argTypes = append(argTypes, args)
}

returnType := node.ReturnType()
var typ types.Type
if returnType != nil {
var err error
typ, err = getTypeFromTypeNode(node.ReturnType())
if err != nil {
p.logger.Debug("failed to get type from return type ", err)
return nil, nil
}
} else {
err := Error{
Msg: "Currently, bqls does not support function without return type.",
Severity: lsp.Warning,
}
loc := node.ParseLocationRange()
if loc != nil {
err.Position = helper.IndexToPosition(sourceFile, loc.Start().ByteOffset())
err.TermLength = loc.End().ByteOffset() - loc.Start().ByteOffset()
}
return nil, &err
}
opt := types.NewFunctionArgumentTypeOptions(types.RequiredArgumentCardinality)
retType := types.NewFunctionArgumentType(typ, opt)

sig := types.NewFunctionSignature(retType, argTypes)

name := ""
for _, n := range node.FunctionDeclaration().Name().Names() {
name += n.Name()
}
newFunc := types.NewFunction([]string{name}, "", types.ScalarMode, []*types.FunctionSignature{sig})
return newFunc, nil
}

func getDummyValueForDeclarationNode(node *ast.VariableDeclarationNode) (string, error) {
switch n := node.Type().(type) {
case *ast.ArrayTypeNode:
Expand Down Expand Up @@ -313,3 +380,47 @@ func getDummyValueForDefaultValueNode(node ast.ExpressionNode) (string, error) {
return "", fmt.Errorf("not implemented: %T", node)
}
}

func getTypeFromTypeNode(node ast.TypeNode) (types.Type, error) {
if stn, ok := node.(*ast.SimpleTypeNode); ok {
names := stn.TypeName().Names()
typeName := ""
for _, n := range names {
typeName += n.Name()
}

switch typeName {
case "INT64":
return types.Int64Type(), nil
case "FLOAT64":
return types.FloatType(), nil
case "BOOL":
return types.BoolType(), nil
case "STRING":
return types.StringType(), nil
case "BYTES":
return types.BytesType(), nil
case "DATE":
return types.DateType(), nil
case "DATETIME":
return types.DatetimeType(), nil
case "TIME":
return types.TimeType(), nil
case "TIMESTAMP":
return types.TimestampType(), nil
case "NUMERIC":
return types.NumericType(), nil
case "BIGNUMERIC":
return types.BigNumericType(), nil
case "GEOGRAPHY":
return types.GeographyType(), nil
case "INTERVAL":
return types.IntervalType(), nil
case "JSON":
return types.JsonType(), nil
default:
return nil, fmt.Errorf("not implemented: %s", typeName)
}
}
return nil, fmt.Errorf("not implemented: %T", node)
}
19 changes: 18 additions & 1 deletion langserver/internal/source/file/catalog.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@ type Catalog struct {
mu *sync.Mutex
}

func NewCatalog(bqClient bigquery.Client) types.Catalog {
var _ types.Catalog = (*Catalog)(nil)

func NewCatalog(bqClient bigquery.Client) *Catalog {
catalog := types.NewSimpleCatalog(catalogName)
catalog.AddZetaSQLBuiltinFunctions(nil)
return &Catalog{
Expand All @@ -34,6 +36,17 @@ func NewCatalog(bqClient bigquery.Client) types.Catalog {
}
}

func (c *Catalog) Clone() *Catalog {
catalog := types.NewSimpleCatalog(catalogName)
catalog.AddZetaSQLBuiltinFunctions(nil)
return &Catalog{
catalog: catalog,
bqClient: c.bqClient,
tableMetaMap: make(map[string]*bq.TableMetadata),
mu: &sync.Mutex{},
}
}

func (c *Catalog) FullName() string {
return c.catalog.FullName()
}
Expand Down Expand Up @@ -177,6 +190,10 @@ func (c *Catalog) FindConnection(path []string) (types.Connection, error) {
return c.catalog.FindConnection(path)
}

func (c *Catalog) AddFunctionWithName(name string, fn *types.Function) {
c.catalog.AddFunctionWithName(name, fn)
}

func (c *Catalog) FindFunction(path []string) (*types.Function, error) {
return c.catalog.FindFunction(path)
}
Expand Down
1 change: 1 addition & 0 deletions langserver/internal/source/file/file.go
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ type Error struct {
Position lsp.Position
TermLength int
IncompleteColumnName string
Severity lsp.DiagnosticSeverity
}

func (e Error) Error() string {
Expand Down
29 changes: 29 additions & 0 deletions langserver/internal/source/file/file_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -706,6 +706,35 @@ func TestAnalyzer_ParseFileWithDeclareStatement(t *testing.T) {
},
expectedErrs: []file.Error{},
},
"Parse create tmp function statement with returns": {
file: "CREATE TEMP FUNCTION target_func(x INT64) RETURNS INT64 AS (x * 10);\n" +
"SELECT target_func(10);",
bqTableMetadataMap: map[string]*bq.TableMetadata{},
expectedErrs: []file.Error{},
},
"Parse create tmp function statement without returns": {
file: "CREATE TEMP FUNCTION target_func(x INT64) AS (x * 10);\n" +
"SELECT target_func(10);",
bqTableMetadataMap: map[string]*bq.TableMetadata{},
expectedErrs: []file.Error{
{
Msg: "Currently, bqls does not support function without return type.",
Position: lsp.Position{
Line: 0,
Character: 0,
},
TermLength: 53,
Severity: lsp.Warning,
},
{
Msg: "INVALID_ARGUMENT: Function not found: target_func",
Position: lsp.Position{
Line: 1,
Character: 7,
},
},
},
},
}

for n, tt := range tests {
Expand Down
4 changes: 2 additions & 2 deletions langserver/internal/source/helper/file.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ func GetLspPosition(files map[string]string) (formattedFiles map[string]string,
if ind := strings.Index(file, "|"); ind != -1 {
file = strings.Replace(file, "|", "", 1)
path = filePath
position = indexToPosition(file, ind)
position = IndexToPosition(file, ind)
}
formattedFiles[path] = file
}
Expand All @@ -27,7 +27,7 @@ func GetLspPosition(files map[string]string) (formattedFiles map[string]string,
return
}

func indexToPosition(file string, index int) lsp.Position {
func IndexToPosition(file string, index int) lsp.Position {
col, row := 0, 0
lines := strings.Split(file, "\n")
for _, line := range lines {
Expand Down

0 comments on commit f4efa8b

Please sign in to comment.