Skip to content

Commit

Permalink
refactor: weighted graph
Browse files Browse the repository at this point in the history
  • Loading branch information
miparnisari committed Nov 22, 2024
1 parent ff7db03 commit e5e60f0
Show file tree
Hide file tree
Showing 6 changed files with 345 additions and 52 deletions.
140 changes: 101 additions & 39 deletions pkg/go/graph/weighted_graph.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@ import (
"math"
"slices"
"strings"

"gonum.org/v1/gonum/graph/encoding"
"gonum.org/v1/gonum/graph/encoding/dot"
"gonum.org/v1/gonum/graph/multi"
)

const Infinite = math.MaxInt32
Expand All @@ -16,55 +20,79 @@ var ErrTupleCycle = errors.New("tuple cycle")
var ErrContrainstTupleCycle = fmt.Errorf("%w: operands AND or BUT NOT cannot be involved in a cycle", ErrTupleCycle)

type WeightedAuthorizationModelGraph struct {
*multi.DirectedGraph
// edges is a convenience field to provide O(1) access to an edge.
// The key is the unique label of a node and the value is the list of edges from that node.
edges map[string][]*WeightedAuthorizationModelEdge
nodes map[string]*WeightedAuthorizationModelNode
// nodes is a convenience field to provide O(1) access to a node.
// The key is the unique label of a node and the value is the Node struct.
nodes map[string]*WeightedAuthorizationModelNode
drawingDirection DrawingDirection
}

// GetEdges returns the edges map.
func (wg *WeightedAuthorizationModelGraph) GetEdges() map[string][]*WeightedAuthorizationModelEdge {
return wg.edges
}

// GetEdges returns the edges map.
func (wg *WeightedAuthorizationModelGraph) GetEdgesByNode(node *WeightedAuthorizationModelNode) ([]*WeightedAuthorizationModelEdge, bool) {
func (wg *WeightedAuthorizationModelGraph) GetEdgesFromNode(node *WeightedAuthorizationModelNode) ([]*WeightedAuthorizationModelEdge, bool) {
v, ok := wg.edges[node.uniqueLabel]
return v, ok
}

// GetNodes returns the nodes map.
func (wg *WeightedAuthorizationModelGraph) GetNodes() map[string]*WeightedAuthorizationModelNode {
return wg.nodes
}

// GetNodes returns the nodes map.
func (wg *WeightedAuthorizationModelGraph) GetNodeByID(uniqueLabel string) (*WeightedAuthorizationModelNode, bool) {
func (wg *WeightedAuthorizationModelGraph) GetNodeByLabel(uniqueLabel string) (*WeightedAuthorizationModelNode, bool) {
v, ok := wg.nodes[uniqueLabel]
return v, ok
}

// NewWeightedAuthorizationModelGraph creates a new WeightedAuthorizationModelGraph.
func (wg *WeightedAuthorizationModelGraph) GetNodeByID(id int64) (*WeightedAuthorizationModelNode, error) {
node := wg.Node(id)

casted, ok := node.(*WeightedAuthorizationModelNode)
if !ok {
return nil, fmt.Errorf("%w: could not cast to AuthorizationModelNode", ErrQueryingGraph)
}

return casted, nil
}

func NewWeightedAuthorizationModelGraph() *WeightedAuthorizationModelGraph {
return &WeightedAuthorizationModelGraph{
nodes: make(map[string]*WeightedAuthorizationModelNode),
edges: make(map[string][]*WeightedAuthorizationModelEdge),
DirectedGraph: multi.NewDirectedGraph(),
nodes: make(map[string]*WeightedAuthorizationModelNode),
edges: make(map[string][]*WeightedAuthorizationModelEdge),
drawingDirection: DrawingDirectionCheck,
}
}

// AddNode adds a node to the graph with optional operationType and weight.
func (wg *WeightedAuthorizationModelGraph) AddNode(uniqueLabel, label string, nodeType NodeType) {
wildcards := make([]string, 0)
if nodeType == SpecificTypeWildcard {
wildcards = append(wildcards, uniqueLabel[:len(uniqueLabel)-2])
}
wg.nodes[uniqueLabel] = &WeightedAuthorizationModelNode{uniqueLabel: uniqueLabel, label: label, nodeType: nodeType, wildcards: wildcards}
innerNode := &AuthorizationModelNode{
Node: wg.NewNode(),
label: label,
uniqueLabel: uniqueLabel,
nodeType: nodeType,
}
newNode := NewWeightedAuthorizationModelNode(innerNode)
wg.DirectedGraph.AddNode(newNode)
wg.nodes[uniqueLabel] = newNode
}

func (wg *WeightedAuthorizationModelGraph) AddEdge(fromID, toID string, edgeType EdgeType, condition string) {
wildcards := make([]string, 0)
fromNode := wg.nodes[fromID]
toNode := wg.nodes[toID]
edge := &WeightedAuthorizationModelEdge{from: fromNode, to: toNode, edgeType: edgeType, conditionedOn: condition, wildcards: wildcards}
wg.edges[fromID] = append(wg.edges[fromID], edge)
fromNode, _ := wg.GetNodeByLabel(fromID)
toNode, _ := wg.GetNodeByLabel(toID)
innerEdge := &AuthorizationModelEdge{
edgeType: edgeType,
conditionedOn: condition,
}

innerEdge.Line = wg.NewLine(fromNode, toNode)
newWeightedEdge := NewWeightedAuthorizationModelEdge(innerEdge)
wg.DirectedGraph.SetLine(newWeightedEdge)
wg.edges[fromID] = append(wg.edges[fromID], newWeightedEdge)
}

// AssignWeights assigns weights to all the edges and nodes of the graph.
Expand Down Expand Up @@ -93,7 +121,8 @@ func (wg *WeightedAuthorizationModelGraph) calculateEdgeWildcards(edge *Weighted
if len(edge.wildcards) > 0 {
return
}
nodeWildcards := wg.nodes[edge.to.uniqueLabel].wildcards
toNode, _ := wg.GetNodeByID(edge.AuthorizationModelEdge.To().ID())
nodeWildcards := wg.nodes[toNode.uniqueLabel].wildcards
if len(nodeWildcards) == 0 {
return
}
Expand Down Expand Up @@ -151,10 +180,11 @@ func (wg *WeightedAuthorizationModelGraph) calculateNodeWeight(nodeID string, vi
if len(edge.weights) != 0 {
continue
}
if edge.to.nodeType == SpecificType || edge.to.nodeType == SpecificTypeWildcard {
toNode, _ := wg.GetNodeByID(edge.AuthorizationModelEdge.To().ID())
if toNode.nodeType == SpecificType || toNode.nodeType == SpecificTypeWildcard {
edge.weights = make(map[string]int)
uniqueLabel := edge.to.uniqueLabel
if edge.to.nodeType == SpecificTypeWildcard {
uniqueLabel := toNode.uniqueLabel
if toNode.nodeType == SpecificTypeWildcard {
uniqueLabel = uniqueLabel[:len(uniqueLabel)-2]
edge.wildcards = append(edge.wildcards, uniqueLabel)
wg.addEdgeWildcardsToNode(nodeID, edge)
Expand All @@ -180,28 +210,30 @@ func (wg *WeightedAuthorizationModelGraph) calculateNodeWeight(nodeID string, vi

// Calculate the weight of the edge based on the type of edge and the weight of the node that is connected to.
func (wg *WeightedAuthorizationModelGraph) calculateEdgeWeight(edge *WeightedAuthorizationModelEdge, ancestorPath []*WeightedAuthorizationModelEdge, visited map[string]bool, tupleCycleDependencies map[string][]*WeightedAuthorizationModelEdge) ([]string, error) {
fromNode, _ := wg.GetNodeByID(edge.AuthorizationModelEdge.From().ID())
toNode, _ := wg.GetNodeByID(edge.AuthorizationModelEdge.To().ID())
// if it is a recursive edge, we need to set the weight to infinite and add the edge to the tuple cycle dependencies
if edge.from.uniqueLabel == edge.to.uniqueLabel {
if fromNode.uniqueLabel == toNode.uniqueLabel {
edge.weights = make(map[string]int)
edge.weights["R#"+edge.to.uniqueLabel] = Infinite
tupleCycleDependencies[edge.to.uniqueLabel] = append(tupleCycleDependencies[edge.to.uniqueLabel], edge)
return []string{edge.from.uniqueLabel}, nil
edge.weights["R#"+toNode.uniqueLabel] = Infinite
tupleCycleDependencies[toNode.uniqueLabel] = append(tupleCycleDependencies[toNode.uniqueLabel], edge)
return []string{fromNode.uniqueLabel}, nil
}

// calculate the weight of the node that is connected to the edge
ancestorPath = append(ancestorPath, edge)
tupleCycle, err := wg.calculateNodeWeight(edge.to.uniqueLabel, visited, ancestorPath, tupleCycleDependencies)
tupleCycle, err := wg.calculateNodeWeight(toNode.uniqueLabel, visited, ancestorPath, tupleCycleDependencies)
if err != nil {
return tupleCycle, err
}

// if the node that is connected to the edge does not have any weight, we need to check if is a tuple cycle or a model cycle
if len(edge.to.weights) == 0 {
if wg.isTupleCycle(edge.to.uniqueLabel, ancestorPath) {
if len(toNode.weights) == 0 {
if wg.isTupleCycle(toNode.uniqueLabel, ancestorPath) {
edge.weights = make(map[string]int)
edge.weights["R#"+edge.to.uniqueLabel] = Infinite
tupleCycleDependencies[edge.to.uniqueLabel] = append(tupleCycleDependencies[edge.to.uniqueLabel], edge)
tupleCycle = append(tupleCycle, edge.to.uniqueLabel)
edge.weights["R#"+toNode.uniqueLabel] = Infinite
tupleCycleDependencies[toNode.uniqueLabel] = append(tupleCycleDependencies[toNode.uniqueLabel], edge)
tupleCycle = append(tupleCycle, toNode.uniqueLabel)
return tupleCycle, nil
}
return tupleCycle, ErrModelCycle
Expand All @@ -215,7 +247,7 @@ func (wg *WeightedAuthorizationModelGraph) calculateEdgeWeight(edge *WeightedAut
}

weights := make(map[string]int)
for key, value := range edge.to.weights {
for key, value := range toNode.weights {
if !isTupleCycle && strings.HasPrefix(key, "R#") {
nodeDependency := strings.TrimPrefix(key, "R#")
tupleCycleDependencies[nodeDependency] = append(tupleCycleDependencies[nodeDependency], edge)
Expand All @@ -242,11 +274,13 @@ func (wg *WeightedAuthorizationModelGraph) calculateEdgeWeight(edge *WeightedAut
func (wg *WeightedAuthorizationModelGraph) isTupleCycle(nodeID string, ancestorPath []*WeightedAuthorizationModelEdge) bool {
startTracking := false
for _, edge := range ancestorPath {
if !startTracking && edge.from.uniqueLabel == nodeID {
fromNode, _ := wg.GetNodeByID(edge.AuthorizationModelEdge.From().ID())
toNode, _ := wg.GetNodeByID(edge.AuthorizationModelEdge.To().ID())
if !startTracking && fromNode.uniqueLabel == nodeID {
startTracking = true
}
if startTracking {
if edge.edgeType == TTUEdge || (edge.edgeType == DirectEdge && edge.to.nodeType == SpecificTypeAndRelation) {
if edge.edgeType == TTUEdge || (edge.edgeType == DirectEdge && toNode.nodeType == SpecificTypeAndRelation) {
return true
}
}
Expand Down Expand Up @@ -479,7 +513,7 @@ func (wg *WeightedAuthorizationModelGraph) fixDependantNodesWeight(nodeCycle str
node := wg.nodes[nodeCycle]

for _, edge := range tupleCycleDependencies[nodeCycle] {
fromNode := wg.nodes[edge.from.uniqueLabel]
fromNode, _ := wg.GetNodeByID(edge.AuthorizationModelEdge.From().ID())
nodeWeights := make(map[string]int)
for key1, value1 := range fromNode.weights {
if key1 == referenceNodeID {
Expand All @@ -499,7 +533,7 @@ func (wg *WeightedAuthorizationModelGraph) fixDependantNodesWeight(nodeCycle str
}
}
fromNode.weights = nodeWeights
wg.addReferentialWildcardsToNode(edge.from.uniqueLabel, nodeCycle)
wg.addReferentialWildcardsToNode(fromNode.uniqueLabel, nodeCycle)
}
}

Expand All @@ -513,3 +547,31 @@ func (wg *WeightedAuthorizationModelGraph) removeNodeFromTupleCycles(nodeID stri
}
return result
}

var _ dot.Attributers = (*WeightedAuthorizationModelGraph)(nil)

func (wg *WeightedAuthorizationModelGraph) DOTAttributers() (encoding.Attributer, encoding.Attributer, encoding.Attributer) {
return wg, nil, nil
}

func (wg *WeightedAuthorizationModelGraph) Attributes() []encoding.Attribute {
rankdir := "BT" // bottom to top
if wg.drawingDirection == DrawingDirectionCheck {
rankdir = "TB" // top to bottom
}

return []encoding.Attribute{{
Key: "rankdir", // https://graphviz.org/docs/attrs/rankdir/
Value: rankdir,
}}
}

// GetDOT returns the DOT visualization. It should only be used for debugging.
func (wg *WeightedAuthorizationModelGraph) GetDOT() string {
dotRepresentation, err := dot.MarshalMulti(wg, "", "", "")
if err != nil {
return ""
}

return string(dotRepresentation)
}
94 changes: 94 additions & 0 deletions pkg/go/graph/weighted_graph_builder_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -484,3 +484,97 @@ func TestValidGraphModel(t *testing.T) {
require.NoError(t, err)
require.Equal(t, 3, graph.nodes["job#can_read"].weights["user"])
}

// TODO make output from DOT stable
//func TestWeightedGraphBuilderDOT(t *testing.T) {
// t.Parallel()
//
// testCases := map[string]struct {
// model string
// expectedOutput string // can visualize in https://dreampuf.github.io/GraphvizOnline
// expectedError error
// }{
// `multigraph`: {
// model: `
// model
// schema 1.1
// type user
// type state
// relations
// define can_view: [user] or member
// define member: [user]
// type transition
// relations
// define start: [state]
// define end: [state]
// define can_apply: [user] and can_view from start and can_view from end
// type group
// relations
// define owner: [user, transition#can_apply]
// define max_owner: [group#owner, group#max_owner]`,
// expectedOutput: `digraph {
//graph [
//rankdir=TB
//];
//
//// Node definitions.
//0 [label=group];
//1 [label="group#max_owner - weights:[user=+∞]"];
//2 [label="group#owner - weights:[user=3]"];
//3 [label=user];
//4 [label="transition#can_apply - weights:[user=2]"];
//5 [label=state];
//6 [label="state#can_view - weights:[user=1]"];
//7 [label="union - weights:[user=1]"];
//8 [label="state#member - weights:[user=1]"];
//9 [label=transition];
//10 [label="intersection - weights:[user=2]"];
//11 [label="transition#end - weights:[state=1]"];
//12 [label="transition#start - weights:[state=1]"];
//
//// Edge definitions.
//1 -> 1 [label="direct - weights:[user=+∞]"];
//1 -> 2 [label="direct - weights:[user=4]"];
//2 -> 3 [label="direct - weights:[user=1]"];
//2 -> 4 [label="direct - weights:[user=3]"];
//4 -> 10 [label="weights:[user=2]"];
//6 -> 7 [label="weights:[user=1]"];
//7 -> 3 [label="direct - weights:[user=1]"];
//7 -> 8 [label="weights:[user=1]"];
//8 -> 3 [label="direct - weights:[user=1]"];
//10 -> 3 [label="direct - weights:[user=1]"];
//10 -> 6 [
//headlabel="(transition#start)"
//label="weights:[user=2]"
//];
//10 -> 6 [
//headlabel="(transition#end)"
//label="weights:[user=2]"
//];
//11 -> 5 [label="direct - weights:[state=1]"];
//12 -> 5 [label="direct - weights:[state=1]"];
//}`,
// },
// }
// for name, testCase := range testCases {
// t.Run(name, func(t *testing.T) {
// t.Parallel()
//
// model := language.MustTransformDSLToProto(testCase.model)
// weightedGraph, err := NewWeightedAuthorizationModelGraphBuilder().Build(model)
// if testCase.expectedError != nil {
// require.ErrorIs(t, err, testCase.expectedError)
// } else {
// require.NoError(t, err)
//
// actualDOT := weightedGraph.GetDOT()
// actualSorted := getSorted(actualDOT)
// expectedSorted := getSorted(testCase.expectedOutput)
//
// diff := cmp.Diff(expectedSorted, actualSorted)
//
// require.Empty(t, diff, "expected %s\ngot\n%s", testCase.expectedOutput, actualDOT)
// }
// })
// }
//}
Loading

0 comments on commit e5e60f0

Please sign in to comment.