Skip to content

Commit

Permalink
Support CONTAINS_SUBSTR function (#45)
Browse files Browse the repository at this point in the history
* Support `CONTAINS_SUBSTR` function

* lint
  • Loading branch information
ohaibbq committed Jun 29, 2024
1 parent 2e98c2c commit 7b36bbc
Show file tree
Hide file tree
Showing 5 changed files with 167 additions and 91 deletions.
5 changes: 4 additions & 1 deletion driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,10 @@ func newDBAndCatalog(name string) (*sql.DB, *internal.Catalog, error) {
if err != nil {
return nil, nil, fmt.Errorf("failed to open database by %s: %w", name, err)
}
catalog := internal.NewCatalog(db)
catalog, err := internal.NewCatalog(db)
if err != nil {
return nil, nil, fmt.Errorf("failed open database by %s: failed to initialize catalog: %w", name, err)
}
nameToDBMap[name] = db
nameToCatalogMap[name] = catalog
return db, catalog, nil
Expand Down
21 changes: 19 additions & 2 deletions internal/catalog.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,13 +71,30 @@ func newSimpleCatalog(name string) *types.SimpleCatalog {
return catalog
}

func NewCatalog(db *sql.DB) *Catalog {
return &Catalog{
func NewCatalog(db *sql.DB) (*Catalog, error) {
catalog := &Catalog{
db: db,
catalog: newSimpleCatalog(catalogName),
tableMap: map[string]*TableSpec{},
funcMap: map[string]*FunctionSpec{},
}

// Add missing CONTAINS_SUBSTR function to the catalog
// https://github.com/google/zetasql/issues/135#issuecomment-1490908494
if err := catalog.addFunctionSpec(&FunctionSpec{
IsTemp: false,
NamePath: []string{"contains_substr"},
Language: "SQL",
Args: []*NameWithType{
{Name: "expression", Type: &Type{Name: "expression", Kind: types.STRING, SignatureKind: types.ArgTypeArbitrary}},
{Name: "search_value_literal", Type: &Type{Name: "search_value_literal", Kind: types.STRING, SignatureKind: types.ArgTypeFixed}},
},
Return: &Type{Name: "BOOL", Kind: types.BOOL, SignatureKind: types.ArgTypeFixed},
}); err != nil {
return nil, err
}

return catalog, nil
}

func (c *Catalog) FullName() string {
Expand Down
40 changes: 39 additions & 1 deletion internal/function_bind.go
Original file line number Diff line number Diff line change
Expand Up @@ -916,11 +916,49 @@ func bindContainsSubstr(args ...Value) (Value, error) {
if existsNull(args) {
return nil, nil
}

// Perform field-by-field match on structs; returning in order:
// true if one field contains substr
// null if at least one field is null
// otherwise false
if structValue, ok := args[0].(*StructValue); ok {
nullExists := false
for _, value := range structValue.values {
if value == nil {
nullExists = true
continue
}
var err error
result, err := bindContainsSubstr(value, args[1])

if err != nil {
return nil, err
}
contained, err := result.EQ(BoolValue(true))
if err != nil {
return nil, err
}
if contained {
return BoolValue(true), nil
}
}

if nullExists {
return nil, nil
}

return BoolValue(false), nil
}

value, err := args[0].ToString()
if err != nil {
return nil, err
}
search, err := args[1].ToString()
if err != nil {
return nil, err
}
return CONTAINS_SUBSTR(args[0], search)
return CONTAINS_SUBSTR(value, search)
}

func bindEndsWith(args ...Value) (Value, error) {
Expand Down
23 changes: 21 additions & 2 deletions internal/function_string.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,27 @@ func CONCAT(args ...Value) (Value, error) {
return nil, fmt.Errorf("CONCAT: argument type must be STRING or BYTES")
}

func CONTAINS_SUBSTR(exprValue Value, search string) (Value, error) {
return nil, nil
func CONTAINS_SUBSTR(value string, search string) (Value, error) {
normalizedExprValue, err := NORMALIZE_AND_CASEFOLD(value, "NFKC")
if err != nil {
return nil, fmt.Errorf("CONTAINS_SUBSTR: could not normalize and casefold value: %w", err)
}

normalizedValue, err := normalizedExprValue.ToString()
if err != nil {
return nil, fmt.Errorf("CONTAINS_SUBSTR: could not convert expression to string: %w", err)
}

normalizedSearchValue, err := NORMALIZE_AND_CASEFOLD(search, "NFKC")
if err != nil {
return nil, fmt.Errorf("CONTAINS_SUBSTR: could not normalize and casefold value: %w", err)
}

normalizedSearch, err := normalizedSearchValue.ToString()
if err != nil {
return nil, fmt.Errorf("CONTAINS_SUBSTR: could not convert expression to string: %w", err)
}
return BoolValue(strings.Contains(normalizedValue, normalizedSearch)), nil
}

func ENDS_WITH(value, ends Value) (Value, error) {
Expand Down
169 changes: 84 additions & 85 deletions query_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3314,91 +3314,90 @@ SELECT characters, CHARACTER_LENGTH(characters) FROM example`,
query: `SELECT CONCAT('T.P.', ' ', 'Bar'), CONCAT('Summer', ' ', 1923), CONCAT("abc"), CONCAT(1), CONCAT('A', NULL, 'C'), CONCAT(NULL)`,
expectedRows: [][]interface{}{{"T.P. Bar", "Summer 1923", "abc", "1", nil, nil}},
},
// TODO: currently unsupported CONTAINS_SUBSTR function because ZetaSQL library doesn't support it.
// {
// name: "contains_substr true",
// query: `SELECT CONTAINS_SUBSTR('the blue house', 'Blue house')`,
// expectedRows: [][]interface{}{{true}},
// },
// {
// name: "contains_substr false",
// query: `SELECT CONTAINS_SUBSTR('the red house', 'blue')`,
// expectedRows: [][]interface{}{{false}},
// },
// {
// name: "contains_substr normalize",
// query: `SELECT '\u2168 day' AS a, 'IX' AS b, CONTAINS_SUBSTR('\u2168', 'IX')`,
// expectedRows: [][]interface{}{{"Ⅸ day", "IX", true}},
// },
// {
// name: "contains_substr struct_field",
// query: `SELECT CONTAINS_SUBSTR((23, 35, 41), '35')`,
// expectedRows: [][]interface{}{{true}},
// },
// {
// name: "contains_substr recursive",
// query: `SELECT CONTAINS_SUBSTR(('abc', ['def', 'ghi', 'jkl'], 'mno'), 'jk')`,
// expectedRows: [][]interface{}{{true}},
// },
// {
// name: "contains_substr struct with null",
// query: `SELECT CONTAINS_SUBSTR((23, NULL, 41), '41')`,
// expectedRows: [][]interface{}{{true}},
// },
// {
// name: "contains_substr struct with null2",
// query: `SELECT CONTAINS_SUBSTR((23, NULL, 41), '35')`,
// expectedRows: [][]interface{}{{nil}},
// },
// {
// name: "contains_substr nil",
// query: `SELECT CONTAINS_SUBSTR('hello', NULL)`,
// expectedErr: true,
// },
// {
// name: "contains_substr for table all rows",
// query: `
// WITH Recipes AS (
// SELECT 'Blueberry pancakes' as Breakfast, 'Egg salad sandwich' as Lunch, 'Potato dumplings' as Dinner UNION ALL
// SELECT 'Potato pancakes', 'Toasted cheese sandwich', 'Beef stroganoff' UNION ALL
// SELECT 'Ham scramble', 'Steak avocado salad', 'Tomato pasta' UNION ALL
// SELECT 'Avocado toast', 'Tomato soup', 'Blueberry salmon' UNION ALL
// SELECT 'Corned beef hash', 'Lentil potato soup', 'Glazed ham'
// ) SELECT * FROM Recipes WHERE CONTAINS_SUBSTR(Recipes, 'toast')`,
// expectedRows: [][]interface{}{
// {"Potato pancakes", "Toasted cheese sandwich", "Beef stroganoff"},
// {"Avocado toast", "Tomato soup", "Blueberry samon"},
// },
// },
// {
// name: "contains_substr for table specified rows",
// query: `
// WITH Recipes AS (
// SELECT 'Blueberry pancakes' as Breakfast, 'Egg salad sandwich' as Lunch, 'Potato dumplings' as Dinner UNION ALL
// SELECT 'Potato pancakes', 'Toasted cheese sandwich', 'Beef stroganoff' UNION ALL
// SELECT 'Ham scramble', 'Steak avocado salad', 'Tomato pasta' UNION ALL
// SELECT 'Avocado toast', 'Tomato soup', 'Blueberry salmon' UNION ALL
// SELECT 'Corned beef hash', 'Lentil potato soup', 'Glazed ham'
// ) SELECT * FROM Recipes WHERE CONTAINS_SUBSTR((Lunch, Dinner), 'potato')`,
// expectedRows: [][]interface{}{
// {"Bluberry pancakes", "Egg salad sandwich", "Potato dumplings"},
// {"Corned beef hash", "Lentil potato soup", "Glazed ham"},
// },
// },
// {
// name: "contains_substr for table except",
// query: `
// WITH Recipes AS (
// SELECT 'Blueberry pancakes' as Breakfast, 'Egg salad sandwich' as Lunch, 'Potato dumplings' as Dinner UNION ALL
// SELECT 'Potato pancakes', 'Toasted cheese sandwich', 'Beef stroganoff' UNION ALL
// SELECT 'Ham scramble', 'Steak avocado salad', 'Tomato pasta' UNION ALL
// SELECT 'Avocado toast', 'Tomato soup', 'Blueberry salmon' UNION ALL
// SELECT 'Corned beef hash', 'Lentil potato soup', 'Glazed ham'
// ) SELECT * FROM Recipes WHERE CONTAINS_SUBSTR((SELECT AS STRUCT Recipes.* EXCEPT (Lunch, Dinner)), 'potato')`,
// expectedRows: [][]interface{}{
// {"Potato pancakes", "Toasted cheese sandwich", "Beef stroganoff"},
// },
// },
{
name: "contains_substr true",
query: `SELECT CONTAINS_SUBSTR('the blue house', 'Blue house')`,
expectedRows: [][]interface{}{{true}},
},
{
name: "contains_substr false",
query: `SELECT CONTAINS_SUBSTR('the red house', 'blue')`,
expectedRows: [][]interface{}{{false}},
},
{
name: "contains_substr normalize",
query: `SELECT '\u2168 day' AS a, 'IX' AS b, CONTAINS_SUBSTR('\u2168', 'IX')`,
expectedRows: [][]interface{}{{"Ⅸ day", "IX", true}},
},
{
name: "contains_substr struct_field",
query: `SELECT CONTAINS_SUBSTR((23, 35, 41), '35')`,
expectedRows: [][]interface{}{{true}},
},
{
name: "contains_substr recursive",
query: `SELECT CONTAINS_SUBSTR(('abc', ['def', 'ghi', 'jkl'], 'mno'), 'jk')`,
expectedRows: [][]interface{}{{true}},
},
{
name: "contains_substr struct with null",
query: `SELECT CONTAINS_SUBSTR((23, NULL, 41), '41')`,
expectedRows: [][]interface{}{{true}},
},
{
name: "contains_substr struct with null2",
query: `SELECT CONTAINS_SUBSTR((23, NULL, 41), '35')`,
expectedRows: [][]interface{}{{nil}},
},
{
name: "contains_substr nil",
query: `SELECT CONTAINS_SUBSTR('hello', NULL)`,
expectedErr: "CONTAINS_SUBSTR: search literal must be not null",
},
{
name: "contains_substr for table all rows",
query: `
WITH Recipes AS (
SELECT 'Blueberry pancakes' as Breakfast, 'Egg salad sandwich' as Lunch, 'Potato dumplings' as Dinner UNION ALL
SELECT 'Potato pancakes', 'Toasted cheese sandwich', 'Beef stroganoff' UNION ALL
SELECT 'Ham scramble', 'Steak avocado salad', 'Tomato pasta' UNION ALL
SELECT 'Avocado toast', 'Tomato soup', 'Blueberry salmon' UNION ALL
SELECT 'Corned beef hash', 'Lentil potato soup', 'Glazed ham'
) SELECT * FROM Recipes WHERE CONTAINS_SUBSTR(Recipes, 'toast')`,
expectedRows: [][]interface{}{
{"Potato pancakes", "Toasted cheese sandwich", "Beef stroganoff"},
{"Avocado toast", "Tomato soup", "Blueberry salmon"},
},
},
{
name: "contains_substr for table specified rows",
query: `
WITH Recipes AS (
SELECT 'Blueberry pancakes' as Breakfast, 'Egg salad sandwich' as Lunch, 'Potato dumplings' as Dinner UNION ALL
SELECT 'Potato pancakes', 'Toasted cheese sandwich', 'Beef stroganoff' UNION ALL
SELECT 'Ham scramble', 'Steak avocado salad', 'Tomato pasta' UNION ALL
SELECT 'Avocado toast', 'Tomato soup', 'Blueberry salmon' UNION ALL
SELECT 'Corned beef hash', 'Lentil potato soup', 'Glazed ham'
) SELECT * FROM Recipes WHERE CONTAINS_SUBSTR((Lunch, Dinner), 'potato')`,
expectedRows: [][]interface{}{
{"Blueberry pancakes", "Egg salad sandwich", "Potato dumplings"},
{"Corned beef hash", "Lentil potato soup", "Glazed ham"},
},
},
{
name: "contains_substr for table except",
query: `
WITH Recipes AS (
SELECT 'Blueberry pancakes' as Breakfast, 'Egg salad sandwich' as Lunch, 'Potato dumplings' as Dinner UNION ALL
SELECT 'Potato pancakes', 'Toasted cheese sandwich', 'Beef stroganoff' UNION ALL
SELECT 'Ham scramble', 'Steak avocado salad', 'Tomato pasta' UNION ALL
SELECT 'Avocado toast', 'Tomato soup', 'Blueberry salmon' UNION ALL
SELECT 'Corned beef hash', 'Lentil potato soup', 'Glazed ham'
) SELECT * FROM Recipes WHERE CONTAINS_SUBSTR((SELECT AS STRUCT Recipes.* EXCEPT (Lunch, Dinner)), 'potato')`,
expectedRows: [][]interface{}{
{"Potato pancakes", "Toasted cheese sandwich", "Beef stroganoff"},
},
},
{
name: "ends_with",
query: `SELECT ENDS_WITH('apple', 'e'), ENDS_WITH('banana', 'e'), ENDS_WITH('orange', 'e'), ENDS_WITH('foo', NULL), ENDS_WITH(NULL, 'foo')`,
Expand Down

0 comments on commit 7b36bbc

Please sign in to comment.