Skip to content

Commit

Permalink
error_reply and status_reply
Browse files Browse the repository at this point in the history
Deal with types related to error_reply and status_reply.
Moved lua code to a separate file.
  • Loading branch information
alicebob committed Jan 15, 2018
1 parent 7333f0a commit 3dc460a
Show file tree
Hide file tree
Showing 3 changed files with 190 additions and 131 deletions.
132 changes: 2 additions & 130 deletions cmd_scripting.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,65 +54,7 @@ func (m *Miniredis) runLuaScript(c *server.Peer, script string, args []string) {
}
l.SetGlobal("ARGV", argvTable)

// Register call function to lua VM
redisFuncs := map[string]lua.LGFunction{
"call": func(l *lua.LState) int {
top := l.GetTop()

cmd := lua.LVAsString(l.Get(1))
args := make([]interface{}, top-1)
for i := 2; i <= top; i++ {
arg := l.Get(i)

dataType := arg.Type()
switch dataType {
case lua.LTBool:
args[i-2] = lua.LVAsBool(arg)
case lua.LTNumber:
value, _ := strconv.ParseFloat(lua.LVAsString(arg), 64)
args[i-2] = value
case lua.LTString:
args[i-2] = lua.LVAsString(arg)
case lua.LTNil:
case lua.LTFunction:
case lua.LTUserData:
case lua.LTThread:
case lua.LTTable:
case lua.LTChannel:
default:
args[i-2] = nil
}
}
res, err := conn.Do(cmd, args...)
if err != nil {
l.Push(lua.LNil)
return 1
}

if res == nil {
l.Push(lua.LNil)
} else {
switch r := res.(type) {
case int64:
l.Push(lua.LNumber(r))
case []uint8:
l.Push(lua.LString(string(r)))
case []interface{}:
l.Push(m.redisToLua(l, r))
case string:
l.Push(lua.LString(r))
default:
// TODO: oops?
l.Push(lua.LString(res.(string)))
}
}

return 1 // Notify that we pushed one value to the stack
},
}

redisFuncs["pcall"] = redisFuncs["call"]

redisFuncs := mkLuaFuncs(conn)
// Register command handlers
l.Push(l.NewFunction(func(l *lua.LState) int {
mod := l.RegisterModule("redis", redisFuncs).(*lua.LTable)
Expand All @@ -129,82 +71,12 @@ func (m *Miniredis) runLuaScript(c *server.Peer, script string, args []string) {
}

if l.GetTop() > 0 {
m.luaToRedis(l, c, l.Get(1))
luaToRedis(l, c, l.Get(1))
} else {
c.WriteNull()
}
}

func (m *Miniredis) redisToLua(l *lua.LState, res []interface{}) *lua.LTable {
rettb := l.NewTable()
for _, e := range res {
var v lua.LValue
if e == nil {
v = lua.LValue(nil)
} else {
switch et := e.(type) {
case int64:
v = lua.LNumber(et)
case []uint8:
v = lua.LString(string(et))
case []interface{}:
v = m.redisToLua(l, et)
case string:
v = lua.LString(et)
default:
// TODO: oops?
v = lua.LString(e.(string))
}
}
l.RawSet(rettb, lua.LNumber(rettb.Len()+1), v)
}
return rettb
}

func (m *Miniredis) luaToRedis(l *lua.LState, c *server.Peer, value lua.LValue) {
if value == nil {
c.WriteNull()
return
}

switch value.Type() {
case lua.LTNil:
c.WriteNull()
case lua.LTBool:
if lua.LVAsBool(value) {
c.WriteInt(1)
} else {
c.WriteInt(0)
}
case lua.LTNumber:
c.WriteInt(int(lua.LVAsNumber(value)))
case lua.LTString:
c.WriteBulk(lua.LVAsString(value))
case lua.LTTable:
result := []lua.LValue{}
for j := 1; true; j++ {
val := l.GetTable(value, lua.LNumber(j))
if val == nil {
result = append(result, val)
continue
}

if val.Type() == lua.LTNil {
break
}

result = append(result, val)
}

c.WriteLen(len(result))
for _, r := range result {
m.luaToRedis(l, c, r)
}
default:
c.WriteInline(lua.LVAsString(value))
}
}

func (m *Miniredis) cmdEval(c *server.Peer, cmd string, args []string) {
if len(args) < 2 {
setDirty(c)
Expand Down
20 changes: 19 additions & 1 deletion cmd_scripting_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ func TestEvalsha(t *testing.T) {
mustFail(t, err, msgNoScriptFound)
}

func TestCmdEvalReplyConversion(t *testing.T) {
func TestCmdEvalReply(t *testing.T) {
s, err := Run()
ok(t, err)
defer s.Close()
Expand Down Expand Up @@ -307,6 +307,24 @@ func TestCmdEvalReplyConversion(t *testing.T) {
}
equals(t, tc.expected, reply)
}

{
_, err := c.Do("EVAL", `return {err="broken"}`, 0)
mustFail(t, err, "broken")

_, err = c.Do("EVAL", `return redis.error_reply("broken")`, 0)
mustFail(t, err, "broken")
}

{
v, err := redis.String(c.Do("EVAL", `return {ok="good"}`, 0))
ok(t, err)
equals(t, "good", v)

v, err = redis.String(c.Do("EVAL", `return redis.status_reply("good")`, 0))
ok(t, err)
equals(t, "good", v)
}
}

func TestCmdEvalResponse(t *testing.T) {
Expand Down
169 changes: 169 additions & 0 deletions lua.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
package miniredis

import (
"strconv"

redigo "github.com/garyburd/redigo/redis"
lua "github.com/yuin/gopher-lua"

"github.com/alicebob/miniredis/server"
)

func mkLuaFuncs(conn redigo.Conn) map[string]lua.LGFunction {
funcs := map[string]lua.LGFunction{
"call": func(l *lua.LState) int {
top := l.GetTop()

cmd := lua.LVAsString(l.Get(1))
args := make([]interface{}, top-1)
for i := 2; i <= top; i++ {
arg := l.Get(i)

dataType := arg.Type()
switch dataType {
case lua.LTBool:
args[i-2] = lua.LVAsBool(arg)
case lua.LTNumber:
value, _ := strconv.ParseFloat(lua.LVAsString(arg), 64)
args[i-2] = value
case lua.LTString:
args[i-2] = lua.LVAsString(arg)
case lua.LTNil:
case lua.LTFunction:
case lua.LTUserData:
case lua.LTThread:
case lua.LTTable:
case lua.LTChannel:
default:
args[i-2] = nil
}
}
res, err := conn.Do(cmd, args...)
if err != nil {
l.Push(lua.LNil)
return 1
}

if res == nil {
l.Push(lua.LNil)
} else {
switch r := res.(type) {
case int64:
l.Push(lua.LNumber(r))
case []uint8:
l.Push(lua.LString(string(r)))
case []interface{}:
l.Push(redisToLua(l, r))
case string:
l.Push(lua.LString(r))
default:
// TODO: oops?
l.Push(lua.LString(res.(string)))
}
}

return 1 // Notify that we pushed one value to the stack
},
"error_reply": func(l *lua.LState) int {
msg := l.CheckAny(1)
res := &lua.LTable{}
res.RawSetString("err", msg)
l.Push(res)
return 1
},
"status_reply": func(l *lua.LState) int {
msg := l.CheckAny(1)
res := &lua.LTable{}
res.RawSetString("ok", msg)
l.Push(res)
return 1
},
}
funcs["pcall"] = funcs["call"]
return funcs
}

func luaToRedis(l *lua.LState, c *server.Peer, value lua.LValue) {
if value == nil {
c.WriteNull()
return
}

switch value.Type() {
case lua.LTNil:
c.WriteNull()
case lua.LTBool:
if lua.LVAsBool(value) {
c.WriteInt(1)
} else {
c.WriteInt(0)
}
case lua.LTNumber:
c.WriteInt(int(lua.LVAsNumber(value)))
case lua.LTString:
c.WriteBulk(lua.LVAsString(value))
case lua.LTTable:
t := value.(*lua.LTable)
// special case for table with only an 'err' field
var keys []string
t.ForEach(func(k, _ lua.LValue) {
keys = append(keys, k.String())
})
if len(keys) == 1 && keys[0] == "err" {
c.WriteError(t.RawGetString("err").String())
return
}
if len(keys) == 1 && keys[0] == "ok" {
c.WriteInline(t.RawGetString("ok").String())
return
}

result := []lua.LValue{}
for j := 1; true; j++ {
val := l.GetTable(value, lua.LNumber(j))
if val == nil {
result = append(result, val)
continue
}

if val.Type() == lua.LTNil {
break
}

result = append(result, val)
}

c.WriteLen(len(result))
for _, r := range result {
luaToRedis(l, c, r)
}
default:
c.WriteInline(lua.LVAsString(value))
}
}

func redisToLua(l *lua.LState, res []interface{}) *lua.LTable {
rettb := l.NewTable()
for _, e := range res {
var v lua.LValue
if e == nil {
v = lua.LValue(nil)
} else {
switch et := e.(type) {
case int64:
v = lua.LNumber(et)
case []uint8:
v = lua.LString(string(et))
case []interface{}:
v = redisToLua(l, et)
case string:
v = lua.LString(et)
default:
// TODO: oops?
v = lua.LString(e.(string))
}
}
l.RawSet(rettb, lua.LNumber(rettb.Len()+1), v)
}
return rettb
}

0 comments on commit 3dc460a

Please sign in to comment.