From 4ce3763ac9eb98ac3943ef577c282babd86564b7 Mon Sep 17 00:00:00 2001 From: Harmen Date: Mon, 15 Jan 2018 21:16:33 +0100 Subject: [PATCH] change error_reply() to work as redis does, not as the docs say Also check function arguments. Same for status_reply() --- cmd_scripting_test.go | 12 ++++++++++++ lua.go | 24 +++++++++++------------- 2 files changed, 23 insertions(+), 13 deletions(-) diff --git a/cmd_scripting_test.go b/cmd_scripting_test.go index 17159228..0d0f338a 100644 --- a/cmd_scripting_test.go +++ b/cmd_scripting_test.go @@ -325,6 +325,18 @@ func TestCmdEvalReply(t *testing.T) { ok(t, err) equals(t, "good", v) } + + _, err = c.Do("EVAL", `return redis.error_reply()`, 0) + assert(t, err != nil, "no EVAL error") + + _, err = c.Do("EVAL", `return redis.error_reply(1)`, 0) + assert(t, err != nil, "no EVAL error") + + _, err = c.Do("EVAL", `return redis.status_reply()`, 0) + assert(t, err != nil, "no EVAL error") + + _, err = c.Do("EVAL", `return redis.status_reply(1)`, 0) + assert(t, err != nil, "no EVAL error") } func TestCmdEvalResponse(t *testing.T) { diff --git a/lua.go b/lua.go index 77212724..9ce7c1cb 100644 --- a/lua.go +++ b/lua.go @@ -65,16 +65,16 @@ func mkLuaFuncs(conn redigo.Conn) map[string]lua.LGFunction { return 1 // Notify that we pushed one value to the stack }, "error_reply": func(l *lua.LState) int { - msg := l.CheckAny(1) + msg := l.CheckString(1) res := &lua.LTable{} - res.RawSetString("err", msg) + res.RawSetString("err", lua.LString(msg)) l.Push(res) return 1 }, "status_reply": func(l *lua.LState) int { - msg := l.CheckAny(1) + msg := l.CheckString(1) res := &lua.LTable{} - res.RawSetString("ok", msg) + res.RawSetString("ok", lua.LString(msg)) l.Push(res) return 1 }, @@ -104,17 +104,15 @@ func luaToRedis(l *lua.LState, c *server.Peer, value lua.LValue) { c.WriteBulk(lua.LVAsString(value)) case lua.LTTable: t := value.(*lua.LTable) - // special case for table with only an 'err' field - var keys []string - t.ForEach(func(k, _ lua.LValue) { - keys = append(keys, k.String()) - }) - if len(keys) == 1 && keys[0] == "err" { - c.WriteError(t.RawGetString("err").String()) + // special case for tables with an 'err' or 'ok' field + // note: according to the docs this only counts when 'err' or 'ok' is + // the only field. + if s := t.RawGetString("err"); s.Type() != lua.LTNil { + c.WriteError(s.String()) return } - if len(keys) == 1 && keys[0] == "ok" { - c.WriteInline(t.RawGetString("ok").String()) + if s := t.RawGetString("ok"); s.Type() != lua.LTNil { + c.WriteInline(s.String()) return }