Skip to content

Commit eaeb2dd

Browse files
committed
sqlite: add custom function support
1 parent 7889254 commit eaeb2dd

File tree

2 files changed

+328
-0
lines changed

2 files changed

+328
-0
lines changed

func.go

Lines changed: 203 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,203 @@
1+
// Copyright (c) 2018 David Crawshaw <[email protected]>
2+
//
3+
// Permission to use, copy, modify, and distribute this software for any
4+
// purpose with or without fee is hereby granted, provided that the above
5+
// copyright notice and this permission notice appear in all copies.
6+
//
7+
// THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
8+
// WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
9+
// MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
10+
// ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
11+
// WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
12+
// ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
13+
// OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
14+
15+
package sqlite
16+
17+
// #include <sqlite3.h>
18+
// #include <stdlib.h>
19+
// extern void func_tramp(sqlite3_context*, int, sqlite3_value**);
20+
// extern void step_tramp(sqlite3_context*, int, sqlite3_value**);
21+
// extern void final_tramp(sqlite3_context*);
22+
// extern void destroy_tramp(void*);
23+
import "C"
24+
import (
25+
"sync"
26+
"unsafe"
27+
)
28+
29+
// Context is an *sqlite3_context.
30+
// It is used by custom functions to return result values.
31+
// An SQLite context is in no way related to a Go context.Context.
32+
type Context struct {
33+
ptr *C.sqlite3_context
34+
}
35+
36+
func (ctx Context) UserData() interface{} {
37+
return getxfuncs(ctx.ptr).data
38+
}
39+
40+
func (ctx Context) SetUserData(data interface{}) {
41+
getxfuncs(ctx.ptr).data = data
42+
}
43+
44+
func (ctx Context) ResultInt(v int) { C.sqlite3_result_int(ctx.ptr, C.int(v)) }
45+
func (ctx Context) ResultInt64(v int64) { C.sqlite3_result_int64(ctx.ptr, C.sqlite3_int64(v)) }
46+
func (ctx Context) ResultFloat(v float64) { C.sqlite3_result_double(ctx.ptr, C.double(v)) }
47+
func (ctx Context) ResultNull() { C.sqlite3_result_null(ctx.ptr) }
48+
func (ctx Context) ResultValue(v Value) { C.sqlite3_result_value(ctx.ptr, v.ptr) }
49+
func (ctx Context) ResultZeroBlob(n int64) { C.sqlite3_result_zeroblob64(ctx.ptr, C.sqlite_uint64(n)) }
50+
func (ctx Context) ResultText(v string) {
51+
var cv *C.char
52+
if len(v) != 0 {
53+
cv = C.CString(v)
54+
}
55+
C.sqlite3_result_text(ctx.ptr, cv, C.int(len(v)), (*[0]byte)(C.free))
56+
}
57+
func (ctx Context) ResultError(err error) {
58+
if err, isError := err.(Error); isError {
59+
C.sqlite3_result_error_code(ctx.ptr, C.int(err.Code))
60+
return
61+
}
62+
errstr := err.Error()
63+
cerrstr := C.CString(errstr)
64+
defer C.free(unsafe.Pointer(cerrstr))
65+
C.sqlite3_result_error(ctx.ptr, cerrstr, C.int(len(errstr)))
66+
}
67+
68+
type Value struct {
69+
ptr *C.sqlite3_value
70+
}
71+
72+
func (v Value) Int() int { return int(C.sqlite3_value_int(v.ptr)) }
73+
func (v Value) Int64() int64 { return int64(C.sqlite3_value_int64(v.ptr)) }
74+
func (v Value) Float() float64 { return float64(C.sqlite3_value_double(v.ptr)) }
75+
func (v Value) Len() int { return int(C.sqlite3_value_bytes(v.ptr)) }
76+
func (v Value) Text() string {
77+
n := v.Len()
78+
return C.GoStringN((*C.char)(unsafe.Pointer(C.sqlite3_value_text(v.ptr))), C.int(n))
79+
}
80+
func (v Value) Blob() []byte {
81+
panic("TODO")
82+
}
83+
84+
type xfunc struct {
85+
id int
86+
name string
87+
conn *Conn
88+
xFunc func(Context, ...Value)
89+
xStep func(Context, ...Value)
90+
xFinal func(Context)
91+
data interface{}
92+
}
93+
94+
var xfuncs = struct {
95+
mu sync.RWMutex
96+
m map[int]*xfunc
97+
next int
98+
}{
99+
m: make(map[int]*xfunc),
100+
}
101+
102+
// CreateFunction registers a Go function with SQLite
103+
// for use in SQL queries.
104+
//
105+
// To define a scalar function, provide a value for
106+
// xFunc and set xStep/xFinal to nil.
107+
//
108+
// To define an aggregation set xFunc to nil and
109+
// provide values for xStep and xFinal.
110+
//
111+
// State can be stored across function calls by
112+
// using the Context UserData/SetUserData methods.
113+
//
114+
// https://sqlite.org/c3ref/create_function.html
115+
func (conn *Conn) CreateFunction(name string, deterministic bool, numArgs int, xFunc, xStep func(Context, ...Value), xFinal func(Context)) error {
116+
cname := C.CString(name) // TODO: free?
117+
eTextRep := C.int(C.SQLITE_UTF8)
118+
if deterministic {
119+
eTextRep |= C.SQLITE_DETERMINISTIC
120+
}
121+
122+
x := &xfunc{
123+
conn: conn,
124+
name: name,
125+
xFunc: xFunc,
126+
xStep: xStep,
127+
xFinal: xFinal,
128+
}
129+
130+
xfuncs.mu.Lock()
131+
xfuncs.next++
132+
x.id = xfuncs.next
133+
xfuncs.m[x.id] = x
134+
xfuncs.mu.Unlock()
135+
136+
pApp := unsafe.Pointer(uintptr(x.id))
137+
138+
var funcfn, stepfn, finalfn *[0]byte
139+
if xFunc == nil {
140+
stepfn = (*[0]byte)(C.step_tramp)
141+
finalfn = (*[0]byte)(C.final_tramp)
142+
} else {
143+
funcfn = (*[0]byte)(C.func_tramp)
144+
}
145+
146+
res := C.sqlite3_create_function_v2(
147+
conn.conn,
148+
cname,
149+
C.int(numArgs),
150+
eTextRep,
151+
pApp,
152+
funcfn,
153+
stepfn,
154+
finalfn,
155+
(*[0]byte)(C.destroy_tramp),
156+
)
157+
return conn.reserr("Conn.CreateFunction", name, res)
158+
}
159+
160+
func getxfuncs(ctx *C.sqlite3_context) *xfunc {
161+
id := int(uintptr(C.sqlite3_user_data(ctx)))
162+
163+
xfuncs.mu.RLock()
164+
x := xfuncs.m[id]
165+
xfuncs.mu.RUnlock()
166+
167+
return x
168+
}
169+
170+
//export func_tramp
171+
func func_tramp(ctx *C.sqlite3_context, n C.int, valarray **C.sqlite3_value) {
172+
x := getxfuncs(ctx)
173+
var vals []Value
174+
if n > 0 {
175+
vals = (*[127]Value)(unsafe.Pointer(valarray))[:n:n]
176+
}
177+
x.xFunc(Context{ptr: ctx}, vals...)
178+
}
179+
180+
//export step_tramp
181+
func step_tramp(ctx *C.sqlite3_context, n C.int, valarray **C.sqlite3_value) {
182+
x := getxfuncs(ctx)
183+
var vals []Value
184+
if n > 0 {
185+
vals = (*[127]Value)(unsafe.Pointer(valarray))[:n:n]
186+
}
187+
x.xStep(Context{ptr: ctx}, vals...)
188+
}
189+
190+
//export final_tramp
191+
func final_tramp(ctx *C.sqlite3_context) {
192+
x := getxfuncs(ctx)
193+
x.xFinal(Context{ptr: ctx})
194+
}
195+
196+
//export destroy_tramp
197+
func destroy_tramp(ptr unsafe.Pointer) {
198+
id := int(uintptr(ptr))
199+
200+
xfuncs.mu.Lock()
201+
delete(xfuncs.m, id)
202+
xfuncs.mu.Unlock()
203+
}

func_test.go

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
// Copyright (c) 2018 David Crawshaw <[email protected]>
2+
//
3+
// Permission to use, copy, modify, and distribute this software for any
4+
// purpose with or without fee is hereby granted, provided that the above
5+
// copyright notice and this permission notice appear in all copies.
6+
//
7+
// THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
8+
// WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
9+
// MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
10+
// ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
11+
// WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
12+
// ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
13+
// OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
14+
15+
package sqlite_test
16+
17+
import (
18+
"testing"
19+
20+
"crawshaw.io/sqlite"
21+
)
22+
23+
func TestFunc(t *testing.T) {
24+
c, err := sqlite.OpenConn(":memory:", 0)
25+
if err != nil {
26+
t.Fatal(err)
27+
}
28+
defer func() {
29+
if err := c.Close(); err != nil {
30+
t.Error(err)
31+
}
32+
}()
33+
34+
xFunc := func(ctx sqlite.Context, values ...sqlite.Value) {
35+
v := values[0].Int() + values[1].Int()
36+
ctx.ResultInt(v)
37+
}
38+
if err := c.CreateFunction("addints", true, 2, xFunc, nil, nil); err != nil {
39+
t.Fatal(err)
40+
}
41+
42+
stmt, _, err := c.PrepareTransient("SELECT addints(2, 3);")
43+
if err != nil {
44+
t.Fatal(err)
45+
}
46+
if _, err := stmt.Step(); err != nil {
47+
t.Fatal(err)
48+
}
49+
if got, want := stmt.ColumnInt(0), 5; got != want {
50+
t.Errorf("addints(2, 3)=%d, want %d", got, want)
51+
}
52+
stmt.Finalize()
53+
}
54+
55+
func TestAggFunc(t *testing.T) {
56+
c, err := sqlite.OpenConn(":memory:", 0)
57+
if err != nil {
58+
t.Fatal(err)
59+
}
60+
defer func() {
61+
if err := c.Close(); err != nil {
62+
t.Error(err)
63+
}
64+
}()
65+
66+
stmt, _, err := c.PrepareTransient("CREATE TABLE t (c integer);")
67+
if err != nil {
68+
t.Fatal(err)
69+
}
70+
if _, err := stmt.Step(); err != nil {
71+
t.Fatal(err)
72+
}
73+
if err := stmt.Finalize(); err != nil {
74+
t.Error(err)
75+
}
76+
77+
cVals := []int{3, 5, 7}
78+
want := 3 + 5 + 7
79+
80+
stmt, err = c.Prepare("INSERT INTO t (c) VALUES ($c);")
81+
if err != nil {
82+
t.Fatal(err)
83+
}
84+
for _, val := range cVals {
85+
stmt.SetInt64("$c", int64(val))
86+
if _, err = stmt.Step(); err != nil {
87+
t.Errorf("INSERT %q: %v", val, err)
88+
}
89+
if err = stmt.Reset(); err != nil {
90+
t.Errorf("INSERT reset %q: %v", val, err)
91+
}
92+
}
93+
stmt.Finalize()
94+
95+
xStep := func(ctx sqlite.Context, values ...sqlite.Value) {
96+
var sum int
97+
if data := ctx.UserData(); data != nil {
98+
sum = data.(int)
99+
}
100+
sum += values[0].Int()
101+
ctx.SetUserData(sum)
102+
}
103+
xFinal := func(ctx sqlite.Context) {
104+
var sum int
105+
if data := ctx.UserData(); data != nil {
106+
sum = data.(int)
107+
}
108+
ctx.ResultInt(sum)
109+
}
110+
if err := c.CreateFunction("sumints", true, 2, nil, xStep, xFinal); err != nil {
111+
t.Fatal(err)
112+
}
113+
114+
stmt, _, err = c.PrepareTransient("SELECT sum(c) FROM t;")
115+
if err != nil {
116+
t.Fatal(err)
117+
}
118+
if _, err := stmt.Step(); err != nil {
119+
t.Fatal(err)
120+
}
121+
if got := stmt.ColumnInt(0); got != want {
122+
t.Errorf("sum(c)=%d, want %d", got, want)
123+
}
124+
stmt.Finalize()
125+
}

0 commit comments

Comments
 (0)