From e5e60f0c6bb7983f9bb5829ba8d8d4624217fe79 Mon Sep 17 00:00:00 2001 From: Maria Ines Parnisari Date: Thu, 21 Nov 2024 20:11:04 -0600 Subject: [PATCH] refactor: weighted graph --- pkg/go/graph/weighted_graph.go | 140 ++++++++++++++------ pkg/go/graph/weighted_graph_builder_test.go | 94 +++++++++++++ pkg/go/graph/weighted_graph_edge.go | 55 ++++++-- pkg/go/graph/weighted_graph_node.go | 54 +++++++- pkg/go/graph/weights.go | 36 +++++ pkg/go/graph/weights_test.go | 18 +++ 6 files changed, 345 insertions(+), 52 deletions(-) create mode 100644 pkg/go/graph/weights.go create mode 100644 pkg/go/graph/weights_test.go diff --git a/pkg/go/graph/weighted_graph.go b/pkg/go/graph/weighted_graph.go index 27954f74..c144df18 100644 --- a/pkg/go/graph/weighted_graph.go +++ b/pkg/go/graph/weighted_graph.go @@ -6,6 +6,10 @@ import ( "math" "slices" "strings" + + "gonum.org/v1/gonum/graph/encoding" + "gonum.org/v1/gonum/graph/encoding/dot" + "gonum.org/v1/gonum/graph/multi" ) const Infinite = math.MaxInt32 @@ -16,55 +20,79 @@ var ErrTupleCycle = errors.New("tuple cycle") var ErrContrainstTupleCycle = fmt.Errorf("%w: operands AND or BUT NOT cannot be involved in a cycle", ErrTupleCycle) type WeightedAuthorizationModelGraph struct { + *multi.DirectedGraph + // edges is a convenience field to provide O(1) access to an edge. + // The key is the unique label of a node and the value is the list of edges from that node. edges map[string][]*WeightedAuthorizationModelEdge - nodes map[string]*WeightedAuthorizationModelNode + // nodes is a convenience field to provide O(1) access to a node. + // The key is the unique label of a node and the value is the Node struct. + nodes map[string]*WeightedAuthorizationModelNode + drawingDirection DrawingDirection } -// GetEdges returns the edges map. func (wg *WeightedAuthorizationModelGraph) GetEdges() map[string][]*WeightedAuthorizationModelEdge { return wg.edges } -// GetEdges returns the edges map. -func (wg *WeightedAuthorizationModelGraph) GetEdgesByNode(node *WeightedAuthorizationModelNode) ([]*WeightedAuthorizationModelEdge, bool) { +func (wg *WeightedAuthorizationModelGraph) GetEdgesFromNode(node *WeightedAuthorizationModelNode) ([]*WeightedAuthorizationModelEdge, bool) { v, ok := wg.edges[node.uniqueLabel] return v, ok } -// GetNodes returns the nodes map. func (wg *WeightedAuthorizationModelGraph) GetNodes() map[string]*WeightedAuthorizationModelNode { return wg.nodes } -// GetNodes returns the nodes map. -func (wg *WeightedAuthorizationModelGraph) GetNodeByID(uniqueLabel string) (*WeightedAuthorizationModelNode, bool) { +func (wg *WeightedAuthorizationModelGraph) GetNodeByLabel(uniqueLabel string) (*WeightedAuthorizationModelNode, bool) { v, ok := wg.nodes[uniqueLabel] return v, ok } -// NewWeightedAuthorizationModelGraph creates a new WeightedAuthorizationModelGraph. +func (wg *WeightedAuthorizationModelGraph) GetNodeByID(id int64) (*WeightedAuthorizationModelNode, error) { + node := wg.Node(id) + + casted, ok := node.(*WeightedAuthorizationModelNode) + if !ok { + return nil, fmt.Errorf("%w: could not cast to AuthorizationModelNode", ErrQueryingGraph) + } + + return casted, nil +} + func NewWeightedAuthorizationModelGraph() *WeightedAuthorizationModelGraph { return &WeightedAuthorizationModelGraph{ - nodes: make(map[string]*WeightedAuthorizationModelNode), - edges: make(map[string][]*WeightedAuthorizationModelEdge), + DirectedGraph: multi.NewDirectedGraph(), + nodes: make(map[string]*WeightedAuthorizationModelNode), + edges: make(map[string][]*WeightedAuthorizationModelEdge), + drawingDirection: DrawingDirectionCheck, } } // AddNode adds a node to the graph with optional operationType and weight. func (wg *WeightedAuthorizationModelGraph) AddNode(uniqueLabel, label string, nodeType NodeType) { - wildcards := make([]string, 0) - if nodeType == SpecificTypeWildcard { - wildcards = append(wildcards, uniqueLabel[:len(uniqueLabel)-2]) - } - wg.nodes[uniqueLabel] = &WeightedAuthorizationModelNode{uniqueLabel: uniqueLabel, label: label, nodeType: nodeType, wildcards: wildcards} + innerNode := &AuthorizationModelNode{ + Node: wg.NewNode(), + label: label, + uniqueLabel: uniqueLabel, + nodeType: nodeType, + } + newNode := NewWeightedAuthorizationModelNode(innerNode) + wg.DirectedGraph.AddNode(newNode) + wg.nodes[uniqueLabel] = newNode } func (wg *WeightedAuthorizationModelGraph) AddEdge(fromID, toID string, edgeType EdgeType, condition string) { - wildcards := make([]string, 0) - fromNode := wg.nodes[fromID] - toNode := wg.nodes[toID] - edge := &WeightedAuthorizationModelEdge{from: fromNode, to: toNode, edgeType: edgeType, conditionedOn: condition, wildcards: wildcards} - wg.edges[fromID] = append(wg.edges[fromID], edge) + fromNode, _ := wg.GetNodeByLabel(fromID) + toNode, _ := wg.GetNodeByLabel(toID) + innerEdge := &AuthorizationModelEdge{ + edgeType: edgeType, + conditionedOn: condition, + } + + innerEdge.Line = wg.NewLine(fromNode, toNode) + newWeightedEdge := NewWeightedAuthorizationModelEdge(innerEdge) + wg.DirectedGraph.SetLine(newWeightedEdge) + wg.edges[fromID] = append(wg.edges[fromID], newWeightedEdge) } // AssignWeights assigns weights to all the edges and nodes of the graph. @@ -93,7 +121,8 @@ func (wg *WeightedAuthorizationModelGraph) calculateEdgeWildcards(edge *Weighted if len(edge.wildcards) > 0 { return } - nodeWildcards := wg.nodes[edge.to.uniqueLabel].wildcards + toNode, _ := wg.GetNodeByID(edge.AuthorizationModelEdge.To().ID()) + nodeWildcards := wg.nodes[toNode.uniqueLabel].wildcards if len(nodeWildcards) == 0 { return } @@ -151,10 +180,11 @@ func (wg *WeightedAuthorizationModelGraph) calculateNodeWeight(nodeID string, vi if len(edge.weights) != 0 { continue } - if edge.to.nodeType == SpecificType || edge.to.nodeType == SpecificTypeWildcard { + toNode, _ := wg.GetNodeByID(edge.AuthorizationModelEdge.To().ID()) + if toNode.nodeType == SpecificType || toNode.nodeType == SpecificTypeWildcard { edge.weights = make(map[string]int) - uniqueLabel := edge.to.uniqueLabel - if edge.to.nodeType == SpecificTypeWildcard { + uniqueLabel := toNode.uniqueLabel + if toNode.nodeType == SpecificTypeWildcard { uniqueLabel = uniqueLabel[:len(uniqueLabel)-2] edge.wildcards = append(edge.wildcards, uniqueLabel) wg.addEdgeWildcardsToNode(nodeID, edge) @@ -180,28 +210,30 @@ func (wg *WeightedAuthorizationModelGraph) calculateNodeWeight(nodeID string, vi // Calculate the weight of the edge based on the type of edge and the weight of the node that is connected to. func (wg *WeightedAuthorizationModelGraph) calculateEdgeWeight(edge *WeightedAuthorizationModelEdge, ancestorPath []*WeightedAuthorizationModelEdge, visited map[string]bool, tupleCycleDependencies map[string][]*WeightedAuthorizationModelEdge) ([]string, error) { + fromNode, _ := wg.GetNodeByID(edge.AuthorizationModelEdge.From().ID()) + toNode, _ := wg.GetNodeByID(edge.AuthorizationModelEdge.To().ID()) // if it is a recursive edge, we need to set the weight to infinite and add the edge to the tuple cycle dependencies - if edge.from.uniqueLabel == edge.to.uniqueLabel { + if fromNode.uniqueLabel == toNode.uniqueLabel { edge.weights = make(map[string]int) - edge.weights["R#"+edge.to.uniqueLabel] = Infinite - tupleCycleDependencies[edge.to.uniqueLabel] = append(tupleCycleDependencies[edge.to.uniqueLabel], edge) - return []string{edge.from.uniqueLabel}, nil + edge.weights["R#"+toNode.uniqueLabel] = Infinite + tupleCycleDependencies[toNode.uniqueLabel] = append(tupleCycleDependencies[toNode.uniqueLabel], edge) + return []string{fromNode.uniqueLabel}, nil } // calculate the weight of the node that is connected to the edge ancestorPath = append(ancestorPath, edge) - tupleCycle, err := wg.calculateNodeWeight(edge.to.uniqueLabel, visited, ancestorPath, tupleCycleDependencies) + tupleCycle, err := wg.calculateNodeWeight(toNode.uniqueLabel, visited, ancestorPath, tupleCycleDependencies) if err != nil { return tupleCycle, err } // if the node that is connected to the edge does not have any weight, we need to check if is a tuple cycle or a model cycle - if len(edge.to.weights) == 0 { - if wg.isTupleCycle(edge.to.uniqueLabel, ancestorPath) { + if len(toNode.weights) == 0 { + if wg.isTupleCycle(toNode.uniqueLabel, ancestorPath) { edge.weights = make(map[string]int) - edge.weights["R#"+edge.to.uniqueLabel] = Infinite - tupleCycleDependencies[edge.to.uniqueLabel] = append(tupleCycleDependencies[edge.to.uniqueLabel], edge) - tupleCycle = append(tupleCycle, edge.to.uniqueLabel) + edge.weights["R#"+toNode.uniqueLabel] = Infinite + tupleCycleDependencies[toNode.uniqueLabel] = append(tupleCycleDependencies[toNode.uniqueLabel], edge) + tupleCycle = append(tupleCycle, toNode.uniqueLabel) return tupleCycle, nil } return tupleCycle, ErrModelCycle @@ -215,7 +247,7 @@ func (wg *WeightedAuthorizationModelGraph) calculateEdgeWeight(edge *WeightedAut } weights := make(map[string]int) - for key, value := range edge.to.weights { + for key, value := range toNode.weights { if !isTupleCycle && strings.HasPrefix(key, "R#") { nodeDependency := strings.TrimPrefix(key, "R#") tupleCycleDependencies[nodeDependency] = append(tupleCycleDependencies[nodeDependency], edge) @@ -242,11 +274,13 @@ func (wg *WeightedAuthorizationModelGraph) calculateEdgeWeight(edge *WeightedAut func (wg *WeightedAuthorizationModelGraph) isTupleCycle(nodeID string, ancestorPath []*WeightedAuthorizationModelEdge) bool { startTracking := false for _, edge := range ancestorPath { - if !startTracking && edge.from.uniqueLabel == nodeID { + fromNode, _ := wg.GetNodeByID(edge.AuthorizationModelEdge.From().ID()) + toNode, _ := wg.GetNodeByID(edge.AuthorizationModelEdge.To().ID()) + if !startTracking && fromNode.uniqueLabel == nodeID { startTracking = true } if startTracking { - if edge.edgeType == TTUEdge || (edge.edgeType == DirectEdge && edge.to.nodeType == SpecificTypeAndRelation) { + if edge.edgeType == TTUEdge || (edge.edgeType == DirectEdge && toNode.nodeType == SpecificTypeAndRelation) { return true } } @@ -479,7 +513,7 @@ func (wg *WeightedAuthorizationModelGraph) fixDependantNodesWeight(nodeCycle str node := wg.nodes[nodeCycle] for _, edge := range tupleCycleDependencies[nodeCycle] { - fromNode := wg.nodes[edge.from.uniqueLabel] + fromNode, _ := wg.GetNodeByID(edge.AuthorizationModelEdge.From().ID()) nodeWeights := make(map[string]int) for key1, value1 := range fromNode.weights { if key1 == referenceNodeID { @@ -499,7 +533,7 @@ func (wg *WeightedAuthorizationModelGraph) fixDependantNodesWeight(nodeCycle str } } fromNode.weights = nodeWeights - wg.addReferentialWildcardsToNode(edge.from.uniqueLabel, nodeCycle) + wg.addReferentialWildcardsToNode(fromNode.uniqueLabel, nodeCycle) } } @@ -513,3 +547,31 @@ func (wg *WeightedAuthorizationModelGraph) removeNodeFromTupleCycles(nodeID stri } return result } + +var _ dot.Attributers = (*WeightedAuthorizationModelGraph)(nil) + +func (wg *WeightedAuthorizationModelGraph) DOTAttributers() (encoding.Attributer, encoding.Attributer, encoding.Attributer) { + return wg, nil, nil +} + +func (wg *WeightedAuthorizationModelGraph) Attributes() []encoding.Attribute { + rankdir := "BT" // bottom to top + if wg.drawingDirection == DrawingDirectionCheck { + rankdir = "TB" // top to bottom + } + + return []encoding.Attribute{{ + Key: "rankdir", // https://graphviz.org/docs/attrs/rankdir/ + Value: rankdir, + }} +} + +// GetDOT returns the DOT visualization. It should only be used for debugging. +func (wg *WeightedAuthorizationModelGraph) GetDOT() string { + dotRepresentation, err := dot.MarshalMulti(wg, "", "", "") + if err != nil { + return "" + } + + return string(dotRepresentation) +} diff --git a/pkg/go/graph/weighted_graph_builder_test.go b/pkg/go/graph/weighted_graph_builder_test.go index 2f1f5b36..e8861ddf 100644 --- a/pkg/go/graph/weighted_graph_builder_test.go +++ b/pkg/go/graph/weighted_graph_builder_test.go @@ -484,3 +484,97 @@ func TestValidGraphModel(t *testing.T) { require.NoError(t, err) require.Equal(t, 3, graph.nodes["job#can_read"].weights["user"]) } + +// TODO make output from DOT stable +//func TestWeightedGraphBuilderDOT(t *testing.T) { +// t.Parallel() +// +// testCases := map[string]struct { +// model string +// expectedOutput string // can visualize in https://dreampuf.github.io/GraphvizOnline +// expectedError error +// }{ +// `multigraph`: { +// model: ` +// model +// schema 1.1 +// type user +// type state +// relations +// define can_view: [user] or member +// define member: [user] +// type transition +// relations +// define start: [state] +// define end: [state] +// define can_apply: [user] and can_view from start and can_view from end +// type group +// relations +// define owner: [user, transition#can_apply] +// define max_owner: [group#owner, group#max_owner]`, +// expectedOutput: `digraph { +//graph [ +//rankdir=TB +//]; +// +//// Node definitions. +//0 [label=group]; +//1 [label="group#max_owner - weights:[user=+∞]"]; +//2 [label="group#owner - weights:[user=3]"]; +//3 [label=user]; +//4 [label="transition#can_apply - weights:[user=2]"]; +//5 [label=state]; +//6 [label="state#can_view - weights:[user=1]"]; +//7 [label="union - weights:[user=1]"]; +//8 [label="state#member - weights:[user=1]"]; +//9 [label=transition]; +//10 [label="intersection - weights:[user=2]"]; +//11 [label="transition#end - weights:[state=1]"]; +//12 [label="transition#start - weights:[state=1]"]; +// +//// Edge definitions. +//1 -> 1 [label="direct - weights:[user=+∞]"]; +//1 -> 2 [label="direct - weights:[user=4]"]; +//2 -> 3 [label="direct - weights:[user=1]"]; +//2 -> 4 [label="direct - weights:[user=3]"]; +//4 -> 10 [label="weights:[user=2]"]; +//6 -> 7 [label="weights:[user=1]"]; +//7 -> 3 [label="direct - weights:[user=1]"]; +//7 -> 8 [label="weights:[user=1]"]; +//8 -> 3 [label="direct - weights:[user=1]"]; +//10 -> 3 [label="direct - weights:[user=1]"]; +//10 -> 6 [ +//headlabel="(transition#start)" +//label="weights:[user=2]" +//]; +//10 -> 6 [ +//headlabel="(transition#end)" +//label="weights:[user=2]" +//]; +//11 -> 5 [label="direct - weights:[state=1]"]; +//12 -> 5 [label="direct - weights:[state=1]"]; +//}`, +// }, +// } +// for name, testCase := range testCases { +// t.Run(name, func(t *testing.T) { +// t.Parallel() +// +// model := language.MustTransformDSLToProto(testCase.model) +// weightedGraph, err := NewWeightedAuthorizationModelGraphBuilder().Build(model) +// if testCase.expectedError != nil { +// require.ErrorIs(t, err, testCase.expectedError) +// } else { +// require.NoError(t, err) +// +// actualDOT := weightedGraph.GetDOT() +// actualSorted := getSorted(actualDOT) +// expectedSorted := getSorted(testCase.expectedOutput) +// +// diff := cmp.Diff(expectedSorted, actualSorted) +// +// require.Empty(t, diff, "expected %s\ngot\n%s", testCase.expectedOutput, actualDOT) +// } +// }) +// } +//} diff --git a/pkg/go/graph/weighted_graph_edge.go b/pkg/go/graph/weighted_graph_edge.go index 065601e6..e34ccac6 100644 --- a/pkg/go/graph/weighted_graph_edge.go +++ b/pkg/go/graph/weighted_graph_edge.go @@ -1,12 +1,49 @@ package graph +import ( + "fmt" + + "gonum.org/v1/gonum/graph/encoding" +) + type WeightedAuthorizationModelEdge struct { - weights map[string]int - edgeType EdgeType - conditionedOn string - from *WeightedAuthorizationModelNode - to *WeightedAuthorizationModelNode - wildcards []string + *AuthorizationModelEdge + weights WeightMap + wildcards []string +} + +func NewWeightedAuthorizationModelEdge(edge *AuthorizationModelEdge) *WeightedAuthorizationModelEdge { + return &WeightedAuthorizationModelEdge{ + AuthorizationModelEdge: edge, + weights: make(WeightMap), + wildcards: make([]string, 0), + } +} + +var _ encoding.Attributer = (*WeightedAuthorizationModelEdge)(nil) + +func (edge *WeightedAuthorizationModelEdge) Attributes() []encoding.Attribute { + weightsStr := edge.weights + labelSet := false + attrs := make([]encoding.Attribute, 0, len(edge.AuthorizationModelEdge.Attributes())) + for _, attr := range edge.AuthorizationModelEdge.Attributes() { + if attr.Key == "label" { + labelSet = true + if len(weightsStr) > 0 { + attr.Value += fmt.Sprintf(" - %v", weightsStr) + } + } + attrs = append(attrs, attr) + } + + if !labelSet { + attrs = append(attrs, encoding.Attribute{ + Key: "label", + Value: fmt.Sprintf("%v", weightsStr), + }) + } + + return attrs } // GetWeights returns the entire weights map. @@ -32,12 +69,14 @@ func (edge *WeightedAuthorizationModelEdge) GetConditionedOn() string { // GetFrom returns the from node. func (edge *WeightedAuthorizationModelEdge) GetFrom() *WeightedAuthorizationModelNode { - return edge.from + from, _ := edge.AuthorizationModelEdge.From().(*WeightedAuthorizationModelNode) + return from } // GetTo returns the to node. func (edge *WeightedAuthorizationModelEdge) GetTo() *WeightedAuthorizationModelNode { - return edge.to + to, _ := edge.AuthorizationModelEdge.To().(*WeightedAuthorizationModelNode) + return to } // GetWildcards returns the wildcards. diff --git a/pkg/go/graph/weighted_graph_node.go b/pkg/go/graph/weighted_graph_node.go index 8df4d436..cb44e338 100644 --- a/pkg/go/graph/weighted_graph_node.go +++ b/pkg/go/graph/weighted_graph_node.go @@ -1,11 +1,55 @@ package graph +import ( + "fmt" + + "gonum.org/v1/gonum/graph/encoding" +) + type WeightedAuthorizationModelNode struct { - weights map[string]int - nodeType NodeType - label string // e.g. "group#member", UnionOperator, IntersectionOperator, ExclusionOperator - uniqueLabel string - wildcards []string + *AuthorizationModelNode + weights WeightMap + wildcards []string +} + +func NewWeightedAuthorizationModelNode(node *AuthorizationModelNode) *WeightedAuthorizationModelNode { + n := &WeightedAuthorizationModelNode{ + AuthorizationModelNode: node, + weights: make(WeightMap), + wildcards: make([]string, 0), + } + + if node.nodeType == SpecificTypeWildcard { + n.wildcards = append(n.wildcards, node.uniqueLabel[:len(node.uniqueLabel)-2]) + } + + return n +} + +var _ encoding.Attributer = (*WeightedAuthorizationModelNode)(nil) + +func (node *WeightedAuthorizationModelNode) Attributes() []encoding.Attribute { + weightsStr := node.weights + labelSet := false + attrs := make([]encoding.Attribute, 0, len(node.AuthorizationModelNode.Attributes())) + for _, attr := range node.AuthorizationModelNode.Attributes() { + if attr.Key == "label" { + labelSet = true + if len(weightsStr) > 0 { + attr.Value += fmt.Sprintf(" - %v", weightsStr) + } + } + attrs = append(attrs, attr) + } + + if !labelSet { + attrs = append(attrs, encoding.Attribute{ + Key: "label", + Value: fmt.Sprintf("%v", weightsStr), + }) + } + + return attrs } // GetWeights returns the entire weights map. diff --git a/pkg/go/graph/weights.go b/pkg/go/graph/weights.go new file mode 100644 index 00000000..f96b294e --- /dev/null +++ b/pkg/go/graph/weights.go @@ -0,0 +1,36 @@ +package graph + +import ( + "fmt" + "sort" + "strconv" + "strings" +) + +// WeightMap is a map of where the key is a type (e.g. folder, user) and the value is the weight/complexity to reach that type. +type WeightMap map[string]int + +func (wt WeightMap) String() string { + var sb strings.Builder + + // Extract keys and sort them + keys := make([]string, 0, len(wt)) + for k := range wt { + keys = append(keys, k) + } + sort.Strings(keys) + + for _, k := range keys { + formatV := strconv.Itoa(wt[k]) + if wt[k] == Infinite { + formatV = "+∞" + } + sb.WriteString(fmt.Sprintf("%v=%s,", k, formatV)) + } + formattedWeights := sb.String() + if len(formattedWeights) > 0 { + formattedWeights = formattedWeights[:len(formattedWeights)-1] + } + + return fmt.Sprintf("weights:[%v]", formattedWeights) +} diff --git a/pkg/go/graph/weights_test.go b/pkg/go/graph/weights_test.go new file mode 100644 index 00000000..f9478800 --- /dev/null +++ b/pkg/go/graph/weights_test.go @@ -0,0 +1,18 @@ +package graph + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestWeightsMapToString(t *testing.T) { + t.Parallel() + wm := WeightMap{"user": 1, "group": 2} + + require.Equal(t, "weights:[group=2,user=1]", wm.String()) + + wm = WeightMap{"user": 1, "group": Infinite} + + require.Equal(t, "weights:[group=+∞,user=1]", wm.String()) +}