diff --git a/engine.go b/engine.go index a622f2b24..336be4420 100644 --- a/engine.go +++ b/engine.go @@ -120,7 +120,7 @@ func (e *Engine) Query( case *plan.CreateIndex: typ = sql.CreateIndexProcess perm = auth.ReadPerm | auth.WritePerm - case *plan.InsertInto, *plan.DeleteFrom, *plan.Update, *plan.DropIndex, *plan.UnlockTables, *plan.LockTables, *plan.CreateView: + case *plan.InsertInto, *plan.DeleteFrom, *plan.Update, *plan.DropIndex, *plan.UnlockTables, *plan.LockTables, *plan.CreateView, *plan.DropView: perm = auth.ReadPerm | auth.WritePerm } diff --git a/engine_test.go b/engine_test.go index 8846d0ebd..be229f91c 100644 --- a/engine_test.go +++ b/engine_test.go @@ -3149,6 +3149,7 @@ func TestReadOnly(t *testing.T) { `DROP INDEX foo ON mytable`, `INSERT INTO mytable (i, s) VALUES(42, 'yolo')`, `CREATE VIEW myview AS SELECT i FROM mytable`, + `DROP VIEW myview`, } for _, query := range writingQueries { diff --git a/sql/analyzer/assign_catalog.go b/sql/analyzer/assign_catalog.go index 104451454..1accc5ed8 100644 --- a/sql/analyzer/assign_catalog.go +++ b/sql/analyzer/assign_catalog.go @@ -64,6 +64,10 @@ func assignCatalog(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) nc := *node nc.Catalog = a.Catalog return &nc, nil + case *plan.DropView: + nc := *node + nc.Catalog = a.Catalog + return &nc, nil default: return n, nil } diff --git a/sql/analyzer/assign_catalog_test.go b/sql/analyzer/assign_catalog_test.go index a7fd9dff0..22c66059e 100644 --- a/sql/analyzer/assign_catalog_test.go +++ b/sql/analyzer/assign_catalog_test.go @@ -81,4 +81,10 @@ func TestAssignCatalog(t *testing.T) { cv, ok := node.(*plan.CreateView) require.True(ok) require.Equal(c, cv.Catalog) + + node, err = f.Apply(sql.NewEmptyContext(), a, plan.NewDropView(nil, false)) + require.NoError(err) + dv, ok := node.(*plan.DropView) + require.True(ok) + require.Equal(c, dv.Catalog) } diff --git a/sql/parse/parse.go b/sql/parse/parse.go index aea61f990..5c516dc76 100644 --- a/sql/parse/parse.go +++ b/sql/parse/parse.go @@ -37,6 +37,7 @@ var ( describeTablesRegex = regexp.MustCompile(`^(describe|desc)\s+table\s+(.*)`) createIndexRegex = regexp.MustCompile(`^create\s+index\s+`) createViewRegex = regexp.MustCompile(`^create\s+(or\s+replace\s+)?view\s+`) + dropViewRegex = regexp.MustCompile(`^drop\s+(if\s+exists\s+)?view\s+`) dropIndexRegex = regexp.MustCompile(`^drop\s+index\s+`) showIndexRegex = regexp.MustCompile(`^show\s+(index|indexes|keys)\s+(from|in)\s+\S+\s*`) showCreateRegex = regexp.MustCompile(`^show create\s+\S+\s*`) @@ -84,6 +85,8 @@ func Parse(ctx *sql.Context, query string) (sql.Node, error) { return parseCreateIndex(ctx, s) case createViewRegex.MatchString(lowerQuery): return parseCreateView(ctx, s) + case dropViewRegex.MatchString(lowerQuery): + return parseDropView(ctx, s) case dropIndexRegex.MatchString(lowerQuery): return parseDropIndex(s) case showIndexRegex.MatchString(lowerQuery): diff --git a/sql/parse/util.go b/sql/parse/util.go index b438ad349..b14375a2a 100644 --- a/sql/parse/util.go +++ b/sql/parse/util.go @@ -516,3 +516,65 @@ func maybeList(opening, separator, closing rune, list *[]string) parseFunc { } } } + +// A qualifiedName represents an identifier of type "db_name.table_name" +type qualifiedName struct { + qualifier string + name string +} + +// readQualifiedIdentifierList reads a comma-separated list of qualifiedNames. +// Any number of spaces between the qualified names are accepted. The qualifier +// may be empty, in which case the period is optional. +// An example of a correctly formed list is: +// "my_db.myview, db_2.mytable , aTable" +func readQualifiedIdentifierList(list *[]qualifiedName) parseFunc { + return func(rd *bufio.Reader) error { + for { + var newItem []string + err := parseFuncs{ + skipSpaces, + readIdentList('.', &newItem), + skipSpaces, + }.exec(rd) + + if err != nil { + return err + } + + if len(newItem) < 1 || len(newItem) > 2 { + return errUnexpectedSyntax.New( + "[qualifier.]name", + strings.Join(newItem, "."), + ) + } + + var qualifier, name string + + if len(newItem) == 1 { + qualifier = "" + name = newItem[0] + } else { + qualifier = newItem[0] + name = newItem[1] + } + + *list = append(*list, qualifiedName{qualifier, name}) + + r, _, err := rd.ReadRune() + if err != nil { + if err == io.EOF { + return nil + } + return err + } + + switch r { + case ',': + continue + default: + return rd.UnreadRune() + } + } + } +} diff --git a/sql/parse/util_test.go b/sql/parse/util_test.go index b1cd07fe5..3cb5f49f8 100644 --- a/sql/parse/util_test.go +++ b/sql/parse/util_test.go @@ -465,3 +465,60 @@ func TestReadSpaces(t *testing.T) { require.Equal(fixture.expectedRemaining, actualRemaining) } } + +// Tests that readQualifiedIdentifierList correctly parses well-formed lists, +// populating the list of identifiers, and that it errors with partial lists +// and when it does not found any identifiers +func TestReadQualifiedIdentifierList(t *testing.T) { + require := require.New(t) + + testFixtures := []struct { + string string + expectedList []qualifiedName + expectedError bool + expectedRemaining string + }{ + { + "my_db.myview, db_2.mytable , aTable", + []qualifiedName{{"my_db", "myview"}, {"db_2", "mytable"}, {"", "aTable"}}, + false, + "", + }, + { + "single_identifier -remaining", + []qualifiedName{{"", "single_identifier"}}, + false, + "-remaining", + }, + { + "", + nil, + true, + "", + }, + { + "partial_list,", + []qualifiedName{{"", "partial_list"}}, + true, + "", + }, + } + + for _, fixture := range testFixtures { + reader := bufio.NewReader(strings.NewReader(fixture.string)) + var actualList []qualifiedName + + err := readQualifiedIdentifierList(&actualList)(reader) + + if fixture.expectedError { + require.Error(err) + } else { + require.NoError(err) + } + + require.Equal(fixture.expectedList, actualList) + + actualRemaining, _ := reader.ReadString('\n') + require.Equal(fixture.expectedRemaining, actualRemaining) + } +} diff --git a/sql/parse/views.go b/sql/parse/views.go index 11b7bbf2f..d543952bf 100644 --- a/sql/parse/views.go +++ b/sql/parse/views.go @@ -13,6 +13,7 @@ import ( var ErrMalformedViewName = errors.NewKind("the view name '%s' is not correct") var ErrMalformedCreateView = errors.NewKind("view definition %#v is not a SELECT query") +var ErrViewsToDropNotFound = errors.NewKind("the list of views to drop must contain at least one view") // parseCreateView parses // CREATE [OR REPLACE] VIEW [db_name.]view_name AS select_statement @@ -87,3 +88,48 @@ func parseCreateView(ctx *sql.Context, s string) (sql.Node, error) { sql.UnresolvedDatabase(databaseName), viewName, columns, subqueryAlias, isReplace, ), nil } + +// parseDropView parses +// DROP VIEW [IF EXISTS] [db_name1.]view_name1 [, [db_name2.]view_name2, ...] +// [RESTRICT] [CASCADE] +// and returns a DropView node in case of success. As per MySQL specification, +// RESTRICT and CASCADE, if given, are parsed and ignored. +func parseDropView(ctx *sql.Context, s string) (sql.Node, error) { + r := bufio.NewReader(strings.NewReader(s)) + + var ( + views []qualifiedName + ifExists bool + unusedBool bool + ) + + err := parseFuncs{ + expect("drop"), + skipSpaces, + expect("view"), + skipSpaces, + multiMaybe(&ifExists, "if", "exists"), + skipSpaces, + readQualifiedIdentifierList(&views), + skipSpaces, + maybe(&unusedBool, "restrict"), + skipSpaces, + maybe(&unusedBool, "cascade"), + checkEOF, + }.exec(r) + + if err != nil { + return nil, err + } + + if len(views) < 1 { + return nil, ErrViewsToDropNotFound.New() + } + + plans := make([]sql.Node, len(views)) + for i, view := range views { + plans[i] = plan.NewSingleDropView(sql.UnresolvedDatabase(view.qualifier), view.name) + } + + return plan.NewDropView(plans, ifExists), nil +} diff --git a/sql/parse/views_test.go b/sql/parse/views_test.go index 16b0e38b1..c0a5b53f6 100644 --- a/sql/parse/views_test.go +++ b/sql/parse/views_test.go @@ -75,3 +75,77 @@ func TestParseCreateView(t *testing.T) { }) } } + +func TestParseDropView(t *testing.T) { + var fixtures = map[string]sql.Node{ + `DROP VIEW view1`: plan.NewDropView( + []sql.Node{plan.NewSingleDropView(sql.UnresolvedDatabase(""), "view1")}, + false, + ), + `DROP VIEW view1, view2`: plan.NewDropView( + []sql.Node{ + plan.NewSingleDropView(sql.UnresolvedDatabase(""), "view1"), + plan.NewSingleDropView(sql.UnresolvedDatabase(""), "view2"), + }, + false, + ), + `DROP VIEW db1.view1`: plan.NewDropView( + []sql.Node{ + plan.NewSingleDropView(sql.UnresolvedDatabase("db1"), "view1"), + }, + false, + ), + `DROP VIEW db1.view1, view2`: plan.NewDropView( + []sql.Node{ + plan.NewSingleDropView(sql.UnresolvedDatabase("db1"), "view1"), + plan.NewSingleDropView(sql.UnresolvedDatabase(""), "view2"), + }, + false, + ), + `DROP VIEW view1, db2.view2`: plan.NewDropView( + []sql.Node{ + plan.NewSingleDropView(sql.UnresolvedDatabase(""), "view1"), + plan.NewSingleDropView(sql.UnresolvedDatabase("db2"), "view2"), + }, + false, + ), + `DROP VIEW db1.view1, db2.view2`: plan.NewDropView( + []sql.Node{ + plan.NewSingleDropView(sql.UnresolvedDatabase("db1"), "view1"), + plan.NewSingleDropView(sql.UnresolvedDatabase("db2"), "view2"), + }, + false, + ), + `DROP VIEW IF EXISTS myview`: plan.NewDropView( + []sql.Node{plan.NewSingleDropView(sql.UnresolvedDatabase(""), "myview")}, + true, + ), + `DROP VIEW IF EXISTS db1.view1, db2.view2`: plan.NewDropView( + []sql.Node{ + plan.NewSingleDropView(sql.UnresolvedDatabase("db1"), "view1"), + plan.NewSingleDropView(sql.UnresolvedDatabase("db2"), "view2"), + }, + true, + ), + `DROP VIEW IF EXISTS db1.view1, db2.view2 RESTRICT CASCADE`: plan.NewDropView( + []sql.Node{ + plan.NewSingleDropView(sql.UnresolvedDatabase("db1"), "view1"), + plan.NewSingleDropView(sql.UnresolvedDatabase("db2"), "view2"), + }, + true, + ), + } + + for query, expectedPlan := range fixtures { + t.Run(query, func(t *testing.T) { + require := require.New(t) + + ctx := sql.NewEmptyContext() + lowerquery := strings.ToLower(query) + result, err := parseDropView(ctx, lowerquery) + + require.NoError(err) + require.Equal(expectedPlan, result) + }) + } +} diff --git a/sql/plan/drop_view.go b/sql/plan/drop_view.go new file mode 100644 index 000000000..718d2bfcd --- /dev/null +++ b/sql/plan/drop_view.go @@ -0,0 +1,148 @@ +package plan + +import ( + "github.com/src-d/go-mysql-server/sql" + errors "gopkg.in/src-d/go-errors.v1" +) + +var errDropViewChild = errors.NewKind("any child of DropView must be of type SingleDropView") + +type SingleDropView struct { + database sql.Database + viewName string +} + +// NewSingleDropView creates a SingleDropView. +func NewSingleDropView( + database sql.Database, + viewName string, +) *SingleDropView { + return &SingleDropView{database, viewName} +} + +// Children implements the Node interface. It always returns nil. +func (dv *SingleDropView) Children() []sql.Node { + return nil +} + +// Resolved implements the Node interface. This node is resolved if and only if +// its database is resolved. +func (dv *SingleDropView) Resolved() bool { + _, ok := dv.database.(sql.UnresolvedDatabase) + return !ok +} + +// RowIter implements the Node interface. It always returns an empty iterator. +func (dv *SingleDropView) RowIter(ctx *sql.Context) (sql.RowIter, error) { + return sql.RowsToRowIter(), nil +} + +// Schema implements the Node interface. It always returns nil. +func (dv *SingleDropView) Schema() sql.Schema { return nil } + +// String implements the fmt.Stringer interface, using sql.TreePrinter to +// generate the string. +func (dv *SingleDropView) String() string { + pr := sql.NewTreePrinter() + _ = pr.WriteNode("SingleDropView(%s.%s)", dv.database.Name(), dv.viewName) + + return pr.String() +} + +// WithChildren implements the Node interface. It only succeeds if the length +// of the specified children equals 0. +func (dv *SingleDropView) WithChildren(children ...sql.Node) (sql.Node, error) { + if len(children) != 0 { + return nil, sql.ErrInvalidChildrenNumber.New(dv, len(children), 0) + } + + return dv, nil +} + +// Database implements the Databaser interfacee. It returns the node's database. +func (dv *SingleDropView) Database() sql.Database { + return dv.database +} + +// Database implements the Databaser interface, and it returns a copy of this +// node with the specified database. +func (dv *SingleDropView) WithDatabase(database sql.Database) (sql.Node, error) { + newDrop := *dv + newDrop.database = database + return &newDrop, nil +} + +// DropView is a node representing the removal of a list of views, defined by +// the children member. The flag ifExists represents whether the user wants the +// node to fail if any of the views in children does not exist. +type DropView struct { + children []sql.Node + Catalog *sql.Catalog + ifExists bool +} + +// NewDropView creates a DropView node with the specified parameters, +// setting its catalog to nil. +func NewDropView(children []sql.Node, ifExists bool) *DropView { + return &DropView{children, nil, ifExists} +} + +// Children implements the Node interface. It returns the children of the +// CreateView node; i.e., all the views that will be dropped. +func (dvs *DropView) Children() []sql.Node { + return dvs.children +} + +// Resolved implements the Node interface. This node is resolved if and only if +// all of its children are resolved. +func (dvs *DropView) Resolved() bool { + for _, child := range dvs.children { + if !child.Resolved() { + return false + } + } + return true +} + +// RowIter implements the Node interface. When executed, this function drops +// all the views defined by the node's children. It errors if the flag ifExists +// is set to false and there is some view that does not exist. +func (dvs *DropView) RowIter(ctx *sql.Context) (sql.RowIter, error) { + viewList := make([]sql.ViewKey, len(dvs.children)) + for i, child := range dvs.children { + drop, ok := child.(*SingleDropView) + if !ok { + return sql.RowsToRowIter(), errDropViewChild.New() + } + + viewList[i] = sql.NewViewKey(drop.database.Name(), drop.viewName) + } + + return sql.RowsToRowIter(), dvs.Catalog.ViewRegistry.DeleteList(viewList, !dvs.ifExists) +} + +// Schema implements the Node interface. It always returns nil. +func (dvs *DropView) Schema() sql.Schema { return nil } + +// String implements the fmt.Stringer interface, using sql.TreePrinter to +// generate the string. +func (dvs *DropView) String() string { + childrenStrings := make([]string, len(dvs.children)) + for i, child := range dvs.children { + childrenStrings[i] = child.String() + } + + pr := sql.NewTreePrinter() + _ = pr.WriteNode("DropView") + _ = pr.WriteChildren(childrenStrings...) + + return pr.String() +} + +// WithChildren implements the Node interface. It always suceeds, returning a +// copy of this node with the new array of nodes as children. +func (dvs *DropView) WithChildren(children ...sql.Node) (sql.Node, error) { + newDrop := dvs + newDrop.children = children + return newDrop, nil +} diff --git a/sql/plan/drop_view_test.go b/sql/plan/drop_view_test.go new file mode 100644 index 000000000..0f9287fcc --- /dev/null +++ b/sql/plan/drop_view_test.go @@ -0,0 +1,94 @@ +package plan + +import ( + "testing" + + "github.com/src-d/go-mysql-server/memory" + "github.com/src-d/go-mysql-server/sql" + "github.com/src-d/go-mysql-server/sql/expression" + + "github.com/stretchr/testify/require" +) + +// Generates a database with a single table called mytable and a catalog with +// the view that is also returned. The context returned is the one used to +// create the view. +func mockData(require *require.Assertions) (sql.Database, *sql.Catalog, *sql.Context, sql.View) { + table := memory.NewTable("mytable", sql.Schema{ + {Name: "i", Source: "mytable", Type: sql.Int32}, + {Name: "s", Source: "mytable", Type: sql.Text}, + }) + + db := memory.NewDatabase("db") + db.AddTable("db", table) + + catalog := sql.NewCatalog() + catalog.AddDatabase(db) + + subqueryAlias := NewSubqueryAlias("myview", + NewProject( + []sql.Expression{ + expression.NewGetFieldWithTable(1, sql.Int32, table.Name(), "i", true), + }, + NewUnresolvedTable("dual", ""), + ), + ) + + createView := NewCreateView(db, subqueryAlias.Name(), nil, subqueryAlias, false) + createView.Catalog = catalog + + ctx := sql.NewEmptyContext() + + _, err := createView.RowIter(ctx) + require.NoError(err) + + return db, catalog, ctx, createView.View() +} + +// Tests that DropView works as expected and that the view is dropped in +// the catalog when RowIter is called, regardless of the value of ifExists +func TestDropExistingView(t *testing.T) { + require := require.New(t) + + test := func(ifExists bool) { + db, catalog, ctx, view := mockData(require) + + singleDropView := NewSingleDropView(db, view.Name()) + dropView := NewDropView([]sql.Node{singleDropView}, ifExists) + dropView.Catalog = catalog + + _, err := dropView.RowIter(ctx) + require.NoError(err) + + require.False(catalog.ViewRegistry.Exists(db.Name(), view.Name())) + } + + test(false) + test(true) +} + +// Tests that DropView errors when trying to delete a non-existing view if and +// only if the flag ifExists is set to false +func TestDropNonExistingView(t *testing.T) { + require := require.New(t) + + test := func(ifExists bool) error { + db, catalog, ctx, view := mockData(require) + + singleDropView := NewSingleDropView(db, "non-existing-view") + dropView := NewDropView([]sql.Node{singleDropView}, ifExists) + dropView.Catalog = catalog + + _, err := dropView.RowIter(ctx) + + require.True(catalog.ViewRegistry.Exists(db.Name(), view.Name())) + + return err + } + + err := test(true) + require.NoError(err) + + err = test(false) + require.Error(err) +} diff --git a/sql/viewregistry.go b/sql/viewregistry.go index bf62f1397..d7ebb577c 100644 --- a/sql/viewregistry.go +++ b/sql/viewregistry.go @@ -35,26 +35,26 @@ func (v *View) Definition() Node { // Views are scoped by the databases in which they were defined, so a key in // the view registry is a pair of names: database and view. -type viewKey struct { +type ViewKey struct { dbName, viewName string } -// newViewKey creates a viewKey ensuring both names are lowercase. -func newViewKey(databaseName, viewName string) viewKey { - return viewKey{strings.ToLower(databaseName), strings.ToLower(viewName)} +// NewViewKey creates a ViewKey ensuring both names are lowercase. +func NewViewKey(databaseName, viewName string) ViewKey { + return ViewKey{strings.ToLower(databaseName), strings.ToLower(viewName)} } -// ViewRegistry is a map of viewKey to View whose access is protected by a +// ViewRegistry is a map of ViewKey to View whose access is protected by a // RWMutex. type ViewRegistry struct { mutex sync.RWMutex - views map[viewKey]View + views map[ViewKey]View } // NewViewRegistry creates an empty ViewRegistry. func NewViewRegistry() *ViewRegistry { return &ViewRegistry{ - views: make(map[viewKey]View), + views: make(map[ViewKey]View), } } @@ -64,7 +64,7 @@ func (r *ViewRegistry) Register(database string, view View) error { r.mutex.Lock() defer r.mutex.Unlock() - key := newViewKey(database, view.Name()) + key := NewViewKey(database, view.Name()) if _, ok := r.views[key]; ok { return ErrExistingView.New(database, view.Name()) @@ -80,7 +80,7 @@ func (r *ViewRegistry) Delete(databaseName, viewName string) error { r.mutex.Lock() defer r.mutex.Unlock() - key := newViewKey(databaseName, viewName) + key := NewViewKey(databaseName, viewName) if _, ok := r.views[key]; !ok { return ErrNonExistingView.New(databaseName, viewName) @@ -90,13 +90,36 @@ func (r *ViewRegistry) Delete(databaseName, viewName string) error { return nil } +// DeleteList tries to delete a list of view keys. +// If the list contains views that do exist and views that do not, the existing +// views are deleted if and only if the errIfNotExists flag is set to false; if +// it is set to true, no views are deleted and an error is returned. +func (r *ViewRegistry) DeleteList(keys []ViewKey, errIfNotExists bool) error { + r.mutex.Lock() + defer r.mutex.Unlock() + + if errIfNotExists { + for _, key := range keys { + if !r.exists(key.dbName, key.viewName) { + return ErrNonExistingView.New(key.dbName, key.viewName) + } + } + } + + for _, key := range keys { + delete(r.views, key) + } + + return nil +} + // View returns a pointer to the view specified by the pair {databaseName, // viewName}, returning an error if it does not exist. func (r *ViewRegistry) View(databaseName, viewName string) (*View, error) { r.mutex.RLock() defer r.mutex.RUnlock() - key := newViewKey(databaseName, viewName) + key := NewViewKey(databaseName, viewName) if view, ok := r.views[key]; ok { return &view, nil @@ -106,7 +129,7 @@ func (r *ViewRegistry) View(databaseName, viewName string) (*View, error) { } // AllViews returns the map of all views in the registry. -func (r *ViewRegistry) AllViews() map[viewKey]View { +func (r *ViewRegistry) AllViews() map[ViewKey]View { r.mutex.RLock() defer r.mutex.RUnlock() @@ -127,3 +150,18 @@ func (r *ViewRegistry) ViewsInDatabase(databaseName string) (views []View) { return views } + +func (r *ViewRegistry) exists(databaseName, viewName string) bool { + key := NewViewKey(databaseName, viewName) + _, ok := r.views[key] + + return ok +} + +// Exists returns whether the specified key is already registered +func (r *ViewRegistry) Exists(databaseName, viewName string) bool { + r.mutex.RLock() + defer r.mutex.RUnlock() + + return r.exists(databaseName, viewName) +} diff --git a/sql/viewregistry_test.go b/sql/viewregistry_test.go index f33a88c60..adf1ef6c0 100644 --- a/sql/viewregistry_test.go +++ b/sql/viewregistry_test.go @@ -12,6 +12,16 @@ var ( mockView = NewView(viewName, nil) ) +func newMockRegistry(require *require.Assertions) *ViewRegistry { + registry := NewViewRegistry() + + err := registry.Register(dbName, mockView) + require.NoError(err) + require.Equal(1, len(registry.AllViews())) + + return registry +} + // Tests the creation of an empty ViewRegistry with no views registered. func TestNewViewRegistry(t *testing.T) { require := require.New(t) @@ -24,11 +34,7 @@ func TestNewViewRegistry(t *testing.T) { func TestRegisterNonExistingView(t *testing.T) { require := require.New(t) - registry := NewViewRegistry() - - err := registry.Register(dbName, mockView) - require.NoError(err) - require.Equal(1, len(registry.AllViews())) + registry := newMockRegistry(require) actualView, err := registry.View(dbName, viewName) require.NoError(err) @@ -39,13 +45,9 @@ func TestRegisterNonExistingView(t *testing.T) { func TestRegisterExistingVIew(t *testing.T) { require := require.New(t) - registry := NewViewRegistry() + registry := newMockRegistry(require) err := registry.Register(dbName, mockView) - require.NoError(err) - require.Equal(1, len(registry.AllViews())) - - err = registry.Register(dbName, mockView) require.Error(err) require.True(ErrExistingView.Is(err)) } @@ -54,13 +56,9 @@ func TestRegisterExistingVIew(t *testing.T) { func TestDeleteExistingView(t *testing.T) { require := require.New(t) - registry := NewViewRegistry() - - err := registry.Register(dbName, mockView) - require.NoError(err) - require.Equal(1, len(registry.AllViews())) + registry := newMockRegistry(require) - err = registry.Delete(dbName, viewName) + err := registry.Delete(dbName, viewName) require.NoError(err) require.Equal(0, len(registry.AllViews())) } @@ -81,11 +79,7 @@ func TestDeleteNonExistingView(t *testing.T) { func TestGetExistingView(t *testing.T) { require := require.New(t) - registry := NewViewRegistry() - - err := registry.Register(dbName, mockView) - require.NoError(err) - require.Equal(1, len(registry.AllViews())) + registry := newMockRegistry(require) actualView, err := registry.View(dbName, viewName) require.NoError(err) @@ -131,3 +125,100 @@ func TestViewsInDatabase(t *testing.T) { require.Equal(db.numViews, len(views)) } } + +var viewKeys = []ViewKey{ + { + "db1", + "view1", + }, + { + "db1", + "view2", + }, + { + "db2", + "view1", + }, +} + +func registerKeys(registry *ViewRegistry, require *require.Assertions) { + for _, key := range viewKeys { + err := registry.Register(key.dbName, NewView(key.viewName, nil)) + require.NoError(err) + } + require.Equal(len(viewKeys), len(registry.AllViews())) +} + +func TestDeleteExistingList(t *testing.T) { + require := require.New(t) + + test := func(errIfNotExists bool) { + registry := NewViewRegistry() + + registerKeys(registry, require) + err := registry.DeleteList(viewKeys, errIfNotExists) + require.NoError(err) + require.Equal(0, len(registry.AllViews())) + } + + test(true) + test(false) +} + +func TestDeleteNonExistingList(t *testing.T) { + require := require.New(t) + + test := func(errIfNotExists bool) { + registry := NewViewRegistry() + + registerKeys(registry, require) + err := registry.DeleteList([]ViewKey{{"random", "random"}}, errIfNotExists) + if errIfNotExists { + require.Error(err) + } else { + require.NoError(err) + } + require.Equal(len(viewKeys), len(registry.AllViews())) + } + + test(false) + test(true) +} + +func TestDeletePartiallyExistingList(t *testing.T) { + require := require.New(t) + + test := func(errIfNotExists bool) { + registry := NewViewRegistry() + + registerKeys(registry, require) + toDelete := append(viewKeys, ViewKey{"random", "random"}) + err := registry.DeleteList(toDelete, errIfNotExists) + if errIfNotExists { + require.Error(err) + require.Equal(len(viewKeys), len(registry.AllViews())) + } else { + require.NoError(err) + require.Equal(0, len(registry.AllViews())) + } + } + + test(false) + test(true) +} + +func TestExistsOnExistingView(t *testing.T) { + require := require.New(t) + + registry := newMockRegistry(require) + + require.True(registry.Exists(dbName, viewName)) +} + +func TestExistsOnNonExistingView(t *testing.T) { + require := require.New(t) + + registry := newMockRegistry(require) + + require.False(registry.Exists("non", "existing")) +}