forked from remind101/empire
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdb_test.go
84 lines (67 loc) · 1.7 KB
/
db_test.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
package empire
import (
"reflect"
"strings"
"testing"
"time"
gosql "database/sql"
"github.com/jinzhu/gorm"
)
func TestComposedScope(t *testing.T) {
var scope composedScope
a, b := make(chan struct{}), make(chan struct{})
scope = append(scope, MockScope(a))
scope = append(scope, MockScope(b))
db := &gorm.DB{}
go scope.scope(db)
select {
case <-a:
case <-time.After(time.Second):
t.Fatal("Expected a to be called")
}
select {
case <-b:
default:
t.Fatal("Expected b to be called")
}
}
// MockScope is a Scope implementation that closes the channel when it is
// called.
func MockScope(called chan struct{}) scope {
return scopeFunc(func(db *gorm.DB) *gorm.DB {
close(called)
return db
})
}
// scopeTest is a struct for testing scopes.
type scopeTest struct {
scope scope
sql string
vars []interface{}
}
// scopeTests provides a convenient way to run assertScopeSql on multiple
// scopeTest instances.
type scopeTests []scopeTest
// Run calls assertScopeSql for each scopeTest.
func (tests scopeTests) Run(t testing.TB) {
for i, tt := range tests {
sql, vars := conditionSql(tt.scope)
if got, want := sql, tt.sql; got != want {
t.Fatalf("#%d: SQL => %v; want %v", i, got, want)
}
if got, want := vars, tt.vars; !reflect.DeepEqual(got, want) {
if len(got) > 0 && len(want) > 0 {
t.Fatalf("#%d: Vars => %v; want %v", i, got, want)
}
}
}
}
// conditionSql takes a scope and generates the condition sql that gorm will use
// for the query.
func conditionSql(scope scope) (sql string, vars []interface{}) {
db, _ := gorm.Open("postgres", &gosql.DB{})
ds := scope.scope(&db).NewScope(nil)
sql = strings.TrimSpace(ds.CombinedConditionSql())
vars = ds.SqlVars
return
}