Skip to content

Commit

Permalink
rego-v1: Future-proofing repl pkg tests to be 1.0 compatible (#7026)
Browse files Browse the repository at this point in the history
Also making some updates to the repl implementation to properly deal with v1 as the default rego-version.

Signed-off-by: Johan Fylling <[email protected]>
  • Loading branch information
johanfylling committed Sep 19, 2024
1 parent 7d60244 commit acb3272
Show file tree
Hide file tree
Showing 4 changed files with 328 additions and 113 deletions.
11 changes: 6 additions & 5 deletions ast/parser_ext.go
Original file line number Diff line number Diff line change
Expand Up @@ -287,11 +287,12 @@ func ParseCompleteDocRuleFromEqExpr(module *Module, lhs, rhs *Term) (*Rule, erro
setJSONOptions(body, &rhs.jsonOptions)

return &Rule{
Location: lhs.Location,
Head: head,
Body: body,
Module: module,
jsonOptions: lhs.jsonOptions,
Location: lhs.Location,
Head: head,
Body: body,
Module: module,
jsonOptions: lhs.jsonOptions,
generatedBody: true,
}, nil
}

Expand Down
6 changes: 4 additions & 2 deletions repl/example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"context"
"fmt"

"github.com/open-policy-agent/opa/ast"
"github.com/open-policy-agent/opa/repl"
"github.com/open-policy-agent/opa/storage/inmem"
)
Expand All @@ -26,10 +27,11 @@ func ExampleREPL_OneShot() {
var buf bytes.Buffer

// Create a new REPL.
r := repl.New(store, "", &buf, "json", 0, "")
r := repl.New(store, "", &buf, "json", 0, "").
WithRegoVersion(ast.RegoV1)

// Define a rule inside the REPL.
r.OneShot(ctx, "p { a = [1, 2, 3, 4]; a[_] > 3 }")
r.OneShot(ctx, "p if { a = [1, 2, 3, 4]; a[_] > 3 }")

// Query the rule defined above.
r.OneShot(ctx, "p")
Expand Down
40 changes: 29 additions & 11 deletions repl/repl.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ type REPL struct {
profiler bool
strictBuiltinErrors bool
capabilities *ast.Capabilities
v1Compatible bool
regoVersion ast.RegoVersion
initBundles map[string]*bundle.Bundle

// TODO(tsandall): replace this state with rule definitions
Expand Down Expand Up @@ -352,8 +352,20 @@ func (r *REPL) WithRuntime(term *ast.Term) *REPL {
return r
}

// WithRegoVersion sets the Rego version to v.
func (r *REPL) WithRegoVersion(v ast.RegoVersion) *REPL {
r.regoVersion = v
return r
}

// WithV1Compatible sets the Rego version to v1.
// Deprecated: Use WithRegoVersion instead.
func (r *REPL) WithV1Compatible(v1Compatible bool) *REPL {
r.v1Compatible = v1Compatible
if v1Compatible {
r.regoVersion = ast.RegoV1
} else {
r.regoVersion = ast.DefaultRegoVersion
}
return r
}

Expand Down Expand Up @@ -495,7 +507,7 @@ func (r *REPL) cmdShow(args []string) error {
return nil
}
module := r.modules[r.currentModuleID]
bs, err := format.Ast(module)
bs, err := format.AstWithOpts(module, format.Opts{RegoVersion: module.RegoVersion()})
if err != nil {
return err
}
Expand Down Expand Up @@ -774,7 +786,7 @@ func (r *REPL) compileRule(ctx context.Context, rule *ast.Rule) error {

var unset bool

if r.v1Compatible {
if r.regoVersion == ast.RegoV1 {
if errs := ast.CheckRegoV1(rule); errs != nil {
return errs
}
Expand Down Expand Up @@ -910,21 +922,26 @@ func (r *REPL) evalBufferMulti(ctx context.Context) error {
}

func (r *REPL) parserOptions() (ast.ParserOptions, error) {
if r.v1Compatible {
if r.regoVersion == ast.RegoV1 {
return ast.ParserOptions{RegoVersion: ast.RegoV1}, nil
}
if r.currentModuleID != "" {
opts, err := future.ParserOptionsFromFutureImports(r.modules[r.currentModuleID].Imports)
if err == nil {
for _, i := range r.modules[r.currentModuleID].Imports {
if ast.Compare(i.Path.Value, ast.RegoV1CompatibleRef) == 0 {
opts.RegoVersion = ast.RegoV1
opts.RegoVersion = ast.RegoV0CompatV1

// ast.RegoV0CompatV1 sets parsing requirements, but doesn't imply allowed future keywords
if r.capabilities != nil {
opts.FutureKeywords = r.capabilities.FutureKeywords
}
}
}
}
return opts, err
}
return ast.ParserOptions{}, nil
return ast.ParserOptions{RegoVersion: r.regoVersion}, nil
}

func (r *REPL) loadCompiler(ctx context.Context) (*ast.Compiler, error) {
Expand Down Expand Up @@ -1166,9 +1183,11 @@ func (r *REPL) evalPackage(p *ast.Package) error {
return nil
}

r.modules[moduleID] = &ast.Module{
m := ast.Module{
Package: p,
}
m.SetRegoVersion(r.regoVersion)
r.modules[moduleID] = &m

r.currentModuleID = moduleID

Expand Down Expand Up @@ -1282,9 +1301,8 @@ func (r *REPL) loadModules(ctx context.Context, txn storage.Transaction) (map[st
return nil, err
}

popts := ast.ParserOptions{}
if r.v1Compatible {
popts.RegoVersion = ast.RegoV1
popts := ast.ParserOptions{
RegoVersion: r.regoVersion,
}

parsed, err := ast.ParseModuleWithOpts(id, string(bs), popts)
Expand Down
Loading

0 comments on commit acb3272

Please sign in to comment.