Skip to content

Commit

Permalink
cmd/atlas/internal/migratelint: first (or baseline) migration should …
Browse files Browse the repository at this point in the history
…be loaded fast (#3185)
  • Loading branch information
a8m authored Oct 10, 2024
1 parent 6591c06 commit 019d92a
Show file tree
Hide file tree
Showing 5 changed files with 90 additions and 12 deletions.
2 changes: 1 addition & 1 deletion cmd/atlas/internal/cmdapi/migrate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1665,7 +1665,7 @@ lint {
require.Error(t, err)
require.Regexp(t, `Analyzing changes from version 1.up to 2.up \(1 migration in total\):
Error: executing statement: near "BORING": syntax error
Error: executing statement: BORING: near "BORING": syntax error
-------------------------
-- .+
Expand Down
55 changes: 52 additions & 3 deletions cmd/atlas/internal/migratelint/lint.go
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,15 @@ func (d *DevLoader) LoadChanges(ctx context.Context, base, files []migrate.File)
cks = append(cks, i)
continue
}
if current, err = d.next(ctx, diff.Files[i], current); err != nil {
// A common case is when importing a project to Atlas the baseline
// migration file might be very long. However, since we execute on
// a clean database, the per-statement analysis is not needed.
if len(base) == 0 && i == 0 {
current, err = d.first(ctx, diff.Files[i], current)
} else {
current, err = d.next(ctx, diff.Files[i], current)
}
if err != nil {
return nil, err
}
}
Expand Down Expand Up @@ -283,17 +291,58 @@ func (d *DevLoader) base(ctx context.Context, base []migrate.File) (*schema.Real
return d.inspect(ctx)
}

// first is a version of "next" but is used when linting the first migration file. In this case we do not
// need to analyze each statement, but the entire result of the file (much faster). For example, a baseline
// file or the first migration when running 'schema apply' might contain thousands of lines.
func (d *DevLoader) first(ctx context.Context, f *sqlcheck.File, start *schema.Realm) (current *schema.Realm, err error) {
stmts, err := d.stmts(ctx, f.File, true)
if err != nil {
return nil, err
}
// We define the max number of apply-inspect-diff cycles to 10,
// to limit our linting time for baseline/first migration files.
const maxStmtLoop = 10
if len(stmts) <= maxStmtLoop {
return d.nextStmts(ctx, f, stmts, start)
}
for _, s := range stmts {
if _, err := d.Dev.ExecContext(ctx, s.Text); err != nil {
return nil, &FileError{File: f.Name(), Err: fmt.Errorf("executing statement: %s: %w", s.Text, err), Pos: s.Pos}
}
}
if current, err = d.inspect(ctx); err != nil {
return nil, err
}
changes, err := d.Dev.RealmDiff(start, current)
if err != nil {
return nil, err
}
f.Changes = append(f.Changes, &sqlcheck.Change{
Changes: changes,
Stmt: &migrate.Stmt{
Pos: 0, // Beginning of the file.
},
})
f.Sum = changes
return current, nil
}

// next returns the next state of the database after executing the statements in
// the file. The changes detected by the statements are attached to the file.
func (d *DevLoader) next(ctx context.Context, f *sqlcheck.File, start *schema.Realm) (current *schema.Realm, err error) {
func (d *DevLoader) next(ctx context.Context, f *sqlcheck.File, start *schema.Realm) (*schema.Realm, error) {
stmts, err := d.stmts(ctx, f.File, true)
if err != nil {
return nil, err
}
return d.nextStmts(ctx, f, stmts, start)
}

// nextStmts is a version of "next" but accepts the statements to execute.
func (d *DevLoader) nextStmts(ctx context.Context, f *sqlcheck.File, stmts []*migrate.Stmt, start *schema.Realm) (current *schema.Realm, err error) {
current = start
for _, s := range stmts {
if _, err := d.Dev.ExecContext(ctx, s.Text); err != nil {
return nil, &FileError{File: f.Name(), Err: fmt.Errorf("executing statement: %w", err), Pos: s.Pos}
return nil, &FileError{File: f.Name(), Err: fmt.Errorf("executing statement: %s: %w", s.Text, err), Pos: s.Pos}
}
next, err := d.inspect(ctx)
if err != nil {
Expand Down
28 changes: 21 additions & 7 deletions cmd/atlas/internal/migratelint/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -162,11 +162,22 @@ func (r *Runner) analyze(ctx context.Context, files []*sqlcheck.File) error {
continue
}
for _, az := range r.Analyzers {
err := az.Analyze(ctx, &sqlcheck.Pass{
File: f,
Dev: r.Dev,
Reporter: nl.reporterFor(fr, az),
})
err := func(az sqlcheck.Analyzer) (rerr error) {
defer func() {
if rc := recover(); rc != nil {
var name string
if n, ok := az.(sqlcheck.NamedAnalyzer); ok {
name = fmt.Sprintf(" (%s)", n.Name())
}
rerr = fmt.Errorf("skip crashed analyzer %s: %v", name, rc)
}
}()
return az.Analyze(ctx, &sqlcheck.Pass{
File: f,
Dev: r.Dev,
Reporter: nl.reporterFor(fr, az),
})
}(az)
// If the last report was skipped,
// skip emitting its error.
if err != nil && !nl.skipped {
Expand Down Expand Up @@ -606,8 +617,11 @@ func nolintRules(f *sqlcheck.File) *skipRules {
}
}
for _, c := range f.Changes {
for _, d := range c.Stmt.Directive("nolint") {
s.pos2rules[c.Stmt.Pos] = append(s.pos2rules[c.Stmt.Pos], strings.Split(d, " ")...)
// A list of changes that were loaded in a batch (no statements per change).
if c.Stmt != nil {
for _, d := range c.Stmt.Directive("nolint") {
s.pos2rules[c.Stmt.Pos] = append(s.pos2rules[c.Stmt.Pos], strings.Split(d, " ")...)
}
}
}
return s
Expand Down
3 changes: 2 additions & 1 deletion cmd/atlas/internal/migratelint/run_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (
"ariga.io/atlas/sql/sqlcheck"
"ariga.io/atlas/sql/sqlclient"

"github.com/fatih/color"
"github.com/stretchr/testify/require"
)

Expand All @@ -24,7 +25,7 @@ func TestRunner_Run(t *testing.T) {
b := &bytes.Buffer{}
c, err := sqlclient.Open(ctx, "sqlite://run?mode=memory&cache=shared&_fk=1")
require.NoError(t, err)
t.Setenv("NO_COLOR", "1")
color.NoColor = true

t.Run("checksum mismatch", func(t *testing.T) {
var (
Expand Down
14 changes: 14 additions & 0 deletions sql/sqlcheck/naming/naming.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,20 @@ func (a *Analyzer) Analyze(_ context.Context, p *sqlcheck.Pass) error {
diags = append(diags, a.match(sc, codeNameS, c.S.Name, "Schema", a.Schema)...)
case *schema.AddTable:
diags = append(diags, a.match(sc, codeNameT, c.T.Name, "Table", a.Table)...)
for _, c := range c.T.Columns {
diags = append(diags, a.match(sc, codeNameC, c.Name, "Column", a.Column)...)
}
for _, i := range c.T.Indexes {
diags = append(diags, a.match(sc, codeNameI, i.Name, "Index", a.Index)...)
}
for _, f := range c.T.ForeignKeys {
diags = append(diags, a.match(sc, codeNameF, f.Symbol, "Foreign-key constraint", a.ForeignKey)...)
}
for _, at := range c.T.Attrs {
if k, ok := at.(*schema.Check); ok {
diags = append(diags, a.match(sc, codeNameK, k.Name, "Check constraint", a.Check)...)
}
}
case *schema.RenameTable:
diags = append(diags, a.match(sc, codeNameT, c.To.Name, "Table", a.Table)...)
case *schema.ModifyTable:
Expand Down

0 comments on commit 019d92a

Please sign in to comment.