Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

cmd/atlas/internal/migratelint: first (or baseline) migration should … #3185

Merged
merged 1 commit into from
Oct 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cmd/atlas/internal/cmdapi/migrate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1665,7 +1665,7 @@ lint {
require.Error(t, err)
require.Regexp(t, `Analyzing changes from version 1.up to 2.up \(1 migration in total\):
Error: executing statement: near "BORING": syntax error
Error: executing statement: BORING: near "BORING": syntax error
-------------------------
-- .+
Expand Down
55 changes: 52 additions & 3 deletions cmd/atlas/internal/migratelint/lint.go
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,15 @@ func (d *DevLoader) LoadChanges(ctx context.Context, base, files []migrate.File)
cks = append(cks, i)
continue
}
if current, err = d.next(ctx, diff.Files[i], current); err != nil {
// A common case is when importing a project to Atlas the baseline
// migration file might be very long. However, since we execute on
// a clean database, the per-statement analysis is not needed.
if len(base) == 0 && i == 0 {
current, err = d.first(ctx, diff.Files[i], current)
} else {
current, err = d.next(ctx, diff.Files[i], current)
}
if err != nil {
return nil, err
}
}
Expand Down Expand Up @@ -283,17 +291,58 @@ func (d *DevLoader) base(ctx context.Context, base []migrate.File) (*schema.Real
return d.inspect(ctx)
}

// first is a version of "next" but is used when linting the first migration file. In this case we do not
// need to analyze each statement, but the entire result of the file (much faster). For example, a baseline
// file or the first migration when running 'schema apply' might contain thousands of lines.
func (d *DevLoader) first(ctx context.Context, f *sqlcheck.File, start *schema.Realm) (current *schema.Realm, err error) {
stmts, err := d.stmts(ctx, f.File, true)
if err != nil {
return nil, err
}
// We define the max number of apply-inspect-diff cycles to 10,
// to limit our linting time for baseline/first migration files.
const maxStmtLoop = 10
if len(stmts) <= maxStmtLoop {
return d.nextStmts(ctx, f, stmts, start)
}
for _, s := range stmts {
if _, err := d.Dev.ExecContext(ctx, s.Text); err != nil {
return nil, &FileError{File: f.Name(), Err: fmt.Errorf("executing statement: %s: %w", s.Text, err), Pos: s.Pos}
}
}
if current, err = d.inspect(ctx); err != nil {
return nil, err
}
changes, err := d.Dev.RealmDiff(start, current)
if err != nil {
return nil, err
}
f.Changes = append(f.Changes, &sqlcheck.Change{
Changes: changes,
Stmt: &migrate.Stmt{
Pos: 0, // Beginning of the file.
},
})
f.Sum = changes
return current, nil
}

// next returns the next state of the database after executing the statements in
// the file. The changes detected by the statements are attached to the file.
func (d *DevLoader) next(ctx context.Context, f *sqlcheck.File, start *schema.Realm) (current *schema.Realm, err error) {
func (d *DevLoader) next(ctx context.Context, f *sqlcheck.File, start *schema.Realm) (*schema.Realm, error) {
stmts, err := d.stmts(ctx, f.File, true)
if err != nil {
return nil, err
}
return d.nextStmts(ctx, f, stmts, start)
}

// nextStmts is a version of "next" but accepts the statements to execute.
func (d *DevLoader) nextStmts(ctx context.Context, f *sqlcheck.File, stmts []*migrate.Stmt, start *schema.Realm) (current *schema.Realm, err error) {
current = start
for _, s := range stmts {
if _, err := d.Dev.ExecContext(ctx, s.Text); err != nil {
return nil, &FileError{File: f.Name(), Err: fmt.Errorf("executing statement: %w", err), Pos: s.Pos}
return nil, &FileError{File: f.Name(), Err: fmt.Errorf("executing statement: %s: %w", s.Text, err), Pos: s.Pos}
}
next, err := d.inspect(ctx)
if err != nil {
Expand Down
28 changes: 21 additions & 7 deletions cmd/atlas/internal/migratelint/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -162,11 +162,22 @@ func (r *Runner) analyze(ctx context.Context, files []*sqlcheck.File) error {
continue
}
for _, az := range r.Analyzers {
err := az.Analyze(ctx, &sqlcheck.Pass{
File: f,
Dev: r.Dev,
Reporter: nl.reporterFor(fr, az),
})
err := func(az sqlcheck.Analyzer) (rerr error) {
defer func() {
if rc := recover(); rc != nil {
var name string
if n, ok := az.(sqlcheck.NamedAnalyzer); ok {
name = fmt.Sprintf(" (%s)", n.Name())
}
rerr = fmt.Errorf("skip crashed analyzer %s: %v", name, rc)
}
}()
return az.Analyze(ctx, &sqlcheck.Pass{
File: f,
Dev: r.Dev,
Reporter: nl.reporterFor(fr, az),
})
}(az)
// If the last report was skipped,
// skip emitting its error.
if err != nil && !nl.skipped {
Expand Down Expand Up @@ -606,8 +617,11 @@ func nolintRules(f *sqlcheck.File) *skipRules {
}
}
for _, c := range f.Changes {
for _, d := range c.Stmt.Directive("nolint") {
s.pos2rules[c.Stmt.Pos] = append(s.pos2rules[c.Stmt.Pos], strings.Split(d, " ")...)
// A list of changes that were loaded in a batch (no statements per change).
if c.Stmt != nil {
for _, d := range c.Stmt.Directive("nolint") {
s.pos2rules[c.Stmt.Pos] = append(s.pos2rules[c.Stmt.Pos], strings.Split(d, " ")...)
}
}
}
return s
Expand Down
3 changes: 2 additions & 1 deletion cmd/atlas/internal/migratelint/run_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (
"ariga.io/atlas/sql/sqlcheck"
"ariga.io/atlas/sql/sqlclient"

"github.com/fatih/color"
"github.com/stretchr/testify/require"
)

Expand All @@ -24,7 +25,7 @@ func TestRunner_Run(t *testing.T) {
b := &bytes.Buffer{}
c, err := sqlclient.Open(ctx, "sqlite://run?mode=memory&cache=shared&_fk=1")
require.NoError(t, err)
t.Setenv("NO_COLOR", "1")
color.NoColor = true

t.Run("checksum mismatch", func(t *testing.T) {
var (
Expand Down
14 changes: 14 additions & 0 deletions sql/sqlcheck/naming/naming.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,20 @@ func (a *Analyzer) Analyze(_ context.Context, p *sqlcheck.Pass) error {
diags = append(diags, a.match(sc, codeNameS, c.S.Name, "Schema", a.Schema)...)
case *schema.AddTable:
diags = append(diags, a.match(sc, codeNameT, c.T.Name, "Table", a.Table)...)
for _, c := range c.T.Columns {
diags = append(diags, a.match(sc, codeNameC, c.Name, "Column", a.Column)...)
}
for _, i := range c.T.Indexes {
diags = append(diags, a.match(sc, codeNameI, i.Name, "Index", a.Index)...)
}
for _, f := range c.T.ForeignKeys {
diags = append(diags, a.match(sc, codeNameF, f.Symbol, "Foreign-key constraint", a.ForeignKey)...)
}
for _, at := range c.T.Attrs {
if k, ok := at.(*schema.Check); ok {
diags = append(diags, a.match(sc, codeNameK, k.Name, "Check constraint", a.Check)...)
}
}
case *schema.RenameTable:
diags = append(diags, a.match(sc, codeNameT, c.To.Name, "Table", a.Table)...)
case *schema.ModifyTable:
Expand Down