-
Notifications
You must be signed in to change notification settings - Fork 0
/
scaff.go
171 lines (156 loc) · 3.76 KB
/
scaff.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
package goraff
import (
"fmt"
"sync"
)
// Scaff represents blueprint of blocks
// When it runs, it will create a graph of data
type Scaff struct {
entrypoint *Block
joins *Joins
blocks *Blocks
}
func NewScaff() *Scaff {
return &Scaff{}
}
func (g *Scaff) Blocks() *Blocks {
if g.blocks == nil {
g.blocks = &Blocks{}
}
return g.blocks
}
func (g *Scaff) Joins() *Joins {
if g.joins == nil {
g.joins = &Joins{
Blocks: g.Blocks(),
}
}
return g.joins
}
func (g *Scaff) SetEntrypoint(name string) {
n := g.blocks.Get(name)
g.entrypoint = n
}
func (g *Scaff) Go(graph *Graph) error {
if graph == nil {
return fmt.Errorf("graph not provided")
}
err := g.validate()
if err != nil {
return fmt.Errorf("error validating graph: %w", err)
}
return g.flowMgr(graph)
}
func (g *Scaff) validate() error {
if g.entrypoint == nil {
return fmt.Errorf("entrypoint not set")
}
// check blocks
err := g.Blocks().Validate()
if err != nil {
return fmt.Errorf("error validating blocks: %w", err)
}
// check joins
err = g.Joins().Validate()
if err != nil {
return fmt.Errorf("error validating joins: %w", err)
}
return nil
}
type nextJoin struct {
Join *Join
previousNode *Node
}
func (g *Scaff) flowMgr(graph *Graph) error {
if g.entrypoint == nil {
return fmt.Errorf("entrypoint not set")
}
completedCh := make(chan nextJoin, 10)
var wg sync.WaitGroup
completedCh <- nextJoin{
Join: &Join{From: nil, To: g.entrypoint},
previousNode: nil,
}
wg.Add(1) // Increment for the initial node
fmt.Println("starting block", g.entrypoint.Name)
var foundErr error
mut := sync.Mutex{}
go func() {
for n := range completedCh {
// check Trigger before launching goroutine to prevent join race conditions
if n.previousNode != nil {
n.previousNode.MarkDone()
}
if n.Join == nil {
wg.Done()
continue
}
fmt.Println("considering block", n.Join.To.Name)
r := NewReadableGraph(graph)
t, err := n.Join.TriggersMet(r)
if err != nil {
fmt.Printf("error checking join condition: %s\n", err.Error())
wg.Done()
continue
}
if !t {
fmt.Printf("join condition not met To: %s\n", n.Join.To.Name)
wg.Done()
continue
}
fmt.Printf("join condition met To: %s\n", n.Join.To.Name)
// launch goroutine
go func(n nextJoin) {
defer wg.Done() // Ensure we mark this goroutine as done on finish
// run block
block := n.Join.To
defer fmt.Printf("finished block %s\n", n.Join.To.Name)
fmt.Println("starting block", block.Name)
if foundErr != nil {
return
}
var tr *ReadableNode = nil
if n.previousNode != nil {
tr = n.previousNode.Get()
}
completedNode, err := g.runBlock(graph, block, tr)
if err != nil {
fmt.Printf("error running block %s, letting all active blocks drain: %s \n", block.Name, err.Error())
mut.Lock()
foundErr = fmt.Errorf("error running block: %w", err)
mut.Unlock()
return
}
joins := g.Joins().Get(block.Name)
for _, j := range joins {
fmt.Println("queueing block join", j.To.Name)
wg.Add(1) // Increment for each new block
completedCh <- nextJoin{
previousNode: completedNode,
Join: j,
}
}
if len(joins) == 0 {
wg.Add(1) // Increment
completedCh <- nextJoin{
previousNode: completedNode,
Join: nil,
}
}
}(n)
}
}()
wg.Wait() // Wait for all goroutines to finish
close(completedCh) // Safe to close here as no more writes will happen
return foundErr
}
func (s *Scaff) runBlock(g *Graph, b *Block, triggeringNS *ReadableNode) (*Node, error) {
n := g.NewNode(b.Name, nil)
r := NewReadableGraph(g)
err := b.Action.Do(n, r, triggeringNS)
if err != nil {
return nil, err
}
// s.MarkDone()
return n, nil
}