diff --git a/cmd_scripting.go b/cmd_scripting.go index f7a01fed..314372b4 100644 --- a/cmd_scripting.go +++ b/cmd_scripting.go @@ -54,65 +54,7 @@ func (m *Miniredis) runLuaScript(c *server.Peer, script string, args []string) { } l.SetGlobal("ARGV", argvTable) - // Register call function to lua VM - redisFuncs := map[string]lua.LGFunction{ - "call": func(l *lua.LState) int { - top := l.GetTop() - - cmd := lua.LVAsString(l.Get(1)) - args := make([]interface{}, top-1) - for i := 2; i <= top; i++ { - arg := l.Get(i) - - dataType := arg.Type() - switch dataType { - case lua.LTBool: - args[i-2] = lua.LVAsBool(arg) - case lua.LTNumber: - value, _ := strconv.ParseFloat(lua.LVAsString(arg), 64) - args[i-2] = value - case lua.LTString: - args[i-2] = lua.LVAsString(arg) - case lua.LTNil: - case lua.LTFunction: - case lua.LTUserData: - case lua.LTThread: - case lua.LTTable: - case lua.LTChannel: - default: - args[i-2] = nil - } - } - res, err := conn.Do(cmd, args...) - if err != nil { - l.Push(lua.LNil) - return 1 - } - - if res == nil { - l.Push(lua.LNil) - } else { - switch r := res.(type) { - case int64: - l.Push(lua.LNumber(r)) - case []uint8: - l.Push(lua.LString(string(r))) - case []interface{}: - l.Push(m.redisToLua(l, r)) - case string: - l.Push(lua.LString(r)) - default: - // TODO: oops? - l.Push(lua.LString(res.(string))) - } - } - - return 1 // Notify that we pushed one value to the stack - }, - } - - redisFuncs["pcall"] = redisFuncs["call"] - + redisFuncs := mkLuaFuncs(conn) // Register command handlers l.Push(l.NewFunction(func(l *lua.LState) int { mod := l.RegisterModule("redis", redisFuncs).(*lua.LTable) @@ -129,82 +71,12 @@ func (m *Miniredis) runLuaScript(c *server.Peer, script string, args []string) { } if l.GetTop() > 0 { - m.luaToRedis(l, c, l.Get(1)) + luaToRedis(l, c, l.Get(1)) } else { c.WriteNull() } } -func (m *Miniredis) redisToLua(l *lua.LState, res []interface{}) *lua.LTable { - rettb := l.NewTable() - for _, e := range res { - var v lua.LValue - if e == nil { - v = lua.LValue(nil) - } else { - switch et := e.(type) { - case int64: - v = lua.LNumber(et) - case []uint8: - v = lua.LString(string(et)) - case []interface{}: - v = m.redisToLua(l, et) - case string: - v = lua.LString(et) - default: - // TODO: oops? - v = lua.LString(e.(string)) - } - } - l.RawSet(rettb, lua.LNumber(rettb.Len()+1), v) - } - return rettb -} - -func (m *Miniredis) luaToRedis(l *lua.LState, c *server.Peer, value lua.LValue) { - if value == nil { - c.WriteNull() - return - } - - switch value.Type() { - case lua.LTNil: - c.WriteNull() - case lua.LTBool: - if lua.LVAsBool(value) { - c.WriteInt(1) - } else { - c.WriteInt(0) - } - case lua.LTNumber: - c.WriteInt(int(lua.LVAsNumber(value))) - case lua.LTString: - c.WriteBulk(lua.LVAsString(value)) - case lua.LTTable: - result := []lua.LValue{} - for j := 1; true; j++ { - val := l.GetTable(value, lua.LNumber(j)) - if val == nil { - result = append(result, val) - continue - } - - if val.Type() == lua.LTNil { - break - } - - result = append(result, val) - } - - c.WriteLen(len(result)) - for _, r := range result { - m.luaToRedis(l, c, r) - } - default: - c.WriteInline(lua.LVAsString(value)) - } -} - func (m *Miniredis) cmdEval(c *server.Peer, cmd string, args []string) { if len(args) < 2 { setDirty(c) diff --git a/cmd_scripting_test.go b/cmd_scripting_test.go index bc5ffc9e..17159228 100644 --- a/cmd_scripting_test.go +++ b/cmd_scripting_test.go @@ -162,7 +162,7 @@ func TestEvalsha(t *testing.T) { mustFail(t, err, msgNoScriptFound) } -func TestCmdEvalReplyConversion(t *testing.T) { +func TestCmdEvalReply(t *testing.T) { s, err := Run() ok(t, err) defer s.Close() @@ -307,6 +307,24 @@ func TestCmdEvalReplyConversion(t *testing.T) { } equals(t, tc.expected, reply) } + + { + _, err := c.Do("EVAL", `return {err="broken"}`, 0) + mustFail(t, err, "broken") + + _, err = c.Do("EVAL", `return redis.error_reply("broken")`, 0) + mustFail(t, err, "broken") + } + + { + v, err := redis.String(c.Do("EVAL", `return {ok="good"}`, 0)) + ok(t, err) + equals(t, "good", v) + + v, err = redis.String(c.Do("EVAL", `return redis.status_reply("good")`, 0)) + ok(t, err) + equals(t, "good", v) + } } func TestCmdEvalResponse(t *testing.T) { diff --git a/lua.go b/lua.go new file mode 100644 index 00000000..77212724 --- /dev/null +++ b/lua.go @@ -0,0 +1,169 @@ +package miniredis + +import ( + "strconv" + + redigo "github.com/garyburd/redigo/redis" + lua "github.com/yuin/gopher-lua" + + "github.com/alicebob/miniredis/server" +) + +func mkLuaFuncs(conn redigo.Conn) map[string]lua.LGFunction { + funcs := map[string]lua.LGFunction{ + "call": func(l *lua.LState) int { + top := l.GetTop() + + cmd := lua.LVAsString(l.Get(1)) + args := make([]interface{}, top-1) + for i := 2; i <= top; i++ { + arg := l.Get(i) + + dataType := arg.Type() + switch dataType { + case lua.LTBool: + args[i-2] = lua.LVAsBool(arg) + case lua.LTNumber: + value, _ := strconv.ParseFloat(lua.LVAsString(arg), 64) + args[i-2] = value + case lua.LTString: + args[i-2] = lua.LVAsString(arg) + case lua.LTNil: + case lua.LTFunction: + case lua.LTUserData: + case lua.LTThread: + case lua.LTTable: + case lua.LTChannel: + default: + args[i-2] = nil + } + } + res, err := conn.Do(cmd, args...) + if err != nil { + l.Push(lua.LNil) + return 1 + } + + if res == nil { + l.Push(lua.LNil) + } else { + switch r := res.(type) { + case int64: + l.Push(lua.LNumber(r)) + case []uint8: + l.Push(lua.LString(string(r))) + case []interface{}: + l.Push(redisToLua(l, r)) + case string: + l.Push(lua.LString(r)) + default: + // TODO: oops? + l.Push(lua.LString(res.(string))) + } + } + + return 1 // Notify that we pushed one value to the stack + }, + "error_reply": func(l *lua.LState) int { + msg := l.CheckAny(1) + res := &lua.LTable{} + res.RawSetString("err", msg) + l.Push(res) + return 1 + }, + "status_reply": func(l *lua.LState) int { + msg := l.CheckAny(1) + res := &lua.LTable{} + res.RawSetString("ok", msg) + l.Push(res) + return 1 + }, + } + funcs["pcall"] = funcs["call"] + return funcs +} + +func luaToRedis(l *lua.LState, c *server.Peer, value lua.LValue) { + if value == nil { + c.WriteNull() + return + } + + switch value.Type() { + case lua.LTNil: + c.WriteNull() + case lua.LTBool: + if lua.LVAsBool(value) { + c.WriteInt(1) + } else { + c.WriteInt(0) + } + case lua.LTNumber: + c.WriteInt(int(lua.LVAsNumber(value))) + case lua.LTString: + c.WriteBulk(lua.LVAsString(value)) + case lua.LTTable: + t := value.(*lua.LTable) + // special case for table with only an 'err' field + var keys []string + t.ForEach(func(k, _ lua.LValue) { + keys = append(keys, k.String()) + }) + if len(keys) == 1 && keys[0] == "err" { + c.WriteError(t.RawGetString("err").String()) + return + } + if len(keys) == 1 && keys[0] == "ok" { + c.WriteInline(t.RawGetString("ok").String()) + return + } + + result := []lua.LValue{} + for j := 1; true; j++ { + val := l.GetTable(value, lua.LNumber(j)) + if val == nil { + result = append(result, val) + continue + } + + if val.Type() == lua.LTNil { + break + } + + result = append(result, val) + } + + c.WriteLen(len(result)) + for _, r := range result { + luaToRedis(l, c, r) + } + default: + c.WriteInline(lua.LVAsString(value)) + } +} + +func redisToLua(l *lua.LState, res []interface{}) *lua.LTable { + rettb := l.NewTable() + for _, e := range res { + var v lua.LValue + if e == nil { + v = lua.LValue(nil) + } else { + switch et := e.(type) { + case int64: + v = lua.LNumber(et) + case []uint8: + v = lua.LString(string(et)) + case []interface{}: + v = redisToLua(l, et) + case string: + v = lua.LString(et) + default: + // TODO: oops? + v = lua.LString(e.(string)) + } + } + l.RawSet(rettb, lua.LNumber(rettb.Len()+1), v) + } + return rettb +}