From dea7631fa3df0dadf830d3d5b408ce7ad72baf1f Mon Sep 17 00:00:00 2001 From: Harmen Date: Sat, 4 Jan 2020 16:58:32 +0800 Subject: [PATCH] Lua .call() not via a socket (#126) * don't go over a socker for lua .call() This solves some locking issues, and is nicer anyway. --- cmd_connection.go | 4 +++ cmd_pubsub.go | 16 +++++++++ cmd_scripting.go | 23 +++++++----- cmd_transactions.go | 14 ++++++-- integration/script_test.go | 41 +++++++++++++++++++++- lua.go | 45 +++++++++++++++--------- miniredis.go | 25 +++++--------- miniredis_test.go | 18 ---------- redis.go | 9 +++++ server/proto.go | 71 ++++++++++++++++++++++++++++++++++++++ server/proto_test.go | 70 +++++++++++++++++++++++++++++++++++++ server/server.go | 10 ++++-- 12 files changed, 282 insertions(+), 64 deletions(-) diff --git a/cmd_connection.go b/cmd_connection.go index b9296e26..68ccd2da 100644 --- a/cmd_connection.go +++ b/cmd_connection.go @@ -63,6 +63,10 @@ func (m *Miniredis) cmdAuth(c *server.Peer, cmd string, args []string) { if m.checkPubsub(c) { return } + if getCtx(c).nested { + c.WriteError(msgNotFromScripts) + return + } pw := args[0] diff --git a/cmd_pubsub.go b/cmd_pubsub.go index d1b5a7d5..b799ba5b 100644 --- a/cmd_pubsub.go +++ b/cmd_pubsub.go @@ -29,6 +29,10 @@ func (m *Miniredis) cmdSubscribe(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } + if getCtx(c).nested { + c.WriteError(msgNotFromScripts) + return + } withTx(m, c, func(c *server.Peer, ctx *connCtx) { sub := m.subscribedState(c) @@ -49,6 +53,10 @@ func (m *Miniredis) cmdUnsubscribe(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } + if getCtx(c).nested { + c.WriteError(msgNotFromScripts) + return + } channels := args @@ -86,6 +94,10 @@ func (m *Miniredis) cmdPsubscribe(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } + if getCtx(c).nested { + c.WriteError(msgNotFromScripts) + return + } withTx(m, c, func(c *server.Peer, ctx *connCtx) { sub := m.subscribedState(c) @@ -106,6 +118,10 @@ func (m *Miniredis) cmdPunsubscribe(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } + if getCtx(c).nested { + c.WriteError(msgNotFromScripts) + return + } patterns := args diff --git a/cmd_scripting.go b/cmd_scripting.go index 6514a330..6dd9e68f 100644 --- a/cmd_scripting.go +++ b/cmd_scripting.go @@ -51,11 +51,6 @@ func (m *Miniredis) runLuaScript(c *server.Peer, script string, args []string) { luajson.Preload(l) requireGlobal(l, "cjson", "json") - m.Unlock() - conn := m.redigo() - m.Lock() - defer conn.Close() - // set global variable KEYS keysTable := l.NewTable() keysS, args := args[0], args[1:] @@ -84,7 +79,7 @@ func (m *Miniredis) runLuaScript(c *server.Peer, script string, args []string) { } l.SetGlobal("ARGV", argvTable) - redisFuncs := mkLuaFuncs(conn) + redisFuncs := mkLuaFuncs(m.srv, c) // Register command handlers l.Push(l.NewFunction(func(l *lua.LState) int { mod := l.RegisterModule("redis", redisFuncs).(*lua.LTable) @@ -97,8 +92,6 @@ func (m *Miniredis) runLuaScript(c *server.Peer, script string, args []string) { l.Push(lua.LString("redis")) l.Call(1, 0) - m.Unlock() // This runs in a transaction, but can access our db recursively - defer m.Lock() if err := l.DoString(script); err != nil { c.WriteError(errLuaParseError(err)) return @@ -120,6 +113,11 @@ func (m *Miniredis) cmdEval(c *server.Peer, cmd string, args []string) { return } + if getCtx(c).nested { + c.WriteError(msgNotFromScripts) + return + } + script, args := args[0], args[1:] withTx(m, c, func(c *server.Peer, ctx *connCtx) { @@ -139,6 +137,10 @@ func (m *Miniredis) cmdEvalsha(c *server.Peer, cmd string, args []string) { if m.checkPubsub(c) { return } + if getCtx(c).nested { + c.WriteError(msgNotFromScripts) + return + } sha, args := args[0], args[1:] @@ -166,6 +168,11 @@ func (m *Miniredis) cmdScript(c *server.Peer, cmd string, args []string) { return } + if getCtx(c).nested { + c.WriteError(msgNotFromScripts) + return + } + subcmd, args := args[0], args[1:] withTx(m, c, func(c *server.Peer, ctx *connCtx) { diff --git a/cmd_transactions.go b/cmd_transactions.go index 8c871d97..d3bf7b82 100644 --- a/cmd_transactions.go +++ b/cmd_transactions.go @@ -29,7 +29,10 @@ func (m *Miniredis) cmdMulti(c *server.Peer, cmd string, args []string) { } ctx := getCtx(c) - + if ctx.nested { + c.WriteError(msgNotFromScripts) + return + } if inTx(ctx) { c.WriteError("ERR MULTI calls can not be nested") return @@ -55,7 +58,10 @@ func (m *Miniredis) cmdExec(c *server.Peer, cmd string, args []string) { } ctx := getCtx(c) - + if ctx.nested { + c.WriteError(msgNotFromScripts) + return + } if !inTx(ctx) { c.WriteError("ERR EXEC without MULTI") return @@ -130,6 +136,10 @@ func (m *Miniredis) cmdWatch(c *server.Peer, cmd string, args []string) { } ctx := getCtx(c) + if ctx.nested { + c.WriteError(msgNotFromScripts) + return + } if inTx(ctx) { c.WriteError("ERR WATCH in MULTI") return diff --git a/integration/script_test.go b/integration/script_test.go index 69975a59..b8d84ea5 100644 --- a/integration/script_test.go +++ b/integration/script_test.go @@ -20,6 +20,8 @@ func TestEval(t *testing.T) { succ("EVAL", "return redis.call('GET', 'nosuch')==nil", 0), succ("EVAL", "local a = redis.call('MGET', 'bar'); return a[1] == false", 0), succ("EVAL", "local a = redis.call('MGET', 'bar'); return a[1] == nil", 0), + succ("EVAL", "return redis.call('ZRANGE', 'q', 0, -1)", 0), + succ("EVAL", "return redis.call('LPOP', 'foo')", 0), // failure cases fail("EVAL"), @@ -98,6 +100,7 @@ func TestLua(t *testing.T) { succ("EVAL", "return 3.9999+0.201", 0), succ("EVAL", "return {{1}}", 0), succ("EVAL", "return {1,{1,{1,'bar'}}}", 0), + succ("EVAL", "return nil", 0), ) // special returns @@ -277,7 +280,7 @@ func TestLuaCall(t *testing.T) { succ("GET", "res"), ) - // call() with transaction commands + // call() with non-allowed commands testCommands(t, succ("SET", "foo", 1), @@ -289,6 +292,42 @@ func TestLuaCall(t *testing.T) { "This Redis command is not allowed from scripts", "EVAL", `redis.call("EXEC")`, 0, ), + failWith( + "This Redis command is not allowed from scripts", + "EVAL", `redis.call("EVAL", "redis.call(\"GET\", \"foo\")", 0)`, 0, + ), + failWith( + "This Redis command is not allowed from scripts", + "EVAL", `redis.call("SCRIPT", "LOAD", "return 42")`, 0, + ), + failWith( + "This Redis command is not allowed from scripts", + "EVAL", `redis.call("EVALSHA", "123", "0")`, 0, + ), + failWith( + "This Redis command is not allowed from scripts", + "EVAL", `redis.call("AUTH", "foobar")`, 0, + ), + failWith( + "This Redis command is not allowed from scripts", + "EVAL", `redis.call("WATCH", "foobar")`, 0, + ), + failWith( + "This Redis command is not allowed from scripts", + "EVAL", `redis.call("SUBSCRIBE", "foo")`, 0, + ), + failWith( + "This Redis command is not allowed from scripts", + "EVAL", `redis.call("UNSUBSCRIBE", "foo")`, 0, + ), + failWith( + "This Redis command is not allowed from scripts", + "EVAL", `redis.call("PSUBSCRIBE", "foo")`, 0, + ), + failWith( + "This Redis command is not allowed from scripts", + "EVAL", `redis.call("PUNSUBSCRIBE", "foo")`, 0, + ), succ("EVAL", `redis.pcall("EXEC")`, 0), succ("GET", "foo"), ) diff --git a/lua.go b/lua.go index 4272d99a..2ad84378 100644 --- a/lua.go +++ b/lua.go @@ -1,23 +1,32 @@ package miniredis import ( + "bufio" + "bytes" + "fmt" "strings" - redigo "github.com/gomodule/redigo/redis" lua "github.com/yuin/gopher-lua" "github.com/alicebob/miniredis/v2/server" ) -func mkLuaFuncs(conn redigo.Conn) map[string]lua.LGFunction { +func mkLuaFuncs(srv *server.Server, c *server.Peer) map[string]lua.LGFunction { mkCall := func(failFast bool) func(l *lua.LState) int { + // one server.Ctx for a single Lua run + pCtx := &connCtx{} + if getCtx(c).authenticated { + pCtx.authenticated = true + } + pCtx.nested = true + return func(l *lua.LState) int { top := l.GetTop() if top == 0 { l.Error(lua.LString("Please specify at least one argument for redis.call()"), 1) return 0 } - var args []interface{} + var args []string for i := 1; i <= top; i++ { switch a := l.Get(i).(type) { case lua.LNumber: @@ -29,22 +38,19 @@ func mkLuaFuncs(conn redigo.Conn) map[string]lua.LGFunction { return 0 } } - cmd, ok := args[0].(string) - if !ok { - l.Error(lua.LString("Unknown Redis command called from Lua script"), 1) + if len(args) == 0 { + l.Error(lua.LString(msgNotFromScripts), 1) return 0 } - switch strings.ToUpper(cmd) { - case "MULTI", "EXEC": - if failFast { - l.Error(lua.LString("This Redis command is not allowed from scripts"), 1) - return 0 - } - l.Push(lua.LNil) - return 1 - } - res, err := conn.Do(cmd, args[1:]...) + buf := &bytes.Buffer{} + wr := bufio.NewWriter(buf) + peer := server.NewPeer(wr) + peer.Ctx = pCtx + srv.Dispatch(peer, args) + wr.Flush() + + res, err := server.ParseReply(bufio.NewReader(buf)) if err != nil { if failFast { // call() mode @@ -66,14 +72,19 @@ func mkLuaFuncs(conn redigo.Conn) map[string]lua.LGFunction { switch r := res.(type) { case int64: l.Push(lua.LNumber(r)) + case int: + l.Push(lua.LNumber(r)) case []uint8: l.Push(lua.LString(string(r))) case []interface{}: l.Push(redisToLua(l, r)) case string: l.Push(lua.LString(r)) + case error: + l.Error(lua.LString(r.Error()), 1) + return 0 default: - panic("type not handled") + panic(fmt.Sprintf("type not handled (%T)", r)) } } return 1 diff --git a/miniredis.go b/miniredis.go index d32540c6..28e0cef0 100644 --- a/miniredis.go +++ b/miniredis.go @@ -18,13 +18,10 @@ import ( "context" "fmt" "math/rand" - "net" "strconv" "sync" "time" - redigo "github.com/gomodule/redigo/redis" - "github.com/alicebob/miniredis/v2/server" ) @@ -80,6 +77,7 @@ type connCtx struct { dirtyTransaction bool // any error during QUEUEing watch map[dbKey]uint // WATCHed keys subscriber *Subscriber // client is in PUBSUB mode if not nil + nested bool // this is called via Lua } // NewMiniRedis makes a new, non-started, Miniredis object. @@ -287,19 +285,6 @@ func (m *Miniredis) Server() *server.Server { return m.srv } -// redigo returns a redigo.Conn, connected using net.Pipe -func (m *Miniredis) redigo() redigo.Conn { - c1, c2 := net.Pipe() - m.srv.ServeConn(c1) - c := redigo.NewConn(c2, 0, 0) - if m.password != "" { - if _, err := c.Do("AUTH", m.password); err != nil { - // ? - } - } - return c -} - // Dump returns a text version of the selected DB, usable for debugging. func (m *Miniredis) Dump() string { m.Lock() @@ -366,6 +351,10 @@ func (m *Miniredis) SetTime(t time.Time) { // handleAuth returns false if connection has no access. It sends the reply. func (m *Miniredis) handleAuth(c *server.Peer) bool { + if getCtx(c).nested { + return true + } + m.Lock() defer m.Unlock() if m.password == "" { @@ -381,6 +370,10 @@ func (m *Miniredis) handleAuth(c *server.Peer) bool { // handlePubsub sends an error to the user if the connection is in PUBSUB mode. // It'll return true if it did. func (m *Miniredis) checkPubsub(c *server.Peer) bool { + if getCtx(c).nested { + return false + } + m.Lock() defer m.Unlock() diff --git a/miniredis_test.go b/miniredis_test.go index c76dccef..10965c1e 100644 --- a/miniredis_test.go +++ b/miniredis_test.go @@ -238,21 +238,3 @@ func TestExpireWithFastForward(t *testing.T) { s.FastForward(5 * time.Second) equals(t, 1, len(s.Keys())) } - -func TestRedigo(t *testing.T) { - s, err := Run() - ok(t, err) - - r := s.redigo() - defer r.Close() - - _, err = r.Do("SELECT", 2) - ok(t, err) - - _, err = r.Do("SET", "foo", "bar") - ok(t, err) - - v, err := redis.String(r.Do("GET", "foo")) - ok(t, err) - equals(t, "bar", v) -} diff --git a/redis.go b/redis.go index 5315de35..6fa49411 100644 --- a/redis.go +++ b/redis.go @@ -36,6 +36,7 @@ const ( msgStreamIDTooSmall = "ERR The ID specified in XADD is equal or smaller than the target stream top item" msgNoScriptFound = "NOSCRIPT No matching script. Please use EVAL." msgUnsupportedUnit = "ERR unsupported unit provided. please use m, km, ft, mi" + msgNotFromScripts = "This Redis command is not allowed from scripts" ) func errWrongNumber(cmd string) string { @@ -54,6 +55,14 @@ func withTx( cb txCmd, ) { ctx := getCtx(c) + + if ctx.nested { + // this is a call via Lua's .call(). It's already locked. + cb(c, ctx) + m.signal.Broadcast() + return + } + if inTx(ctx) { addTxCmd(ctx, cb) c.WriteInline("QUEUED") diff --git a/server/proto.go b/server/proto.go index 27e62d4f..d09d16a1 100644 --- a/server/proto.go +++ b/server/proto.go @@ -82,3 +82,74 @@ func readString(rd *bufio.Reader) (string, error) { return string(buf[:length]), nil } } + +// parse a reply +func ParseReply(rd *bufio.Reader) (interface{}, error) { + line, err := rd.ReadString('\n') + if err != nil { + return nil, err + } + if len(line) < 3 { + return nil, ErrProtocol + } + + switch line[0] { + default: + return nil, ErrProtocol + case '+': + // +: simple string + return string(line[1 : len(line)-2]), nil + case '-': + // -: errors + return nil, errors.New(string(line[1 : len(line)-2])) + case ':': + // :: integer + v := line[1 : len(line)-2] + if v == "" { + return 0, nil + } + n, err := strconv.Atoi(v) + if err != nil { + return nil, ErrProtocol + } + return n, nil + case '$': + // bulk strings are: `$5\r\nhello\r\n` + length, err := strconv.Atoi(line[1 : len(line)-2]) + if err != nil { + return "", err + } + if length < 0 { + // -1 is a nil response + return nil, nil + } + var ( + buf = make([]byte, length+2) + pos = 0 + ) + for pos < length+2 { + n, err := rd.Read(buf[pos:]) + if err != nil { + return "", err + } + pos += n + } + return string(buf[:length]), nil + case '*': + // array + l, err := strconv.Atoi(line[1 : len(line)-2]) + if err != nil { + return nil, ErrProtocol + } + // l can be -1 + var fields []interface{} + for ; l > 0; l-- { + s, err := ParseReply(rd) + if err != nil { + return nil, err + } + fields = append(fields, s) + } + return fields, nil + } +} diff --git a/server/proto_test.go b/server/proto_test.go index 8223b961..ff3206d5 100644 --- a/server/proto_test.go +++ b/server/proto_test.go @@ -3,6 +3,7 @@ package server import ( "bufio" "bytes" + "errors" "fmt" "io" "reflect" @@ -119,3 +120,72 @@ func TestReadString(t *testing.T) { } } } + +func TestParseReply(t *testing.T) { + type cas struct { + payload string + err error + res interface{} + } + bigPayload := strings.Repeat("X", 1<<24) + for i, c := range []cas{ + { + payload: "+hello world\r\n", + res: "hello world", + }, + { + payload: "-some error\r\n", + err: errors.New("some error"), + }, + { + payload: ":42\r\n", + res: 42, + }, + { + payload: ":\r\n", + res: 0, + }, + { + payload: "$4\r\nabcd\r\n", + res: "abcd", + }, + { + payload: fmt.Sprintf("$%d\r\n%s\r\n", len(bigPayload), bigPayload), + res: bigPayload, + }, + + { + payload: "", + err: io.EOF, + }, + { + payload: ":42", + err: io.EOF, + }, + { + payload: "XXX", + err: io.EOF, + }, + { + payload: "XXXXXX", + err: io.EOF, + }, + { + payload: "\r\n", + err: ErrProtocol, + }, + { + payload: "XXXX\r\n", + err: ErrProtocol, + }, + } { + res, err := ParseReply(bufio.NewReader(bytes.NewBufferString(c.payload))) + if have, want := err, c.err; !reflect.DeepEqual(have, want) { + t.Errorf("err %d: have %#v, want %#v", i, have, want) + continue + } + if have, want := res, c.res; !reflect.DeepEqual(have, want) { + t.Errorf("case %d: have %#v, want %#v", i, have, want) + } + } +} diff --git a/server/server.go b/server/server.go index c924bb3c..2d315722 100644 --- a/server/server.go +++ b/server/server.go @@ -140,7 +140,7 @@ func (s *Server) servePeer(c net.Conn) { if err != nil { return } - s.dispatch(peer, args) + s.Dispatch(peer, args) peer.Flush() s.mu.Lock() closed := peer.closed @@ -151,7 +151,7 @@ func (s *Server) servePeer(c net.Conn) { } } -func (s *Server) dispatch(c *Peer, args []string) { +func (s *Server) Dispatch(c *Peer, args []string) { cmd, args := args[0], args[1:] cmdUp := strings.ToUpper(cmd) s.mu.Lock() @@ -199,6 +199,12 @@ type Peer struct { mu sync.Mutex // for Block() } +func NewPeer(w *bufio.Writer) *Peer { + return &Peer{ + w: w, + } +} + // Flush the write buffer. Called automatically after every redis command func (c *Peer) Flush() { c.mu.Lock()