Skip to content

Commit ef76d4d

Browse files
authored
Implement vector switch operator (#5538)
1 parent 1a79b43 commit ef76d4d

13 files changed

+232
-5
lines changed

compiler/kernel/vop.go

+58-4
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"errors"
55
"fmt"
66

7+
"github.com/brimdata/super"
78
"github.com/brimdata/super/compiler/dag"
89
"github.com/brimdata/super/compiler/optimizer"
910
"github.com/brimdata/super/pkg/field"
@@ -39,10 +40,10 @@ func (b *Builder) compileVam(o dag.Op, parents []vector.Puller) ([]vector.Puller
3940
case *dag.Scope:
4041
//return b.compileVecScope(o, parents)
4142
case *dag.Switch:
42-
//if o.Expr != nil {
43-
// return b.compileVamExprSwitch(o, parents)
44-
//}
45-
//return b.compileVecSwitch(o, parents)
43+
if o.Expr != nil {
44+
return b.compileVamExprSwitch(o, parents)
45+
}
46+
return b.compileVamSwitch(o, parents)
4647
default:
4748
var parent vector.Puller
4849
if len(parents) == 1 {
@@ -114,6 +115,59 @@ func (b *Builder) compileVamScatter(scatter *dag.Scatter, parents []vector.Pulle
114115
return ops, nil
115116
}
116117

118+
func (b *Builder) compileVamExprSwitch(swtch *dag.Switch, parents []vector.Puller) ([]vector.Puller, error) {
119+
parent := parents[0]
120+
if len(parents) > 1 {
121+
parent = vamop.NewCombine(b.rctx, parents)
122+
}
123+
e, err := b.compileVamExpr(swtch.Expr)
124+
if err != nil {
125+
return nil, err
126+
}
127+
s := vamop.NewExprSwitch(b.rctx, parent, e)
128+
var exits []vector.Puller
129+
for _, c := range swtch.Cases {
130+
var val *super.Value
131+
if c.Expr != nil {
132+
val2, err := b.evalAtCompileTime(c.Expr)
133+
if err != nil {
134+
return nil, err
135+
}
136+
if val2.IsError() {
137+
return nil, errors.New("switch case is not a constant expression")
138+
}
139+
val = &val2
140+
}
141+
parents, err := b.compileVamSeq(c.Path, []vector.Puller{s.AddCase(val)})
142+
if err != nil {
143+
return nil, err
144+
}
145+
exits = append(exits, parents...)
146+
}
147+
return exits, nil
148+
}
149+
150+
func (b *Builder) compileVamSwitch(swtch *dag.Switch, parents []vector.Puller) ([]vector.Puller, error) {
151+
parent := parents[0]
152+
if len(parents) > 1 {
153+
parent = vamop.NewCombine(b.rctx, parents)
154+
}
155+
s := vamop.NewSwitch(b.rctx, parent)
156+
var exits []vector.Puller
157+
for _, c := range swtch.Cases {
158+
e, err := b.compileVamExpr(c.Expr)
159+
if err != nil {
160+
return nil, fmt.Errorf("compiling switch case filter: %w", err)
161+
}
162+
exit, err := b.compileVamSeq(c.Path, []vector.Puller{s.AddCase(e)})
163+
if err != nil {
164+
return nil, err
165+
}
166+
exits = append(exits, exit...)
167+
}
168+
return exits, nil
169+
}
170+
117171
func (b *Builder) compileVamLeaf(o dag.Op, parent vector.Puller) (vector.Puller, error) {
118172
switch o := o.(type) {
119173
case *dag.Cut:

runtime/vam/op/exprswitch.go

+59
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
package op
2+
3+
import (
4+
"context"
5+
6+
"github.com/brimdata/super"
7+
"github.com/brimdata/super/runtime/vam/expr"
8+
"github.com/brimdata/super/vector"
9+
"github.com/brimdata/super/zcode"
10+
)
11+
12+
type ExprSwitch struct {
13+
expr expr.Evaluator
14+
router *router
15+
16+
builder zcode.Builder
17+
cases map[string]*route
18+
caseIndexes map[*route][]uint32
19+
defaultRoute *route
20+
}
21+
22+
func NewExprSwitch(ctx context.Context, parent vector.Puller, e expr.Evaluator) *ExprSwitch {
23+
s := &ExprSwitch{expr: e, cases: map[string]*route{}, caseIndexes: map[*route][]uint32{}}
24+
s.router = newRouter(ctx, s, parent)
25+
return s
26+
}
27+
28+
func (s *ExprSwitch) AddCase(val *super.Value) vector.Puller {
29+
r := s.router.addRoute()
30+
if val == nil {
31+
s.defaultRoute = r
32+
} else {
33+
s.cases[string(val.Bytes())] = r
34+
}
35+
return r
36+
}
37+
38+
func (s *ExprSwitch) forward(vec vector.Any) bool {
39+
defer clear(s.caseIndexes)
40+
exprVec := s.expr.Eval(vec)
41+
for i := range exprVec.Len() {
42+
s.builder.Truncate()
43+
exprVec.Serialize(&s.builder, i)
44+
route, ok := s.cases[string(s.builder.Bytes().Body())]
45+
if !ok {
46+
route = s.defaultRoute
47+
}
48+
if route != nil {
49+
s.caseIndexes[route] = append(s.caseIndexes[route], i)
50+
}
51+
}
52+
for route, index := range s.caseIndexes {
53+
view := vector.NewView(vec, index)
54+
if !route.send(view, nil) {
55+
return false
56+
}
57+
}
58+
return true
59+
}

runtime/vam/op/router.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ func newRouter(ctx context.Context, f forwarder, parent vector.Puller) *router {
2525
return &router{ctx: ctx, forwarder: f, parent: parent}
2626
}
2727

28-
func (r *router) addRoute() vector.Puller {
28+
func (r *router) addRoute() *route {
2929
route := &route{r, make(chan result), make(chan struct{}), false}
3030
r.routes = append(r.routes, route)
3131
return route

runtime/vam/op/swtich.go

+94
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
package op
2+
3+
import (
4+
"context"
5+
6+
"github.com/RoaringBitmap/roaring"
7+
"github.com/brimdata/super"
8+
"github.com/brimdata/super/runtime/vam/expr"
9+
"github.com/brimdata/super/vector"
10+
)
11+
12+
type Switch struct {
13+
router *router
14+
cases []expr.Evaluator
15+
}
16+
17+
func NewSwitch(ctx context.Context, parent vector.Puller) *Switch {
18+
s := &Switch{}
19+
s.router = newRouter(ctx, s, parent)
20+
return s
21+
}
22+
23+
func (s *Switch) AddCase(e expr.Evaluator) vector.Puller {
24+
s.cases = append(s.cases, e)
25+
return s.router.addRoute()
26+
}
27+
28+
func (s *Switch) forward(vec vector.Any) bool {
29+
doneMap := roaring.New()
30+
for i, c := range s.cases {
31+
maskVec := c.Eval(vec)
32+
boolMap, errMap := expr.BoolMask(maskVec)
33+
boolMap.AndNot(doneMap)
34+
errMap.AndNot(doneMap)
35+
doneMap.Or(boolMap)
36+
if !errMap.IsEmpty() {
37+
// Clone because iteration results are undefined if the bitmap is modified.
38+
for it := errMap.Clone().Iterator(); it.HasNext(); {
39+
i := it.Next()
40+
if isErrorMissing(maskVec, i) {
41+
errMap.Remove(i)
42+
}
43+
}
44+
}
45+
var vec2 vector.Any
46+
if errMap.IsEmpty() {
47+
if boolMap.IsEmpty() {
48+
continue
49+
}
50+
vec2 = vector.NewView(vec, boolMap.ToArray())
51+
} else if boolMap.IsEmpty() {
52+
vec2 = vector.NewView(maskVec, errMap.ToArray())
53+
} else {
54+
valIndex := boolMap.ToArray()
55+
errIndex := errMap.ToArray()
56+
tags := make([]uint32, 0, len(valIndex)+len(errIndex))
57+
for len(valIndex) > 0 && len(errIndex) > 0 {
58+
if valIndex[0] < errIndex[0] {
59+
valIndex = valIndex[1:]
60+
tags = append(tags, 0)
61+
} else {
62+
errIndex = errIndex[1:]
63+
tags = append(tags, 1)
64+
}
65+
}
66+
tags = append(tags, valIndex...)
67+
tags = append(tags, errIndex...)
68+
valVec := vector.NewView(vec, valIndex)
69+
errVec := vector.NewView(maskVec, errIndex)
70+
vec2 = vector.NewDynamic(tags, []vector.Any{valVec, errVec})
71+
}
72+
if !s.router.routes[i].send(vec2, nil) {
73+
return false
74+
}
75+
}
76+
return true
77+
}
78+
79+
func isErrorMissing(vec vector.Any, i uint32) bool {
80+
vec = vector.Under(vec)
81+
if dynVec, ok := vec.(*vector.Dynamic); ok {
82+
vec = dynVec.Values[dynVec.Tags[i]]
83+
i = dynVec.TagMap.Forward[i]
84+
}
85+
errVec, ok := vec.(*vector.Error)
86+
if !ok {
87+
return false
88+
}
89+
if errVec.Vals.Type().ID() != super.IDString {
90+
return false
91+
}
92+
s, null := vector.StringValue(errVec.Vals, i)
93+
return !null && s == string(super.Missing)
94+
}

runtime/sam/op/switcher/ztests/switch-chained.yaml runtime/ztests/op/switch-chained.yaml

+2
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ zed: |
1212
case this==3 => yield 4
1313
)
1414
15+
vector: true
16+
1517
input: |
1618
1
1719

runtime/sam/op/switcher/ztests/switch-default.yaml runtime/ztests/op/switch-default.yaml

+2
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@ zed: |
66
default => count:=count() |> put a:=-1
77
) |> sort a
88
9+
vector: true
10+
911
input: |
1012
{a:1,s:"a"}
1113
{a:2,s:"B"}

runtime/sam/op/switcher/ztests/switch-error.yaml runtime/ztests/op/switch-error.yaml

+4
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,16 @@ zed: |
22
switch (
33
case a == 1 => put v:='one'
44
case a / 0 => put v:='xxx'
5+
case a % 0 => put v:='yyy'
56
) |> sort this
67
8+
vector: true
9+
710
input: |
811
{a:1,s:"a"}
912
{a:2,s:"b"}
1013
1114
output: |
1215
{a:1,s:"a",v:"one"}
1316
error("divide by zero")
17+
error("divide by zero")

runtime/sam/op/exprswitch/ztests/switch-default.yaml runtime/ztests/op/switch-expr-default.yaml

+2
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@ zed: |
66
default => count:=count() |> put a:=-1
77
) |> sort a
88
9+
vector: true
10+
911
input: |
1012
{a:1,s:"a"}
1113
{a:2,s:"B"}

runtime/sam/op/exprswitch/ztests/switch-done.yaml runtime/ztests/op/switch-expr-done.yaml

+2
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@ zed: |
55
default => pass
66
) |> sort b
77
8+
vector: true
9+
810
input: |
911
{a:1,b:1}
1012
{a:2,b:2}

runtime/sam/op/exprswitch/ztests/switch-over.yaml runtime/ztests/op/switch-expr-over.yaml

+2
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ zed: |
44
default => over a |> yield {b:this}
55
) |> sort this
66
7+
vector: true
8+
79
input: |
810
{a:[1,2,3]}
911
{a:[6,7,8,9]}

runtime/sam/op/exprswitch/ztests/switch.yaml runtime/ztests/op/switch-expr.yaml

+2
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@ zed: |
55
case 3 => ? null
66
) |> sort a
77
8+
vector: true
9+
810
input: |
911
{a:1(int32),s:"a"}
1012
{a:2(int32),s:"B"}

runtime/sam/op/switcher/ztests/switch-over.yaml runtime/ztests/op/switch-over.yaml

+2
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ zed: |
44
default => over a |> yield {b:this}
55
) |> sort this
66
7+
vector: true
8+
79
input: |
810
{a:[1,2,3]}
911
{a:[6,7,8,9]}

runtime/sam/op/switcher/ztests/switch.yaml runtime/ztests/op/switch.yaml

+2
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@ zed: |
66
case true => count:=count() |> put a:=-1
77
) |> sort a
88
9+
vector: true
10+
911
input: |
1012
{a:1(int32),s:"a"}
1113
{a:2(int32),s:"B"}

0 commit comments

Comments
 (0)