From 7b36bbc03e329c7aa7a1e2dec11202009d804323 Mon Sep 17 00:00:00 2001 From: Dan Hansen Date: Thu, 11 Apr 2024 10:17:43 -0700 Subject: [PATCH] Support `CONTAINS_SUBSTR` function (#45) * Support `CONTAINS_SUBSTR` function * lint --- driver.go | 5 +- internal/catalog.go | 21 ++++- internal/function_bind.go | 40 ++++++++- internal/function_string.go | 23 ++++- query_test.go | 169 ++++++++++++++++++------------------ 5 files changed, 167 insertions(+), 91 deletions(-) diff --git a/driver.go b/driver.go index 6a9dd33..eaf2ef0 100644 --- a/driver.go +++ b/driver.go @@ -48,7 +48,10 @@ func newDBAndCatalog(name string) (*sql.DB, *internal.Catalog, error) { if err != nil { return nil, nil, fmt.Errorf("failed to open database by %s: %w", name, err) } - catalog := internal.NewCatalog(db) + catalog, err := internal.NewCatalog(db) + if err != nil { + return nil, nil, fmt.Errorf("failed open database by %s: failed to initialize catalog: %w", name, err) + } nameToDBMap[name] = db nameToCatalogMap[name] = catalog return db, catalog, nil diff --git a/internal/catalog.go b/internal/catalog.go index c626dd1..752811b 100644 --- a/internal/catalog.go +++ b/internal/catalog.go @@ -71,13 +71,30 @@ func newSimpleCatalog(name string) *types.SimpleCatalog { return catalog } -func NewCatalog(db *sql.DB) *Catalog { - return &Catalog{ +func NewCatalog(db *sql.DB) (*Catalog, error) { + catalog := &Catalog{ db: db, catalog: newSimpleCatalog(catalogName), tableMap: map[string]*TableSpec{}, funcMap: map[string]*FunctionSpec{}, } + + // Add missing CONTAINS_SUBSTR function to the catalog + // https://github.com/google/zetasql/issues/135#issuecomment-1490908494 + if err := catalog.addFunctionSpec(&FunctionSpec{ + IsTemp: false, + NamePath: []string{"contains_substr"}, + Language: "SQL", + Args: []*NameWithType{ + {Name: "expression", Type: &Type{Name: "expression", Kind: types.STRING, SignatureKind: types.ArgTypeArbitrary}}, + {Name: "search_value_literal", Type: &Type{Name: "search_value_literal", Kind: types.STRING, SignatureKind: types.ArgTypeFixed}}, + }, + Return: &Type{Name: "BOOL", Kind: types.BOOL, SignatureKind: types.ArgTypeFixed}, + }); err != nil { + return nil, err + } + + return catalog, nil } func (c *Catalog) FullName() string { diff --git a/internal/function_bind.go b/internal/function_bind.go index 50da8c2..d1c4492 100644 --- a/internal/function_bind.go +++ b/internal/function_bind.go @@ -916,11 +916,49 @@ func bindContainsSubstr(args ...Value) (Value, error) { if existsNull(args) { return nil, nil } + + // Perform field-by-field match on structs; returning in order: + // true if one field contains substr + // null if at least one field is null + // otherwise false + if structValue, ok := args[0].(*StructValue); ok { + nullExists := false + for _, value := range structValue.values { + if value == nil { + nullExists = true + continue + } + var err error + result, err := bindContainsSubstr(value, args[1]) + + if err != nil { + return nil, err + } + contained, err := result.EQ(BoolValue(true)) + if err != nil { + return nil, err + } + if contained { + return BoolValue(true), nil + } + } + + if nullExists { + return nil, nil + } + + return BoolValue(false), nil + } + + value, err := args[0].ToString() + if err != nil { + return nil, err + } search, err := args[1].ToString() if err != nil { return nil, err } - return CONTAINS_SUBSTR(args[0], search) + return CONTAINS_SUBSTR(value, search) } func bindEndsWith(args ...Value) (Value, error) { diff --git a/internal/function_string.go b/internal/function_string.go index 6e48cc1..92b5abb 100644 --- a/internal/function_string.go +++ b/internal/function_string.go @@ -107,8 +107,27 @@ func CONCAT(args ...Value) (Value, error) { return nil, fmt.Errorf("CONCAT: argument type must be STRING or BYTES") } -func CONTAINS_SUBSTR(exprValue Value, search string) (Value, error) { - return nil, nil +func CONTAINS_SUBSTR(value string, search string) (Value, error) { + normalizedExprValue, err := NORMALIZE_AND_CASEFOLD(value, "NFKC") + if err != nil { + return nil, fmt.Errorf("CONTAINS_SUBSTR: could not normalize and casefold value: %w", err) + } + + normalizedValue, err := normalizedExprValue.ToString() + if err != nil { + return nil, fmt.Errorf("CONTAINS_SUBSTR: could not convert expression to string: %w", err) + } + + normalizedSearchValue, err := NORMALIZE_AND_CASEFOLD(search, "NFKC") + if err != nil { + return nil, fmt.Errorf("CONTAINS_SUBSTR: could not normalize and casefold value: %w", err) + } + + normalizedSearch, err := normalizedSearchValue.ToString() + if err != nil { + return nil, fmt.Errorf("CONTAINS_SUBSTR: could not convert expression to string: %w", err) + } + return BoolValue(strings.Contains(normalizedValue, normalizedSearch)), nil } func ENDS_WITH(value, ends Value) (Value, error) { diff --git a/query_test.go b/query_test.go index d62b70d..caf564d 100644 --- a/query_test.go +++ b/query_test.go @@ -3314,91 +3314,90 @@ SELECT characters, CHARACTER_LENGTH(characters) FROM example`, query: `SELECT CONCAT('T.P.', ' ', 'Bar'), CONCAT('Summer', ' ', 1923), CONCAT("abc"), CONCAT(1), CONCAT('A', NULL, 'C'), CONCAT(NULL)`, expectedRows: [][]interface{}{{"T.P. Bar", "Summer 1923", "abc", "1", nil, nil}}, }, - // TODO: currently unsupported CONTAINS_SUBSTR function because ZetaSQL library doesn't support it. - // { - // name: "contains_substr true", - // query: `SELECT CONTAINS_SUBSTR('the blue house', 'Blue house')`, - // expectedRows: [][]interface{}{{true}}, - // }, - // { - // name: "contains_substr false", - // query: `SELECT CONTAINS_SUBSTR('the red house', 'blue')`, - // expectedRows: [][]interface{}{{false}}, - // }, - // { - // name: "contains_substr normalize", - // query: `SELECT '\u2168 day' AS a, 'IX' AS b, CONTAINS_SUBSTR('\u2168', 'IX')`, - // expectedRows: [][]interface{}{{"Ⅸ day", "IX", true}}, - // }, - // { - // name: "contains_substr struct_field", - // query: `SELECT CONTAINS_SUBSTR((23, 35, 41), '35')`, - // expectedRows: [][]interface{}{{true}}, - // }, - // { - // name: "contains_substr recursive", - // query: `SELECT CONTAINS_SUBSTR(('abc', ['def', 'ghi', 'jkl'], 'mno'), 'jk')`, - // expectedRows: [][]interface{}{{true}}, - // }, - // { - // name: "contains_substr struct with null", - // query: `SELECT CONTAINS_SUBSTR((23, NULL, 41), '41')`, - // expectedRows: [][]interface{}{{true}}, - // }, - // { - // name: "contains_substr struct with null2", - // query: `SELECT CONTAINS_SUBSTR((23, NULL, 41), '35')`, - // expectedRows: [][]interface{}{{nil}}, - // }, - // { - // name: "contains_substr nil", - // query: `SELECT CONTAINS_SUBSTR('hello', NULL)`, - // expectedErr: true, - // }, - // { - // name: "contains_substr for table all rows", - // query: ` - // WITH Recipes AS ( - // SELECT 'Blueberry pancakes' as Breakfast, 'Egg salad sandwich' as Lunch, 'Potato dumplings' as Dinner UNION ALL - // SELECT 'Potato pancakes', 'Toasted cheese sandwich', 'Beef stroganoff' UNION ALL - // SELECT 'Ham scramble', 'Steak avocado salad', 'Tomato pasta' UNION ALL - // SELECT 'Avocado toast', 'Tomato soup', 'Blueberry salmon' UNION ALL - // SELECT 'Corned beef hash', 'Lentil potato soup', 'Glazed ham' - // ) SELECT * FROM Recipes WHERE CONTAINS_SUBSTR(Recipes, 'toast')`, - // expectedRows: [][]interface{}{ - // {"Potato pancakes", "Toasted cheese sandwich", "Beef stroganoff"}, - // {"Avocado toast", "Tomato soup", "Blueberry samon"}, - // }, - // }, - // { - // name: "contains_substr for table specified rows", - // query: ` - // WITH Recipes AS ( - // SELECT 'Blueberry pancakes' as Breakfast, 'Egg salad sandwich' as Lunch, 'Potato dumplings' as Dinner UNION ALL - // SELECT 'Potato pancakes', 'Toasted cheese sandwich', 'Beef stroganoff' UNION ALL - // SELECT 'Ham scramble', 'Steak avocado salad', 'Tomato pasta' UNION ALL - // SELECT 'Avocado toast', 'Tomato soup', 'Blueberry salmon' UNION ALL - // SELECT 'Corned beef hash', 'Lentil potato soup', 'Glazed ham' - // ) SELECT * FROM Recipes WHERE CONTAINS_SUBSTR((Lunch, Dinner), 'potato')`, - // expectedRows: [][]interface{}{ - // {"Bluberry pancakes", "Egg salad sandwich", "Potato dumplings"}, - // {"Corned beef hash", "Lentil potato soup", "Glazed ham"}, - // }, - // }, - // { - // name: "contains_substr for table except", - // query: ` - // WITH Recipes AS ( - // SELECT 'Blueberry pancakes' as Breakfast, 'Egg salad sandwich' as Lunch, 'Potato dumplings' as Dinner UNION ALL - // SELECT 'Potato pancakes', 'Toasted cheese sandwich', 'Beef stroganoff' UNION ALL - // SELECT 'Ham scramble', 'Steak avocado salad', 'Tomato pasta' UNION ALL - // SELECT 'Avocado toast', 'Tomato soup', 'Blueberry salmon' UNION ALL - // SELECT 'Corned beef hash', 'Lentil potato soup', 'Glazed ham' - // ) SELECT * FROM Recipes WHERE CONTAINS_SUBSTR((SELECT AS STRUCT Recipes.* EXCEPT (Lunch, Dinner)), 'potato')`, - // expectedRows: [][]interface{}{ - // {"Potato pancakes", "Toasted cheese sandwich", "Beef stroganoff"}, - // }, - // }, + { + name: "contains_substr true", + query: `SELECT CONTAINS_SUBSTR('the blue house', 'Blue house')`, + expectedRows: [][]interface{}{{true}}, + }, + { + name: "contains_substr false", + query: `SELECT CONTAINS_SUBSTR('the red house', 'blue')`, + expectedRows: [][]interface{}{{false}}, + }, + { + name: "contains_substr normalize", + query: `SELECT '\u2168 day' AS a, 'IX' AS b, CONTAINS_SUBSTR('\u2168', 'IX')`, + expectedRows: [][]interface{}{{"Ⅸ day", "IX", true}}, + }, + { + name: "contains_substr struct_field", + query: `SELECT CONTAINS_SUBSTR((23, 35, 41), '35')`, + expectedRows: [][]interface{}{{true}}, + }, + { + name: "contains_substr recursive", + query: `SELECT CONTAINS_SUBSTR(('abc', ['def', 'ghi', 'jkl'], 'mno'), 'jk')`, + expectedRows: [][]interface{}{{true}}, + }, + { + name: "contains_substr struct with null", + query: `SELECT CONTAINS_SUBSTR((23, NULL, 41), '41')`, + expectedRows: [][]interface{}{{true}}, + }, + { + name: "contains_substr struct with null2", + query: `SELECT CONTAINS_SUBSTR((23, NULL, 41), '35')`, + expectedRows: [][]interface{}{{nil}}, + }, + { + name: "contains_substr nil", + query: `SELECT CONTAINS_SUBSTR('hello', NULL)`, + expectedErr: "CONTAINS_SUBSTR: search literal must be not null", + }, + { + name: "contains_substr for table all rows", + query: ` + WITH Recipes AS ( + SELECT 'Blueberry pancakes' as Breakfast, 'Egg salad sandwich' as Lunch, 'Potato dumplings' as Dinner UNION ALL + SELECT 'Potato pancakes', 'Toasted cheese sandwich', 'Beef stroganoff' UNION ALL + SELECT 'Ham scramble', 'Steak avocado salad', 'Tomato pasta' UNION ALL + SELECT 'Avocado toast', 'Tomato soup', 'Blueberry salmon' UNION ALL + SELECT 'Corned beef hash', 'Lentil potato soup', 'Glazed ham' + ) SELECT * FROM Recipes WHERE CONTAINS_SUBSTR(Recipes, 'toast')`, + expectedRows: [][]interface{}{ + {"Potato pancakes", "Toasted cheese sandwich", "Beef stroganoff"}, + {"Avocado toast", "Tomato soup", "Blueberry salmon"}, + }, + }, + { + name: "contains_substr for table specified rows", + query: ` + WITH Recipes AS ( + SELECT 'Blueberry pancakes' as Breakfast, 'Egg salad sandwich' as Lunch, 'Potato dumplings' as Dinner UNION ALL + SELECT 'Potato pancakes', 'Toasted cheese sandwich', 'Beef stroganoff' UNION ALL + SELECT 'Ham scramble', 'Steak avocado salad', 'Tomato pasta' UNION ALL + SELECT 'Avocado toast', 'Tomato soup', 'Blueberry salmon' UNION ALL + SELECT 'Corned beef hash', 'Lentil potato soup', 'Glazed ham' + ) SELECT * FROM Recipes WHERE CONTAINS_SUBSTR((Lunch, Dinner), 'potato')`, + expectedRows: [][]interface{}{ + {"Blueberry pancakes", "Egg salad sandwich", "Potato dumplings"}, + {"Corned beef hash", "Lentil potato soup", "Glazed ham"}, + }, + }, + { + name: "contains_substr for table except", + query: ` + WITH Recipes AS ( + SELECT 'Blueberry pancakes' as Breakfast, 'Egg salad sandwich' as Lunch, 'Potato dumplings' as Dinner UNION ALL + SELECT 'Potato pancakes', 'Toasted cheese sandwich', 'Beef stroganoff' UNION ALL + SELECT 'Ham scramble', 'Steak avocado salad', 'Tomato pasta' UNION ALL + SELECT 'Avocado toast', 'Tomato soup', 'Blueberry salmon' UNION ALL + SELECT 'Corned beef hash', 'Lentil potato soup', 'Glazed ham' + ) SELECT * FROM Recipes WHERE CONTAINS_SUBSTR((SELECT AS STRUCT Recipes.* EXCEPT (Lunch, Dinner)), 'potato')`, + expectedRows: [][]interface{}{ + {"Potato pancakes", "Toasted cheese sandwich", "Beef stroganoff"}, + }, + }, { name: "ends_with", query: `SELECT ENDS_WITH('apple', 'e'), ENDS_WITH('banana', 'e'), ENDS_WITH('orange', 'e'), ENDS_WITH('foo', NULL), ENDS_WITH(NULL, 'foo')`,