Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: weighted graph #387

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
140 changes: 101 additions & 39 deletions pkg/go/graph/weighted_graph.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@ import (
"math"
"slices"
"strings"

"gonum.org/v1/gonum/graph/encoding"
"gonum.org/v1/gonum/graph/encoding/dot"
"gonum.org/v1/gonum/graph/multi"
)

const Infinite = math.MaxInt32
Expand All @@ -16,55 +20,79 @@ var ErrTupleCycle = errors.New("tuple cycle")
var ErrContrainstTupleCycle = fmt.Errorf("%w: operands AND or BUT NOT cannot be involved in a cycle", ErrTupleCycle)

type WeightedAuthorizationModelGraph struct {
*multi.DirectedGraph
// edges is a convenience field to provide O(1) access to an edge.
// The key is the unique label of a node and the value is the list of edges from that node.
edges map[string][]*WeightedAuthorizationModelEdge
nodes map[string]*WeightedAuthorizationModelNode
// nodes is a convenience field to provide O(1) access to a node.
// The key is the unique label of a node and the value is the Node struct.
nodes map[string]*WeightedAuthorizationModelNode
drawingDirection DrawingDirection
}

// GetEdges returns the edges map.
func (wg *WeightedAuthorizationModelGraph) GetEdges() map[string][]*WeightedAuthorizationModelEdge {
return wg.edges
}

// GetEdges returns the edges map.
func (wg *WeightedAuthorizationModelGraph) GetEdgesByNode(node *WeightedAuthorizationModelNode) ([]*WeightedAuthorizationModelEdge, bool) {
func (wg *WeightedAuthorizationModelGraph) GetEdgesFromNode(node *WeightedAuthorizationModelNode) ([]*WeightedAuthorizationModelEdge, bool) {
v, ok := wg.edges[node.uniqueLabel]
return v, ok
}

// GetNodes returns the nodes map.
func (wg *WeightedAuthorizationModelGraph) GetNodes() map[string]*WeightedAuthorizationModelNode {
return wg.nodes
}

// GetNodes returns the nodes map.
func (wg *WeightedAuthorizationModelGraph) GetNodeByID(uniqueLabel string) (*WeightedAuthorizationModelNode, bool) {
func (wg *WeightedAuthorizationModelGraph) GetNodeByLabel(uniqueLabel string) (*WeightedAuthorizationModelNode, bool) {
v, ok := wg.nodes[uniqueLabel]
return v, ok
}

// NewWeightedAuthorizationModelGraph creates a new WeightedAuthorizationModelGraph.
func (wg *WeightedAuthorizationModelGraph) GetNodeByID(id int64) (*WeightedAuthorizationModelNode, error) {
node := wg.Node(id)

casted, ok := node.(*WeightedAuthorizationModelNode)
if !ok {
return nil, fmt.Errorf("%w: could not cast to AuthorizationModelNode", ErrQueryingGraph)
}

return casted, nil
}

func NewWeightedAuthorizationModelGraph() *WeightedAuthorizationModelGraph {
return &WeightedAuthorizationModelGraph{
nodes: make(map[string]*WeightedAuthorizationModelNode),
edges: make(map[string][]*WeightedAuthorizationModelEdge),
DirectedGraph: multi.NewDirectedGraph(),
nodes: make(map[string]*WeightedAuthorizationModelNode),
edges: make(map[string][]*WeightedAuthorizationModelEdge),
drawingDirection: DrawingDirectionCheck,
}
}

// AddNode adds a node to the graph with optional operationType and weight.
func (wg *WeightedAuthorizationModelGraph) AddNode(uniqueLabel, label string, nodeType NodeType) {
wildcards := make([]string, 0)
if nodeType == SpecificTypeWildcard {
wildcards = append(wildcards, uniqueLabel[:len(uniqueLabel)-2])
}
wg.nodes[uniqueLabel] = &WeightedAuthorizationModelNode{uniqueLabel: uniqueLabel, label: label, nodeType: nodeType, wildcards: wildcards}
innerNode := &AuthorizationModelNode{
Node: wg.NewNode(),
label: label,
uniqueLabel: uniqueLabel,
nodeType: nodeType,
}
newNode := NewWeightedAuthorizationModelNode(innerNode)
wg.DirectedGraph.AddNode(newNode)
wg.nodes[uniqueLabel] = newNode
}

func (wg *WeightedAuthorizationModelGraph) AddEdge(fromID, toID string, edgeType EdgeType, condition string) {
wildcards := make([]string, 0)
fromNode := wg.nodes[fromID]
toNode := wg.nodes[toID]
edge := &WeightedAuthorizationModelEdge{from: fromNode, to: toNode, edgeType: edgeType, conditionedOn: condition, wildcards: wildcards}
wg.edges[fromID] = append(wg.edges[fromID], edge)
fromNode, _ := wg.GetNodeByLabel(fromID)
toNode, _ := wg.GetNodeByLabel(toID)
innerEdge := &AuthorizationModelEdge{
edgeType: edgeType,
conditionedOn: condition,
}

innerEdge.Line = wg.NewLine(fromNode, toNode)
newWeightedEdge := NewWeightedAuthorizationModelEdge(innerEdge)
wg.DirectedGraph.SetLine(newWeightedEdge)
wg.edges[fromID] = append(wg.edges[fromID], newWeightedEdge)
}

// AssignWeights assigns weights to all the edges and nodes of the graph.
Expand Down Expand Up @@ -93,7 +121,8 @@ func (wg *WeightedAuthorizationModelGraph) calculateEdgeWildcards(edge *Weighted
if len(edge.wildcards) > 0 {
return
}
nodeWildcards := wg.nodes[edge.to.uniqueLabel].wildcards
toNode, _ := wg.GetNodeByID(edge.AuthorizationModelEdge.To().ID())
nodeWildcards := wg.nodes[toNode.uniqueLabel].wildcards
if len(nodeWildcards) == 0 {
return
}
Expand Down Expand Up @@ -151,10 +180,11 @@ func (wg *WeightedAuthorizationModelGraph) calculateNodeWeight(nodeID string, vi
if len(edge.weights) != 0 {
continue
}
if edge.to.nodeType == SpecificType || edge.to.nodeType == SpecificTypeWildcard {
toNode, _ := wg.GetNodeByID(edge.AuthorizationModelEdge.To().ID())
if toNode.nodeType == SpecificType || toNode.nodeType == SpecificTypeWildcard {
edge.weights = make(map[string]int)
uniqueLabel := edge.to.uniqueLabel
if edge.to.nodeType == SpecificTypeWildcard {
uniqueLabel := toNode.uniqueLabel
if toNode.nodeType == SpecificTypeWildcard {
uniqueLabel = uniqueLabel[:len(uniqueLabel)-2]
edge.wildcards = append(edge.wildcards, uniqueLabel)
wg.addEdgeWildcardsToNode(nodeID, edge)
Expand All @@ -180,28 +210,30 @@ func (wg *WeightedAuthorizationModelGraph) calculateNodeWeight(nodeID string, vi

// Calculate the weight of the edge based on the type of edge and the weight of the node that is connected to.
func (wg *WeightedAuthorizationModelGraph) calculateEdgeWeight(edge *WeightedAuthorizationModelEdge, ancestorPath []*WeightedAuthorizationModelEdge, visited map[string]bool, tupleCycleDependencies map[string][]*WeightedAuthorizationModelEdge) ([]string, error) {
fromNode, _ := wg.GetNodeByID(edge.AuthorizationModelEdge.From().ID())
toNode, _ := wg.GetNodeByID(edge.AuthorizationModelEdge.To().ID())
// if it is a recursive edge, we need to set the weight to infinite and add the edge to the tuple cycle dependencies
if edge.from.uniqueLabel == edge.to.uniqueLabel {
if fromNode.uniqueLabel == toNode.uniqueLabel {
edge.weights = make(map[string]int)
edge.weights["R#"+edge.to.uniqueLabel] = Infinite
tupleCycleDependencies[edge.to.uniqueLabel] = append(tupleCycleDependencies[edge.to.uniqueLabel], edge)
return []string{edge.from.uniqueLabel}, nil
edge.weights["R#"+toNode.uniqueLabel] = Infinite
tupleCycleDependencies[toNode.uniqueLabel] = append(tupleCycleDependencies[toNode.uniqueLabel], edge)
return []string{fromNode.uniqueLabel}, nil
}

// calculate the weight of the node that is connected to the edge
ancestorPath = append(ancestorPath, edge)
tupleCycle, err := wg.calculateNodeWeight(edge.to.uniqueLabel, visited, ancestorPath, tupleCycleDependencies)
tupleCycle, err := wg.calculateNodeWeight(toNode.uniqueLabel, visited, ancestorPath, tupleCycleDependencies)
if err != nil {
return tupleCycle, err
}

// if the node that is connected to the edge does not have any weight, we need to check if is a tuple cycle or a model cycle
if len(edge.to.weights) == 0 {
if wg.isTupleCycle(edge.to.uniqueLabel, ancestorPath) {
if len(toNode.weights) == 0 {
if wg.isTupleCycle(toNode.uniqueLabel, ancestorPath) {
edge.weights = make(map[string]int)
edge.weights["R#"+edge.to.uniqueLabel] = Infinite
tupleCycleDependencies[edge.to.uniqueLabel] = append(tupleCycleDependencies[edge.to.uniqueLabel], edge)
tupleCycle = append(tupleCycle, edge.to.uniqueLabel)
edge.weights["R#"+toNode.uniqueLabel] = Infinite
tupleCycleDependencies[toNode.uniqueLabel] = append(tupleCycleDependencies[toNode.uniqueLabel], edge)
tupleCycle = append(tupleCycle, toNode.uniqueLabel)
return tupleCycle, nil
}
return tupleCycle, ErrModelCycle
Expand All @@ -215,7 +247,7 @@ func (wg *WeightedAuthorizationModelGraph) calculateEdgeWeight(edge *WeightedAut
}

weights := make(map[string]int)
for key, value := range edge.to.weights {
for key, value := range toNode.weights {
if !isTupleCycle && strings.HasPrefix(key, "R#") {
nodeDependency := strings.TrimPrefix(key, "R#")
tupleCycleDependencies[nodeDependency] = append(tupleCycleDependencies[nodeDependency], edge)
Expand All @@ -242,11 +274,13 @@ func (wg *WeightedAuthorizationModelGraph) calculateEdgeWeight(edge *WeightedAut
func (wg *WeightedAuthorizationModelGraph) isTupleCycle(nodeID string, ancestorPath []*WeightedAuthorizationModelEdge) bool {
startTracking := false
for _, edge := range ancestorPath {
if !startTracking && edge.from.uniqueLabel == nodeID {
fromNode, _ := wg.GetNodeByID(edge.AuthorizationModelEdge.From().ID())
toNode, _ := wg.GetNodeByID(edge.AuthorizationModelEdge.To().ID())
if !startTracking && fromNode.uniqueLabel == nodeID {
startTracking = true
}
if startTracking {
if edge.edgeType == TTUEdge || (edge.edgeType == DirectEdge && edge.to.nodeType == SpecificTypeAndRelation) {
if edge.edgeType == TTUEdge || (edge.edgeType == DirectEdge && toNode.nodeType == SpecificTypeAndRelation) {
return true
}
}
Expand Down Expand Up @@ -479,7 +513,7 @@ func (wg *WeightedAuthorizationModelGraph) fixDependantNodesWeight(nodeCycle str
node := wg.nodes[nodeCycle]

for _, edge := range tupleCycleDependencies[nodeCycle] {
fromNode := wg.nodes[edge.from.uniqueLabel]
fromNode, _ := wg.GetNodeByID(edge.AuthorizationModelEdge.From().ID())
nodeWeights := make(map[string]int)
for key1, value1 := range fromNode.weights {
if key1 == referenceNodeID {
Expand All @@ -499,7 +533,7 @@ func (wg *WeightedAuthorizationModelGraph) fixDependantNodesWeight(nodeCycle str
}
}
fromNode.weights = nodeWeights
wg.addReferentialWildcardsToNode(edge.from.uniqueLabel, nodeCycle)
wg.addReferentialWildcardsToNode(fromNode.uniqueLabel, nodeCycle)
}
}

Expand All @@ -513,3 +547,31 @@ func (wg *WeightedAuthorizationModelGraph) removeNodeFromTupleCycles(nodeID stri
}
return result
}

var _ dot.Attributers = (*WeightedAuthorizationModelGraph)(nil)

func (wg *WeightedAuthorizationModelGraph) DOTAttributers() (encoding.Attributer, encoding.Attributer, encoding.Attributer) {
return wg, nil, nil
}

func (wg *WeightedAuthorizationModelGraph) Attributes() []encoding.Attribute {
rankdir := "BT" // bottom to top
if wg.drawingDirection == DrawingDirectionCheck {
rankdir = "TB" // top to bottom
}

return []encoding.Attribute{{
Key: "rankdir", // https://graphviz.org/docs/attrs/rankdir/
Value: rankdir,
}}
}

// GetDOT returns the DOT visualization. It should only be used for debugging.
func (wg *WeightedAuthorizationModelGraph) GetDOT() string {
dotRepresentation, err := dot.MarshalMulti(wg, "", "", "")
if err != nil {
return ""
}

return string(dotRepresentation)
}
94 changes: 94 additions & 0 deletions pkg/go/graph/weighted_graph_builder_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -484,3 +484,97 @@ func TestValidGraphModel(t *testing.T) {
require.NoError(t, err)
require.Equal(t, 3, graph.nodes["job#can_read"].weights["user"])
}

// TODO make output from DOT stable
//func TestWeightedGraphBuilderDOT(t *testing.T) {
// t.Parallel()
//
// testCases := map[string]struct {
// model string
// expectedOutput string // can visualize in https://dreampuf.github.io/GraphvizOnline
// expectedError error
// }{
// `multigraph`: {
// model: `
// model
// schema 1.1
// type user
// type state
// relations
// define can_view: [user] or member
// define member: [user]
// type transition
// relations
// define start: [state]
// define end: [state]
// define can_apply: [user] and can_view from start and can_view from end
// type group
// relations
// define owner: [user, transition#can_apply]
// define max_owner: [group#owner, group#max_owner]`,
// expectedOutput: `digraph {
//graph [
//rankdir=TB
//];
//
//// Node definitions.
//0 [label=group];
//1 [label="group#max_owner - weights:[user=+∞]"];
//2 [label="group#owner - weights:[user=3]"];
//3 [label=user];
//4 [label="transition#can_apply - weights:[user=2]"];
//5 [label=state];
//6 [label="state#can_view - weights:[user=1]"];
//7 [label="union - weights:[user=1]"];
//8 [label="state#member - weights:[user=1]"];
//9 [label=transition];
//10 [label="intersection - weights:[user=2]"];
//11 [label="transition#end - weights:[state=1]"];
//12 [label="transition#start - weights:[state=1]"];
//
//// Edge definitions.
//1 -> 1 [label="direct - weights:[user=+∞]"];
//1 -> 2 [label="direct - weights:[user=4]"];
//2 -> 3 [label="direct - weights:[user=1]"];
//2 -> 4 [label="direct - weights:[user=3]"];
//4 -> 10 [label="weights:[user=2]"];
//6 -> 7 [label="weights:[user=1]"];
//7 -> 3 [label="direct - weights:[user=1]"];
//7 -> 8 [label="weights:[user=1]"];
//8 -> 3 [label="direct - weights:[user=1]"];
//10 -> 3 [label="direct - weights:[user=1]"];
//10 -> 6 [
//headlabel="(transition#start)"
//label="weights:[user=2]"
//];
//10 -> 6 [
//headlabel="(transition#end)"
//label="weights:[user=2]"
//];
//11 -> 5 [label="direct - weights:[state=1]"];
//12 -> 5 [label="direct - weights:[state=1]"];
//}`,
// },
// }
// for name, testCase := range testCases {
// t.Run(name, func(t *testing.T) {
// t.Parallel()
//
// model := language.MustTransformDSLToProto(testCase.model)
// weightedGraph, err := NewWeightedAuthorizationModelGraphBuilder().Build(model)
// if testCase.expectedError != nil {
// require.ErrorIs(t, err, testCase.expectedError)
// } else {
// require.NoError(t, err)
//
// actualDOT := weightedGraph.GetDOT()
// actualSorted := getSorted(actualDOT)
// expectedSorted := getSorted(testCase.expectedOutput)
//
// diff := cmp.Diff(expectedSorted, actualSorted)
//
// require.Empty(t, diff, "expected %s\ngot\n%s", testCase.expectedOutput, actualDOT)
// }
// })
// }
//}
Loading