Skip to content

Commit

Permalink
Lua .call() not via a socket (#126)
Browse files Browse the repository at this point in the history
* don't go over a socker for lua .call()

This solves some locking issues, and is nicer anyway.
  • Loading branch information
alicebob authored Jan 4, 2020
1 parent f930c06 commit dea7631
Show file tree
Hide file tree
Showing 12 changed files with 282 additions and 64 deletions.
4 changes: 4 additions & 0 deletions cmd_connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,10 @@ func (m *Miniredis) cmdAuth(c *server.Peer, cmd string, args []string) {
if m.checkPubsub(c) {
return
}
if getCtx(c).nested {
c.WriteError(msgNotFromScripts)
return
}

pw := args[0]

Expand Down
16 changes: 16 additions & 0 deletions cmd_pubsub.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@ func (m *Miniredis) cmdSubscribe(c *server.Peer, cmd string, args []string) {
if !m.handleAuth(c) {
return
}
if getCtx(c).nested {
c.WriteError(msgNotFromScripts)
return
}

withTx(m, c, func(c *server.Peer, ctx *connCtx) {
sub := m.subscribedState(c)
Expand All @@ -49,6 +53,10 @@ func (m *Miniredis) cmdUnsubscribe(c *server.Peer, cmd string, args []string) {
if !m.handleAuth(c) {
return
}
if getCtx(c).nested {
c.WriteError(msgNotFromScripts)
return
}

channels := args

Expand Down Expand Up @@ -86,6 +94,10 @@ func (m *Miniredis) cmdPsubscribe(c *server.Peer, cmd string, args []string) {
if !m.handleAuth(c) {
return
}
if getCtx(c).nested {
c.WriteError(msgNotFromScripts)
return
}

withTx(m, c, func(c *server.Peer, ctx *connCtx) {
sub := m.subscribedState(c)
Expand All @@ -106,6 +118,10 @@ func (m *Miniredis) cmdPunsubscribe(c *server.Peer, cmd string, args []string) {
if !m.handleAuth(c) {
return
}
if getCtx(c).nested {
c.WriteError(msgNotFromScripts)
return
}

patterns := args

Expand Down
23 changes: 15 additions & 8 deletions cmd_scripting.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,6 @@ func (m *Miniredis) runLuaScript(c *server.Peer, script string, args []string) {
luajson.Preload(l)
requireGlobal(l, "cjson", "json")

m.Unlock()
conn := m.redigo()
m.Lock()
defer conn.Close()

// set global variable KEYS
keysTable := l.NewTable()
keysS, args := args[0], args[1:]
Expand Down Expand Up @@ -84,7 +79,7 @@ func (m *Miniredis) runLuaScript(c *server.Peer, script string, args []string) {
}
l.SetGlobal("ARGV", argvTable)

redisFuncs := mkLuaFuncs(conn)
redisFuncs := mkLuaFuncs(m.srv, c)
// Register command handlers
l.Push(l.NewFunction(func(l *lua.LState) int {
mod := l.RegisterModule("redis", redisFuncs).(*lua.LTable)
Expand All @@ -97,8 +92,6 @@ func (m *Miniredis) runLuaScript(c *server.Peer, script string, args []string) {
l.Push(lua.LString("redis"))
l.Call(1, 0)

m.Unlock() // This runs in a transaction, but can access our db recursively
defer m.Lock()
if err := l.DoString(script); err != nil {
c.WriteError(errLuaParseError(err))
return
Expand All @@ -120,6 +113,11 @@ func (m *Miniredis) cmdEval(c *server.Peer, cmd string, args []string) {
return
}

if getCtx(c).nested {
c.WriteError(msgNotFromScripts)
return
}

script, args := args[0], args[1:]

withTx(m, c, func(c *server.Peer, ctx *connCtx) {
Expand All @@ -139,6 +137,10 @@ func (m *Miniredis) cmdEvalsha(c *server.Peer, cmd string, args []string) {
if m.checkPubsub(c) {
return
}
if getCtx(c).nested {
c.WriteError(msgNotFromScripts)
return
}

sha, args := args[0], args[1:]

Expand Down Expand Up @@ -166,6 +168,11 @@ func (m *Miniredis) cmdScript(c *server.Peer, cmd string, args []string) {
return
}

if getCtx(c).nested {
c.WriteError(msgNotFromScripts)
return
}

subcmd, args := args[0], args[1:]

withTx(m, c, func(c *server.Peer, ctx *connCtx) {
Expand Down
14 changes: 12 additions & 2 deletions cmd_transactions.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,10 @@ func (m *Miniredis) cmdMulti(c *server.Peer, cmd string, args []string) {
}

ctx := getCtx(c)

if ctx.nested {
c.WriteError(msgNotFromScripts)
return
}
if inTx(ctx) {
c.WriteError("ERR MULTI calls can not be nested")
return
Expand All @@ -55,7 +58,10 @@ func (m *Miniredis) cmdExec(c *server.Peer, cmd string, args []string) {
}

ctx := getCtx(c)

if ctx.nested {
c.WriteError(msgNotFromScripts)
return
}
if !inTx(ctx) {
c.WriteError("ERR EXEC without MULTI")
return
Expand Down Expand Up @@ -130,6 +136,10 @@ func (m *Miniredis) cmdWatch(c *server.Peer, cmd string, args []string) {
}

ctx := getCtx(c)
if ctx.nested {
c.WriteError(msgNotFromScripts)
return
}
if inTx(ctx) {
c.WriteError("ERR WATCH in MULTI")
return
Expand Down
41 changes: 40 additions & 1 deletion integration/script_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ func TestEval(t *testing.T) {
succ("EVAL", "return redis.call('GET', 'nosuch')==nil", 0),
succ("EVAL", "local a = redis.call('MGET', 'bar'); return a[1] == false", 0),
succ("EVAL", "local a = redis.call('MGET', 'bar'); return a[1] == nil", 0),
succ("EVAL", "return redis.call('ZRANGE', 'q', 0, -1)", 0),
succ("EVAL", "return redis.call('LPOP', 'foo')", 0),

// failure cases
fail("EVAL"),
Expand Down Expand Up @@ -98,6 +100,7 @@ func TestLua(t *testing.T) {
succ("EVAL", "return 3.9999+0.201", 0),
succ("EVAL", "return {{1}}", 0),
succ("EVAL", "return {1,{1,{1,'bar'}}}", 0),
succ("EVAL", "return nil", 0),
)

// special returns
Expand Down Expand Up @@ -277,7 +280,7 @@ func TestLuaCall(t *testing.T) {
succ("GET", "res"),
)

// call() with transaction commands
// call() with non-allowed commands
testCommands(t,
succ("SET", "foo", 1),

Expand All @@ -289,6 +292,42 @@ func TestLuaCall(t *testing.T) {
"This Redis command is not allowed from scripts",
"EVAL", `redis.call("EXEC")`, 0,
),
failWith(
"This Redis command is not allowed from scripts",
"EVAL", `redis.call("EVAL", "redis.call(\"GET\", \"foo\")", 0)`, 0,
),
failWith(
"This Redis command is not allowed from scripts",
"EVAL", `redis.call("SCRIPT", "LOAD", "return 42")`, 0,
),
failWith(
"This Redis command is not allowed from scripts",
"EVAL", `redis.call("EVALSHA", "123", "0")`, 0,
),
failWith(
"This Redis command is not allowed from scripts",
"EVAL", `redis.call("AUTH", "foobar")`, 0,
),
failWith(
"This Redis command is not allowed from scripts",
"EVAL", `redis.call("WATCH", "foobar")`, 0,
),
failWith(
"This Redis command is not allowed from scripts",
"EVAL", `redis.call("SUBSCRIBE", "foo")`, 0,
),
failWith(
"This Redis command is not allowed from scripts",
"EVAL", `redis.call("UNSUBSCRIBE", "foo")`, 0,
),
failWith(
"This Redis command is not allowed from scripts",
"EVAL", `redis.call("PSUBSCRIBE", "foo")`, 0,
),
failWith(
"This Redis command is not allowed from scripts",
"EVAL", `redis.call("PUNSUBSCRIBE", "foo")`, 0,
),
succ("EVAL", `redis.pcall("EXEC")`, 0),
succ("GET", "foo"),
)
Expand Down
45 changes: 28 additions & 17 deletions lua.go
Original file line number Diff line number Diff line change
@@ -1,23 +1,32 @@
package miniredis

import (
"bufio"
"bytes"
"fmt"
"strings"

redigo "github.com/gomodule/redigo/redis"
lua "github.com/yuin/gopher-lua"

"github.com/alicebob/miniredis/v2/server"
)

func mkLuaFuncs(conn redigo.Conn) map[string]lua.LGFunction {
func mkLuaFuncs(srv *server.Server, c *server.Peer) map[string]lua.LGFunction {
mkCall := func(failFast bool) func(l *lua.LState) int {
// one server.Ctx for a single Lua run
pCtx := &connCtx{}
if getCtx(c).authenticated {
pCtx.authenticated = true
}
pCtx.nested = true

return func(l *lua.LState) int {
top := l.GetTop()
if top == 0 {
l.Error(lua.LString("Please specify at least one argument for redis.call()"), 1)
return 0
}
var args []interface{}
var args []string
for i := 1; i <= top; i++ {
switch a := l.Get(i).(type) {
case lua.LNumber:
Expand All @@ -29,22 +38,19 @@ func mkLuaFuncs(conn redigo.Conn) map[string]lua.LGFunction {
return 0
}
}
cmd, ok := args[0].(string)
if !ok {
l.Error(lua.LString("Unknown Redis command called from Lua script"), 1)
if len(args) == 0 {
l.Error(lua.LString(msgNotFromScripts), 1)
return 0
}
switch strings.ToUpper(cmd) {
case "MULTI", "EXEC":
if failFast {
l.Error(lua.LString("This Redis command is not allowed from scripts"), 1)
return 0
}
l.Push(lua.LNil)
return 1
}

res, err := conn.Do(cmd, args[1:]...)
buf := &bytes.Buffer{}
wr := bufio.NewWriter(buf)
peer := server.NewPeer(wr)
peer.Ctx = pCtx
srv.Dispatch(peer, args)
wr.Flush()

res, err := server.ParseReply(bufio.NewReader(buf))
if err != nil {
if failFast {
// call() mode
Expand All @@ -66,14 +72,19 @@ func mkLuaFuncs(conn redigo.Conn) map[string]lua.LGFunction {
switch r := res.(type) {
case int64:
l.Push(lua.LNumber(r))
case int:
l.Push(lua.LNumber(r))
case []uint8:
l.Push(lua.LString(string(r)))
case []interface{}:
l.Push(redisToLua(l, r))
case string:
l.Push(lua.LString(r))
case error:
l.Error(lua.LString(r.Error()), 1)
return 0
default:
panic("type not handled")
panic(fmt.Sprintf("type not handled (%T)", r))
}
}
return 1
Expand Down
25 changes: 9 additions & 16 deletions miniredis.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,10 @@ import (
"context"
"fmt"
"math/rand"
"net"
"strconv"
"sync"
"time"

redigo "github.com/gomodule/redigo/redis"

"github.com/alicebob/miniredis/v2/server"
)

Expand Down Expand Up @@ -80,6 +77,7 @@ type connCtx struct {
dirtyTransaction bool // any error during QUEUEing
watch map[dbKey]uint // WATCHed keys
subscriber *Subscriber // client is in PUBSUB mode if not nil
nested bool // this is called via Lua
}

// NewMiniRedis makes a new, non-started, Miniredis object.
Expand Down Expand Up @@ -287,19 +285,6 @@ func (m *Miniredis) Server() *server.Server {
return m.srv
}

// redigo returns a redigo.Conn, connected using net.Pipe
func (m *Miniredis) redigo() redigo.Conn {
c1, c2 := net.Pipe()
m.srv.ServeConn(c1)
c := redigo.NewConn(c2, 0, 0)
if m.password != "" {
if _, err := c.Do("AUTH", m.password); err != nil {
// ?
}
}
return c
}

// Dump returns a text version of the selected DB, usable for debugging.
func (m *Miniredis) Dump() string {
m.Lock()
Expand Down Expand Up @@ -366,6 +351,10 @@ func (m *Miniredis) SetTime(t time.Time) {

// handleAuth returns false if connection has no access. It sends the reply.
func (m *Miniredis) handleAuth(c *server.Peer) bool {
if getCtx(c).nested {
return true
}

m.Lock()
defer m.Unlock()
if m.password == "" {
Expand All @@ -381,6 +370,10 @@ func (m *Miniredis) handleAuth(c *server.Peer) bool {
// handlePubsub sends an error to the user if the connection is in PUBSUB mode.
// It'll return true if it did.
func (m *Miniredis) checkPubsub(c *server.Peer) bool {
if getCtx(c).nested {
return false
}

m.Lock()
defer m.Unlock()

Expand Down
18 changes: 0 additions & 18 deletions miniredis_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -238,21 +238,3 @@ func TestExpireWithFastForward(t *testing.T) {
s.FastForward(5 * time.Second)
equals(t, 1, len(s.Keys()))
}

func TestRedigo(t *testing.T) {
s, err := Run()
ok(t, err)

r := s.redigo()
defer r.Close()

_, err = r.Do("SELECT", 2)
ok(t, err)

_, err = r.Do("SET", "foo", "bar")
ok(t, err)

v, err := redis.String(r.Do("GET", "foo"))
ok(t, err)
equals(t, "bar", v)
}
Loading

0 comments on commit dea7631

Please sign in to comment.