-
Notifications
You must be signed in to change notification settings - Fork 46
/
Copy pathvisitors.go
499 lines (429 loc) · 15.1 KB
/
visitors.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package iceberg
import (
"fmt"
"maps"
"math"
"slices"
"strings"
"github.com/google/uuid"
)
// BooleanExprVisitor is an interface for recursively visiting the nodes of a
// boolean expression
type BooleanExprVisitor[T any] interface {
VisitTrue() T
VisitFalse() T
VisitNot(childResult T) T
VisitAnd(left, right T) T
VisitOr(left, right T) T
VisitUnbound(UnboundPredicate) T
VisitBound(BoundPredicate) T
}
// BoundBooleanExprVisitor builds on BooleanExprVisitor by adding interface
// methods for visiting bound expressions, because we do casting of literals
// during binding you can assume that the BoundTerm and the Literal passed
// to a method have the same type.
type BoundBooleanExprVisitor[T any] interface {
BooleanExprVisitor[T]
VisitIn(BoundTerm, Set[Literal]) T
VisitNotIn(BoundTerm, Set[Literal]) T
VisitIsNan(BoundTerm) T
VisitNotNan(BoundTerm) T
VisitIsNull(BoundTerm) T
VisitNotNull(BoundTerm) T
VisitEqual(BoundTerm, Literal) T
VisitNotEqual(BoundTerm, Literal) T
VisitGreaterEqual(BoundTerm, Literal) T
VisitGreater(BoundTerm, Literal) T
VisitLessEqual(BoundTerm, Literal) T
VisitLess(BoundTerm, Literal) T
VisitStartsWith(BoundTerm, Literal) T
VisitNotStartsWith(BoundTerm, Literal) T
}
// VisitExpr is a convenience function to use a given visitor to visit all parts of
// a boolean expression in-order. Values returned from the methods are passed to the
// subsequent methods, effectively "bubbling up" the results.
func VisitExpr[T any](expr BooleanExpression, visitor BooleanExprVisitor[T]) (res T, err error) {
defer func() {
if r := recover(); r != nil {
switch e := r.(type) {
case string:
err = fmt.Errorf("error encountered during visitExpr: %s", e)
case error:
err = e
}
}
}()
return visitBoolExpr(expr, visitor), err
}
func visitBoolExpr[T any](e BooleanExpression, visitor BooleanExprVisitor[T]) T {
switch e := e.(type) {
case AlwaysFalse:
return visitor.VisitFalse()
case AlwaysTrue:
return visitor.VisitTrue()
case AndExpr:
left, right := visitBoolExpr(e.left, visitor), visitBoolExpr(e.right, visitor)
return visitor.VisitAnd(left, right)
case OrExpr:
left, right := visitBoolExpr(e.left, visitor), visitBoolExpr(e.right, visitor)
return visitor.VisitOr(left, right)
case NotExpr:
child := visitBoolExpr(e.child, visitor)
return visitor.VisitNot(child)
case UnboundPredicate:
return visitor.VisitUnbound(e)
case BoundPredicate:
return visitor.VisitBound(e)
}
panic(fmt.Errorf("%w: VisitBooleanExpression type %s", ErrNotImplemented, e))
}
// VisitBoundPredicate uses a BoundBooleanExprVisitor to call the appropriate method
// based on the type of operation in the predicate. This is a convenience function
// for implementing the VisitBound method of a BoundBooleanExprVisitor by simply calling
// iceberg.VisitBoundPredicate(pred, this).
func VisitBoundPredicate[T any](e BoundPredicate, visitor BoundBooleanExprVisitor[T]) T {
switch e.Op() {
case OpIn:
return visitor.VisitIn(e.Term(), e.(BoundSetPredicate).Literals())
case OpNotIn:
return visitor.VisitNotIn(e.Term(), e.(BoundSetPredicate).Literals())
case OpIsNan:
return visitor.VisitIsNan(e.Term())
case OpNotNan:
return visitor.VisitNotNan(e.Term())
case OpIsNull:
return visitor.VisitIsNull(e.Term())
case OpNotNull:
return visitor.VisitNotNull(e.Term())
case OpEQ:
return visitor.VisitEqual(e.Term(), e.(BoundLiteralPredicate).Literal())
case OpNEQ:
return visitor.VisitNotEqual(e.Term(), e.(BoundLiteralPredicate).Literal())
case OpGTEQ:
return visitor.VisitGreaterEqual(e.Term(), e.(BoundLiteralPredicate).Literal())
case OpGT:
return visitor.VisitGreater(e.Term(), e.(BoundLiteralPredicate).Literal())
case OpLTEQ:
return visitor.VisitLessEqual(e.Term(), e.(BoundLiteralPredicate).Literal())
case OpLT:
return visitor.VisitLess(e.Term(), e.(BoundLiteralPredicate).Literal())
case OpStartsWith:
return visitor.VisitStartsWith(e.Term(), e.(BoundLiteralPredicate).Literal())
case OpNotStartsWith:
return visitor.VisitNotStartsWith(e.Term(), e.(BoundLiteralPredicate).Literal())
}
panic(fmt.Errorf("%w: unhandled bound predicate type: %s", ErrNotImplemented, e))
}
// BindExpr recursively binds each portion of an expression using the provided schema.
// Because the expression can end up being simplified to just AlwaysTrue/AlwaysFalse,
// this returns a BooleanExpression.
func BindExpr(s *Schema, expr BooleanExpression, caseSensitive bool) (BooleanExpression, error) {
return VisitExpr(expr, &bindVisitor{schema: s, caseSensitive: caseSensitive})
}
type bindVisitor struct {
schema *Schema
caseSensitive bool
}
func (*bindVisitor) VisitTrue() BooleanExpression { return AlwaysTrue{} }
func (*bindVisitor) VisitFalse() BooleanExpression { return AlwaysFalse{} }
func (*bindVisitor) VisitNot(child BooleanExpression) BooleanExpression {
return NewNot(child)
}
func (*bindVisitor) VisitAnd(left, right BooleanExpression) BooleanExpression {
return NewAnd(left, right)
}
func (*bindVisitor) VisitOr(left, right BooleanExpression) BooleanExpression {
return NewOr(left, right)
}
func (b *bindVisitor) VisitUnbound(pred UnboundPredicate) BooleanExpression {
expr, err := pred.Bind(b.schema, b.caseSensitive)
if err != nil {
panic(err)
}
return expr
}
func (*bindVisitor) VisitBound(pred BoundPredicate) BooleanExpression {
panic(fmt.Errorf("%w: found already bound predicate: %s", ErrInvalidArgument, pred))
}
// ExpressionEvaluator returns a function which can be used to evaluate a given expression
// as long as a structlike value is passed which operates like and matches the passed in
// schema.
func ExpressionEvaluator(s *Schema, unbound BooleanExpression, caseSensitive bool) (func(structLike) (bool, error), error) {
bound, err := BindExpr(s, unbound, caseSensitive)
if err != nil {
return nil, err
}
return (&exprEvaluator{bound: bound}).Eval, nil
}
type exprEvaluator struct {
bound BooleanExpression
st structLike
}
func (e *exprEvaluator) Eval(st structLike) (bool, error) {
e.st = st
return VisitExpr(e.bound, e)
}
func (e *exprEvaluator) VisitUnbound(UnboundPredicate) bool {
panic("found unbound predicate when evaluating expression")
}
func (e *exprEvaluator) VisitBound(pred BoundPredicate) bool {
return VisitBoundPredicate(pred, e)
}
func (*exprEvaluator) VisitTrue() bool { return true }
func (*exprEvaluator) VisitFalse() bool { return false }
func (*exprEvaluator) VisitNot(child bool) bool { return !child }
func (*exprEvaluator) VisitAnd(left, right bool) bool { return left && right }
func (*exprEvaluator) VisitOr(left, right bool) bool { return left || right }
func (e *exprEvaluator) VisitIn(term BoundTerm, literals Set[Literal]) bool {
v := term.evalToLiteral(e.st)
if !v.Valid {
return false
}
return literals.Contains(v.Val)
}
func (e *exprEvaluator) VisitNotIn(term BoundTerm, literals Set[Literal]) bool {
return !e.VisitIn(term, literals)
}
func (e *exprEvaluator) VisitIsNan(term BoundTerm) bool {
switch term.Type().(type) {
case Float32Type:
v := term.(bound[float32]).eval(e.st)
if !v.Valid {
break
}
return math.IsNaN(float64(v.Val))
case Float64Type:
v := term.(bound[float64]).eval(e.st)
if !v.Valid {
break
}
return math.IsNaN(v.Val)
}
return false
}
func (e *exprEvaluator) VisitNotNan(term BoundTerm) bool {
return !e.VisitIsNan(term)
}
func (e *exprEvaluator) VisitIsNull(term BoundTerm) bool {
return term.evalIsNull(e.st)
}
func (e *exprEvaluator) VisitNotNull(term BoundTerm) bool {
return !term.evalIsNull(e.st)
}
func nullsFirstCmp[T LiteralType](cmp Comparator[T], v1, v2 Optional[T]) int {
if !v1.Valid {
if !v2.Valid {
// both are null
return 0
}
// v1 is null, v2 is not
return -1
}
if !v2.Valid {
return 1
}
return cmp(v1.Val, v2.Val)
}
func typedCmp[T LiteralType](st structLike, term BoundTerm, lit Literal) int {
v := term.(bound[T]).eval(st)
var l Optional[T]
rhs := lit.(TypedLiteral[T])
if lit != nil {
l.Valid = true
l.Val = rhs.Value()
}
return nullsFirstCmp(rhs.Comparator(), v, l)
}
func doCmp(st structLike, term BoundTerm, lit Literal) int {
// we already properly casted and converted everything during binding
// so we can type assert based on the term type
switch term.Type().(type) {
case BooleanType:
return typedCmp[bool](st, term, lit)
case Int32Type:
return typedCmp[int32](st, term, lit)
case Int64Type:
return typedCmp[int64](st, term, lit)
case Float32Type:
return typedCmp[float32](st, term, lit)
case Float64Type:
return typedCmp[float64](st, term, lit)
case DateType:
return typedCmp[Date](st, term, lit)
case TimeType:
return typedCmp[Time](st, term, lit)
case TimestampType, TimestampTzType:
return typedCmp[Timestamp](st, term, lit)
case BinaryType, FixedType:
return typedCmp[[]byte](st, term, lit)
case StringType:
return typedCmp[string](st, term, lit)
case UUIDType:
return typedCmp[uuid.UUID](st, term, lit)
case DecimalType:
return typedCmp[Decimal](st, term, lit)
}
panic(ErrType)
}
func (e *exprEvaluator) VisitEqual(term BoundTerm, lit Literal) bool {
return doCmp(e.st, term, lit) == 0
}
func (e *exprEvaluator) VisitNotEqual(term BoundTerm, lit Literal) bool {
return doCmp(e.st, term, lit) != 0
}
func (e *exprEvaluator) VisitGreater(term BoundTerm, lit Literal) bool {
return doCmp(e.st, term, lit) > 0
}
func (e *exprEvaluator) VisitGreaterEqual(term BoundTerm, lit Literal) bool {
return doCmp(e.st, term, lit) >= 0
}
func (e *exprEvaluator) VisitLess(term BoundTerm, lit Literal) bool {
return doCmp(e.st, term, lit) < 0
}
func (e *exprEvaluator) VisitLessEqual(term BoundTerm, lit Literal) bool {
return doCmp(e.st, term, lit) <= 0
}
func (e *exprEvaluator) VisitStartsWith(term BoundTerm, lit Literal) bool {
var value, prefix string
switch lit.(type) {
case TypedLiteral[string]:
val := term.(bound[string]).eval(e.st)
if !val.Valid {
return false
}
prefix, value = lit.(StringLiteral).Value(), val.Val
case TypedLiteral[[]byte]:
val := term.(bound[[]byte]).eval(e.st)
if !val.Valid {
return false
}
prefix, value = string(lit.(TypedLiteral[[]byte]).Value()), string(val.Val)
}
return strings.HasPrefix(value, prefix)
}
func (e *exprEvaluator) VisitNotStartsWith(term BoundTerm, lit Literal) bool {
return !e.VisitStartsWith(term, lit)
}
// RewriteNotExpr rewrites a boolean expression to remove "Not" nodes from the expression
// tree. This is because Projections assume there are no "not" nodes.
//
// Not nodes will be replaced with simply calling `Negate` on the child in the tree.
func RewriteNotExpr(expr BooleanExpression) (BooleanExpression, error) {
return VisitExpr(expr, rewriteNotVisitor{})
}
type rewriteNotVisitor struct{}
func (rewriteNotVisitor) VisitTrue() BooleanExpression { return AlwaysTrue{} }
func (rewriteNotVisitor) VisitFalse() BooleanExpression { return AlwaysFalse{} }
func (rewriteNotVisitor) VisitNot(child BooleanExpression) BooleanExpression {
return child.Negate()
}
func (rewriteNotVisitor) VisitAnd(left, right BooleanExpression) BooleanExpression {
return NewAnd(left, right)
}
func (rewriteNotVisitor) VisitOr(left, right BooleanExpression) BooleanExpression {
return NewOr(left, right)
}
func (rewriteNotVisitor) VisitUnbound(pred UnboundPredicate) BooleanExpression {
return pred
}
func (rewriteNotVisitor) VisitBound(pred BoundPredicate) BooleanExpression {
return pred
}
// ExtractFieldIDs returns a slice containing the field IDs which are referenced
// by any terms in the given expression. This enables retrieving exactly which
// fields are needed for an expression.
func ExtractFieldIDs(expr BooleanExpression) ([]int, error) {
res, err := VisitExpr(expr, expressionFieldIDs{})
if err != nil {
return nil, err
}
out := make([]int, 0, len(res))
return slices.AppendSeq(out, maps.Keys(res)), nil
}
type expressionFieldIDs struct{}
func (expressionFieldIDs) VisitTrue() map[int]struct{} {
return map[int]struct{}{}
}
func (expressionFieldIDs) VisitFalse() map[int]struct{} {
return map[int]struct{}{}
}
func (expressionFieldIDs) VisitNot(child map[int]struct{}) map[int]struct{} {
return child
}
func (expressionFieldIDs) VisitAnd(left, right map[int]struct{}) map[int]struct{} {
maps.Insert(left, maps.All(right))
return left
}
func (expressionFieldIDs) VisitOr(left, right map[int]struct{}) map[int]struct{} {
maps.Insert(left, maps.All(right))
return left
}
func (expressionFieldIDs) VisitUnbound(UnboundPredicate) map[int]struct{} {
panic("expression field IDs only works for bound expressions")
}
func (expressionFieldIDs) VisitBound(pred BoundPredicate) map[int]struct{} {
return map[int]struct{}{
pred.Ref().Field().ID: {},
}
}
// TranslateColumnNames converts the names of columns in an expression by looking up
// the field IDs in the file schema. If columns don't exist they are replaced with
// AlwaysFalse or AlwaysTrue depending on the operator.
func TranslateColumnNames(expr BooleanExpression, fileSchema *Schema) (BooleanExpression, error) {
return VisitExpr(expr, columnNameTranslator{fileSchema: fileSchema})
}
type columnNameTranslator struct {
fileSchema *Schema
}
func (columnNameTranslator) VisitTrue() BooleanExpression { return AlwaysTrue{} }
func (columnNameTranslator) VisitFalse() BooleanExpression { return AlwaysFalse{} }
func (columnNameTranslator) VisitNot(child BooleanExpression) BooleanExpression {
return NewNot(child)
}
func (columnNameTranslator) VisitAnd(left, right BooleanExpression) BooleanExpression {
return NewAnd(left, right)
}
func (columnNameTranslator) VisitOr(left, right BooleanExpression) BooleanExpression {
return NewOr(left, right)
}
func (columnNameTranslator) VisitUnbound(pred UnboundPredicate) BooleanExpression {
panic(fmt.Errorf("%w: expected bound predicate, got: %s", ErrInvalidArgument, pred.Term()))
}
func (c columnNameTranslator) VisitBound(pred BoundPredicate) BooleanExpression {
fileColName, found := c.fileSchema.FindColumnName(pred.Term().Ref().Field().ID)
if !found {
// in the case of schema evolution, the column might not be present
// in the file schema when reading older data
if pred.Op() == OpIsNull {
return AlwaysTrue{}
}
return AlwaysFalse{}
}
ref := Reference(fileColName)
switch p := pred.(type) {
case BoundUnaryPredicate:
return p.AsUnbound(ref)
case BoundLiteralPredicate:
return p.AsUnbound(ref, p.Literal())
case BoundSetPredicate:
return p.AsUnbound(ref, p.Literals().Members())
default:
panic(fmt.Errorf("%w: unsupported predicate: %s", ErrNotImplemented, pred))
}
}