From 019d92a6d43d100fcd26fd247a13bb8a6bd942aa Mon Sep 17 00:00:00 2001 From: Ariel Mashraki <7413593+a8m@users.noreply.github.com> Date: Thu, 10 Oct 2024 16:19:34 +0300 Subject: [PATCH] cmd/atlas/internal/migratelint: first (or baseline) migration should be loaded fast (#3185) --- cmd/atlas/internal/cmdapi/migrate_test.go | 2 +- cmd/atlas/internal/migratelint/lint.go | 55 ++++++++++++++++++++-- cmd/atlas/internal/migratelint/run.go | 28 ++++++++--- cmd/atlas/internal/migratelint/run_test.go | 3 +- sql/sqlcheck/naming/naming.go | 14 ++++++ 5 files changed, 90 insertions(+), 12 deletions(-) diff --git a/cmd/atlas/internal/cmdapi/migrate_test.go b/cmd/atlas/internal/cmdapi/migrate_test.go index 4c49dce4e91..f68e1f4c3b7 100644 --- a/cmd/atlas/internal/cmdapi/migrate_test.go +++ b/cmd/atlas/internal/cmdapi/migrate_test.go @@ -1665,7 +1665,7 @@ lint { require.Error(t, err) require.Regexp(t, `Analyzing changes from version 1.up to 2.up \(1 migration in total\): - Error: executing statement: near "BORING": syntax error + Error: executing statement: BORING: near "BORING": syntax error ------------------------- -- .+ diff --git a/cmd/atlas/internal/migratelint/lint.go b/cmd/atlas/internal/migratelint/lint.go index 54df932c5c8..9a241804325 100644 --- a/cmd/atlas/internal/migratelint/lint.go +++ b/cmd/atlas/internal/migratelint/lint.go @@ -238,7 +238,15 @@ func (d *DevLoader) LoadChanges(ctx context.Context, base, files []migrate.File) cks = append(cks, i) continue } - if current, err = d.next(ctx, diff.Files[i], current); err != nil { + // A common case is when importing a project to Atlas the baseline + // migration file might be very long. However, since we execute on + // a clean database, the per-statement analysis is not needed. + if len(base) == 0 && i == 0 { + current, err = d.first(ctx, diff.Files[i], current) + } else { + current, err = d.next(ctx, diff.Files[i], current) + } + if err != nil { return nil, err } } @@ -283,17 +291,58 @@ func (d *DevLoader) base(ctx context.Context, base []migrate.File) (*schema.Real return d.inspect(ctx) } +// first is a version of "next" but is used when linting the first migration file. In this case we do not +// need to analyze each statement, but the entire result of the file (much faster). For example, a baseline +// file or the first migration when running 'schema apply' might contain thousands of lines. +func (d *DevLoader) first(ctx context.Context, f *sqlcheck.File, start *schema.Realm) (current *schema.Realm, err error) { + stmts, err := d.stmts(ctx, f.File, true) + if err != nil { + return nil, err + } + // We define the max number of apply-inspect-diff cycles to 10, + // to limit our linting time for baseline/first migration files. + const maxStmtLoop = 10 + if len(stmts) <= maxStmtLoop { + return d.nextStmts(ctx, f, stmts, start) + } + for _, s := range stmts { + if _, err := d.Dev.ExecContext(ctx, s.Text); err != nil { + return nil, &FileError{File: f.Name(), Err: fmt.Errorf("executing statement: %s: %w", s.Text, err), Pos: s.Pos} + } + } + if current, err = d.inspect(ctx); err != nil { + return nil, err + } + changes, err := d.Dev.RealmDiff(start, current) + if err != nil { + return nil, err + } + f.Changes = append(f.Changes, &sqlcheck.Change{ + Changes: changes, + Stmt: &migrate.Stmt{ + Pos: 0, // Beginning of the file. + }, + }) + f.Sum = changes + return current, nil +} + // next returns the next state of the database after executing the statements in // the file. The changes detected by the statements are attached to the file. -func (d *DevLoader) next(ctx context.Context, f *sqlcheck.File, start *schema.Realm) (current *schema.Realm, err error) { +func (d *DevLoader) next(ctx context.Context, f *sqlcheck.File, start *schema.Realm) (*schema.Realm, error) { stmts, err := d.stmts(ctx, f.File, true) if err != nil { return nil, err } + return d.nextStmts(ctx, f, stmts, start) +} + +// nextStmts is a version of "next" but accepts the statements to execute. +func (d *DevLoader) nextStmts(ctx context.Context, f *sqlcheck.File, stmts []*migrate.Stmt, start *schema.Realm) (current *schema.Realm, err error) { current = start for _, s := range stmts { if _, err := d.Dev.ExecContext(ctx, s.Text); err != nil { - return nil, &FileError{File: f.Name(), Err: fmt.Errorf("executing statement: %w", err), Pos: s.Pos} + return nil, &FileError{File: f.Name(), Err: fmt.Errorf("executing statement: %s: %w", s.Text, err), Pos: s.Pos} } next, err := d.inspect(ctx) if err != nil { diff --git a/cmd/atlas/internal/migratelint/run.go b/cmd/atlas/internal/migratelint/run.go index e2139988e66..d8166b45eb5 100644 --- a/cmd/atlas/internal/migratelint/run.go +++ b/cmd/atlas/internal/migratelint/run.go @@ -162,11 +162,22 @@ func (r *Runner) analyze(ctx context.Context, files []*sqlcheck.File) error { continue } for _, az := range r.Analyzers { - err := az.Analyze(ctx, &sqlcheck.Pass{ - File: f, - Dev: r.Dev, - Reporter: nl.reporterFor(fr, az), - }) + err := func(az sqlcheck.Analyzer) (rerr error) { + defer func() { + if rc := recover(); rc != nil { + var name string + if n, ok := az.(sqlcheck.NamedAnalyzer); ok { + name = fmt.Sprintf(" (%s)", n.Name()) + } + rerr = fmt.Errorf("skip crashed analyzer %s: %v", name, rc) + } + }() + return az.Analyze(ctx, &sqlcheck.Pass{ + File: f, + Dev: r.Dev, + Reporter: nl.reporterFor(fr, az), + }) + }(az) // If the last report was skipped, // skip emitting its error. if err != nil && !nl.skipped { @@ -606,8 +617,11 @@ func nolintRules(f *sqlcheck.File) *skipRules { } } for _, c := range f.Changes { - for _, d := range c.Stmt.Directive("nolint") { - s.pos2rules[c.Stmt.Pos] = append(s.pos2rules[c.Stmt.Pos], strings.Split(d, " ")...) + // A list of changes that were loaded in a batch (no statements per change). + if c.Stmt != nil { + for _, d := range c.Stmt.Directive("nolint") { + s.pos2rules[c.Stmt.Pos] = append(s.pos2rules[c.Stmt.Pos], strings.Split(d, " ")...) + } } } return s diff --git a/cmd/atlas/internal/migratelint/run_test.go b/cmd/atlas/internal/migratelint/run_test.go index f78af05d89f..66dccd73d10 100644 --- a/cmd/atlas/internal/migratelint/run_test.go +++ b/cmd/atlas/internal/migratelint/run_test.go @@ -16,6 +16,7 @@ import ( "ariga.io/atlas/sql/sqlcheck" "ariga.io/atlas/sql/sqlclient" + "github.com/fatih/color" "github.com/stretchr/testify/require" ) @@ -24,7 +25,7 @@ func TestRunner_Run(t *testing.T) { b := &bytes.Buffer{} c, err := sqlclient.Open(ctx, "sqlite://run?mode=memory&cache=shared&_fk=1") require.NoError(t, err) - t.Setenv("NO_COLOR", "1") + color.NoColor = true t.Run("checksum mismatch", func(t *testing.T) { var ( diff --git a/sql/sqlcheck/naming/naming.go b/sql/sqlcheck/naming/naming.go index 161c1327a55..f3bf9004128 100644 --- a/sql/sqlcheck/naming/naming.go +++ b/sql/sqlcheck/naming/naming.go @@ -100,6 +100,20 @@ func (a *Analyzer) Analyze(_ context.Context, p *sqlcheck.Pass) error { diags = append(diags, a.match(sc, codeNameS, c.S.Name, "Schema", a.Schema)...) case *schema.AddTable: diags = append(diags, a.match(sc, codeNameT, c.T.Name, "Table", a.Table)...) + for _, c := range c.T.Columns { + diags = append(diags, a.match(sc, codeNameC, c.Name, "Column", a.Column)...) + } + for _, i := range c.T.Indexes { + diags = append(diags, a.match(sc, codeNameI, i.Name, "Index", a.Index)...) + } + for _, f := range c.T.ForeignKeys { + diags = append(diags, a.match(sc, codeNameF, f.Symbol, "Foreign-key constraint", a.ForeignKey)...) + } + for _, at := range c.T.Attrs { + if k, ok := at.(*schema.Check); ok { + diags = append(diags, a.match(sc, codeNameK, k.Name, "Check constraint", a.Check)...) + } + } case *schema.RenameTable: diags = append(diags, a.match(sc, codeNameT, c.To.Name, "Table", a.Table)...) case *schema.ModifyTable: