diff --git a/enginetest/join_stats_tests.go b/enginetest/join_stats_tests.go index 7df713e6fd..32e43eb7cd 100644 --- a/enginetest/join_stats_tests.go +++ b/enginetest/join_stats_tests.go @@ -360,12 +360,12 @@ func (t TestProvider) Function(ctx *sql.Context, name string) (sql.Function, boo return nil, false } -func (t TestProvider) TableFunction(_ *sql.Context, name string) (sql.TableFunction, error) { +func (t TestProvider) TableFunction(_ *sql.Context, name string) (sql.TableFunction, bool) { if tf, ok := t.tableFunctions[strings.ToLower(name)]; ok { - return tf, nil + return tf, true } - return nil, sql.ErrTableFunctionNotFound.New(name) + return nil, false } func (t TestProvider) WithTableFunctions(fns ...sql.TableFunction) (sql.TableFunctionProvider, error) { diff --git a/memory/provider.go b/memory/provider.go index 3023fc6e65..624e66b549 100644 --- a/memory/provider.go +++ b/memory/provider.go @@ -194,10 +194,10 @@ func (pro *DbProvider) ExternalStoredProcedures(_ *sql.Context, name string) ([] } // TableFunction implements sql.TableFunctionProvider -func (pro *DbProvider) TableFunction(_ *sql.Context, name string) (sql.TableFunction, error) { +func (pro *DbProvider) TableFunction(_ *sql.Context, name string) (sql.TableFunction, bool) { if tableFunction, ok := pro.tableFunctions[name]; ok { - return tableFunction, nil + return tableFunction, true } - return nil, sql.ErrTableFunctionNotFound.New(name) + return nil, false } diff --git a/sql/analyzer/catalog.go b/sql/analyzer/catalog.go index fedddae31a..167778f7dc 100644 --- a/sql/analyzer/catalog.go +++ b/sql/analyzer/catalog.go @@ -384,17 +384,14 @@ func (c *Catalog) ExternalStoredProcedures(ctx *sql.Context, name string) ([]sql } // TableFunction implements the TableFunctionProvider interface -func (c *Catalog) TableFunction(ctx *sql.Context, name string) (sql.TableFunction, error) { +func (c *Catalog) TableFunction(ctx *sql.Context, name string) (sql.TableFunction, bool) { if fp, ok := c.DbProvider.(sql.TableFunctionProvider); ok { - tf, err := fp.TableFunction(ctx, name) - if err != nil { - return nil, err - } else if tf != nil { - return tf, nil + tf, found := fp.TableFunction(ctx, name) + if found && tf != nil { + return tf, true } } - - return nil, sql.ErrTableFunctionNotFound.New(name) + return nil, false } func (c *Catalog) RefreshTableStats(ctx *sql.Context, table sql.Table, db string) error { diff --git a/sql/catalog_map.go b/sql/catalog_map.go index 3f23b03a6b..ecbf0f9567 100644 --- a/sql/catalog_map.go +++ b/sql/catalog_map.go @@ -25,11 +25,11 @@ func (t MapCatalog) Function(ctx *Context, name string) (Function, bool) { return nil, false } -func (t MapCatalog) TableFunction(ctx *Context, name string) (TableFunction, error) { +func (t MapCatalog) TableFunction(ctx *Context, name string) (TableFunction, bool) { if f, ok := t.tabFuncs[name]; ok { - return f, nil + return f, true } - return nil, fmt.Errorf("table func not found") + return nil, false } func (t MapCatalog) ExternalStoredProcedure(ctx *Context, name string, numOfParams int) (*ExternalStoredProcedureDetails, error) { diff --git a/sql/databases.go b/sql/databases.go index f7dfbe2f49..e1ec634ace 100644 --- a/sql/databases.go +++ b/sql/databases.go @@ -57,8 +57,9 @@ type CollatedDatabaseProvider interface { // TableFunctionProvider is an interface that allows custom table functions to be provided. It's usually (but not // always) implemented by a DatabaseProvider. type TableFunctionProvider interface { - // TableFunction returns the table function with the name provided, case-insensitive - TableFunction(ctx *Context, name string) (TableFunction, error) + // TableFunction returns the table function with the name provided, case-insensitive. + // It also returns boolean param for whether the table function was found. + TableFunction(ctx *Context, name string) (TableFunction, bool) // WithTableFunctions returns a new provider with (only) the list of table functions arguments WithTableFunctions(fns ...TableFunction) (TableFunctionProvider, error) } diff --git a/sql/expression/tablefunction/table_function.go b/sql/expression/tablefunction/table_function.go new file mode 100644 index 0000000000..4491b1f8d9 --- /dev/null +++ b/sql/expression/tablefunction/table_function.go @@ -0,0 +1,139 @@ +// Copyright 2024 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package dtablefunctions + +import ( + "fmt" + "strings" + + "github.com/dolthub/go-mysql-server/sql" +) + +var _ sql.TableFunction = &TableFunctionWrapper{} +var _ sql.ExecSourceRel = &TableFunctionWrapper{} + +// TableFunctionWrapper represents a table function with underlying +// regular function. It allows using regular function as table function. +type TableFunctionWrapper struct { + underlyingFunc sql.Function + + args []sql.Expression + database sql.Database + funcExpr sql.Expression +} + +// NewTableFunctionWrapper creates new TableFunction +// with given Function as underlying function. +func NewTableFunctionWrapper(f sql.Function) sql.TableFunction { + return &TableFunctionWrapper{ + underlyingFunc: f, + } +} + +func (t *TableFunctionWrapper) NewInstance(ctx *sql.Context, db sql.Database, args []sql.Expression) (sql.Node, error) { + nt := *t + nt.database = db + nt.args = args + f, err := nt.underlyingFunc.NewInstance(args) + if err != nil { + return nil, err + } + nt.funcExpr = f + return &nt, nil +} + +func (t *TableFunctionWrapper) Children() []sql.Node { + return nil +} + +func (t *TableFunctionWrapper) Database() sql.Database { + return t.database +} + +func (t *TableFunctionWrapper) Expressions() []sql.Expression { + if t.funcExpr == nil { + return nil + } + return t.funcExpr.Children() +} + +func (t *TableFunctionWrapper) IsReadOnly() bool { + return true +} + +func (t *TableFunctionWrapper) Name() string { + return t.underlyingFunc.FunctionName() +} + +func (t *TableFunctionWrapper) RowIter(ctx *sql.Context, r sql.Row) (sql.RowIter, error) { + v, err := t.funcExpr.Eval(ctx, r) + if err != nil { + return nil, err + } + return sql.RowsToRowIter(sql.Row{v}), nil +} + +func (t *TableFunctionWrapper) Resolved() bool { + for _, expr := range t.args { + if !expr.Resolved() { + return false + } + } + return true +} + +func (t *TableFunctionWrapper) Schema() sql.Schema { + return sql.Schema{&sql.Column{Name: t.underlyingFunc.FunctionName(), Type: t.funcExpr.Type()}} +} + +func (t *TableFunctionWrapper) String() string { + var args []string + for _, expr := range t.args { + args = append(args, expr.String()) + } + return fmt.Sprintf("%s(%s)", t.underlyingFunc.FunctionName(), strings.Join(args, ", ")) +} + +func (t *TableFunctionWrapper) WithChildren(children ...sql.Node) (sql.Node, error) { + if len(children) != 0 { + return nil, sql.ErrInvalidChildrenNumber.New(t, len(children), 0) + } + return t, nil +} + +func (t *TableFunctionWrapper) WithDatabase(database sql.Database) (sql.Node, error) { + nt := *t + nt.database = database + return &nt, nil +} + +func (t *TableFunctionWrapper) WithExpressions(exprs ...sql.Expression) (sql.Node, error) { + if t.funcExpr == nil { + if len(exprs) != 0 { + return nil, sql.ErrInvalidChildrenNumber.New(t, len(exprs), 0) + } + } + l := len(t.funcExpr.Children()) + if len(exprs) != l { + return nil, sql.ErrInvalidChildrenNumber.New(t, len(exprs), l) + } + nt := *t + nf, err := nt.funcExpr.WithChildren(exprs...) + if err != nil { + return nil, err + } + nt.funcExpr = nf + return &nt, nil +} diff --git a/sql/planbuilder/from.go b/sql/planbuilder/from.go index 51283365b3..c7070fe69f 100644 --- a/sql/planbuilder/from.go +++ b/sql/planbuilder/from.go @@ -22,6 +22,7 @@ import ( "github.com/dolthub/go-mysql-server/sql" "github.com/dolthub/go-mysql-server/sql/expression" + dtablefunctions "github.com/dolthub/go-mysql-server/sql/expression/tablefunction" "github.com/dolthub/go-mysql-server/sql/mysql_db" "github.com/dolthub/go-mysql-server/sql/plan" "github.com/dolthub/go-mysql-server/sql/transform" @@ -447,20 +448,11 @@ func (b *Builder) resolveTable(tab, db string, asOf interface{}) *plan.ResolvedT func (b *Builder) buildTableFunc(inScope *scope, t *ast.TableFuncExpr) (outScope *scope) { //TODO what are valid mysql table arguments args := make([]sql.Expression, 0, len(t.Exprs)) - for _, e := range t.Exprs { - switch e := e.(type) { + for _, expr := range t.Exprs { + switch e := expr.(type) { case *ast.AliasedExpr: - expr := b.buildScalar(inScope, e.Expr) - - if !e.As.IsEmpty() { - b.handleErr(sql.ErrUnsupportedSyntax.New(ast.String(e))) - } - - if selectExprNeedsAlias(e, expr) { - b.handleErr(sql.ErrUnsupportedSyntax.New(ast.String(e))) - } - - args = append(args, expr) + scalarExpr := b.buildScalar(inScope, e.Expr) + args = append(args, scalarExpr) default: b.handleErr(sql.ErrUnsupportedSyntax.New(ast.String(e))) } @@ -468,9 +460,14 @@ func (b *Builder) buildTableFunc(inScope *scope, t *ast.TableFuncExpr) (outScope utf := expression.NewUnresolvedTableFunction(t.Name, args) - tableFunction, err := b.cat.TableFunction(b.ctx, utf.Name()) - if err != nil { - b.handleErr(err) + tableFunction, found := b.cat.TableFunction(b.ctx, utf.Name()) + if !found { + // try getting regular function + f, funcFound := b.cat.Function(b.ctx, utf.Name()) + if !funcFound { + b.handleErr(sql.ErrTableFunctionNotFound.New(utf.Name())) + } + tableFunction = dtablefunctions.NewTableFunctionWrapper(f) } database := b.currentDb() diff --git a/test/test_catalog.go b/test/test_catalog.go index 1f94f439f1..9500271e04 100644 --- a/test/test_catalog.go +++ b/test/test_catalog.go @@ -159,7 +159,7 @@ func (c *Catalog) UnlockTables(ctx *sql.Context, id uint32) error { return nil } -func (c *Catalog) TableFunction(ctx *sql.Context, name string) (sql.TableFunction, error) { +func (c *Catalog) TableFunction(ctx *sql.Context, name string) (sql.TableFunction, bool) { //TODO implement me panic("implement me") }