Skip to content

Commit

Permalink
Add rand implementation (#7)
Browse files Browse the repository at this point in the history
  • Loading branch information
jackkleeman authored Jul 15, 2024
1 parent 96a7bd9 commit fe436d0
Show file tree
Hide file tree
Showing 5 changed files with 158 additions and 9 deletions.
10 changes: 1 addition & 9 deletions example/checkout.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"fmt"
"math/rand"

"github.com/google/uuid"
restate "github.com/restatedev/sdk-go"
)

Expand All @@ -27,17 +26,10 @@ func (c *checkout) Name() string {
const CheckoutServiceName = "Checkout"

func (c *checkout) Payment(ctx restate.Context, request PaymentRequest) (response PaymentResponse, err error) {
uuid, err := restate.RunAs(ctx, func(ctx restate.RunContext) (string, error) {
uuid := uuid.New()
return uuid.String(), nil
})
uuid := ctx.Rand().UUID().String()

response.ID = uuid

if err != nil {
return response, err
}

// We are a uniform shop where everything costs 30 USD
// that is cheaper than the official example :P
price := len(request.Tickets) * 30
Expand Down
67 changes: 67 additions & 0 deletions internal/rand/rand.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
package rand

import (
"crypto/sha256"
"encoding/binary"
"math/rand/v2"

"github.com/google/uuid"
)

type Rand struct {
*rand.Rand
}

func New(invocationID []byte) *Rand {
return &Rand{rand.New(newSource(invocationID))}
}

func (r *Rand) UUID() uuid.UUID {
var uuid [16]byte
binary.LittleEndian.PutUint64(uuid[:8], r.Uint64())
binary.LittleEndian.PutUint64(uuid[8:], r.Uint64())
uuid[6] = (uuid[6] & 0x0f) | 0x40 // Version 4
uuid[8] = (uuid[8] & 0x3f) | 0x80 // Variant is 10
return uuid
}

type Source struct {
state [4]uint64
}

func newSource(invocationID []byte) *Source {
hash := sha256.New()
hash.Write(invocationID)
var sum [32]byte
hash.Sum(sum[:0])

return &Source{state: [4]uint64{
binary.LittleEndian.Uint64(sum[:8]),
binary.LittleEndian.Uint64(sum[8:16]),
binary.LittleEndian.Uint64(sum[16:24]),
binary.LittleEndian.Uint64(sum[24:32]),
}}
}

func (s *Source) Uint64() uint64 {
result := rotl((s.state[0]+s.state[3]), 23) + s.state[0]

t := (s.state[1] << 17)

s.state[2] ^= s.state[0]
s.state[3] ^= s.state[1]
s.state[1] ^= s.state[2]
s.state[0] ^= s.state[3]

s.state[2] ^= t

s.state[3] = rotl(s.state[3], 45)

return result
}

func rotl(x uint64, k uint64) uint64 {
return (x << k) | (x >> (64 - k))
}

var _ rand.Source = (*Source)(nil)
76 changes: 76 additions & 0 deletions internal/rand/rand_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
package rand

import (
"encoding/hex"
"math/rand/v2"
"testing"
)

func TestUint64(t *testing.T) {
id, err := hex.DecodeString("f311f1fdcb9863f0018bd3400ecd7d69b547204e776218b2")
if err != nil {
t.Fatal(err)
}
rand := New(id)

expected := []uint64{
6541268553928124324,
1632128201851599825,
3999496359968271420,
9099219592091638755,
2609122094717920550,
16569362788292807660,
14955958648458255954,
15581072429430901841,
4951852598761288088,
2380816196140950843,
}

for _, e := range expected {
if found := rand.Uint64(); e != found {
t.Fatalf("Unexpected uint64 %d, expected %d", found, e)
}
}
}

func TestFloat64(t *testing.T) {
source := &Source{state: [4]uint64{1, 2, 3, 4}}
rand := &Rand{rand.New(source)}

expected := []float64{
4.656612984099695e-9, 6.519269457605503e-9, 0.39843750651926946,
0.3986824029416509, 0.5822761557370711, 0.2997488042907357,
0.5336032865255543, 0.36335061693258097, 0.5968067925950846,
0.18570456306457928,
}

for _, e := range expected {
if found := rand.Float64(); e != found {
t.Fatalf("Unexpected float64 %v, expected %v", found, e)
}
}
}

func TestUUID(t *testing.T) {
source := &Source{state: [4]uint64{1, 2, 3, 4}}
rand := &Rand{rand.New(source)}

expected := []string{
"01008002-0000-4000-a700-800300000000",
"67008003-00c0-4c00-b200-449901c20c00",
"cd33c49a-01a2-4280-ba33-eecd8a97698a",
"bd4a1533-4713-41c2-979e-167991a02bac",
"d83f078f-0a19-43db-a092-22b24af10591",
"677c91f7-146e-4769-a4fd-df3793e717e8",
"f15179b2-f220-4427-8d90-7b5437d9828d",
"9e97720f-42b8-4d09-a449-914cf221df26",
"09d0a109-6f11-4ef9-93fa-f013d0ad3808",
"41eb0e0c-41c9-4828-85d0-59fb901b4df4",
}

for _, e := range expected {
if found := rand.UUID().String(); e != found {
t.Fatalf("Unexpected uuid %s, expected %s", found, e)
}
}
}
8 changes: 8 additions & 0 deletions internal/state/state.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import (
"github.com/restatedev/sdk-go/internal/errors"
"github.com/restatedev/sdk-go/internal/futures"
"github.com/restatedev/sdk-go/internal/log"
"github.com/restatedev/sdk-go/internal/rand"
"github.com/restatedev/sdk-go/internal/wire"
"github.com/restatedev/sdk-go/rcontext"
)
Expand Down Expand Up @@ -46,6 +47,10 @@ func (c *Context) Log() *slog.Logger {
return c.machine.userLog
}

func (c *Context) Rand() *rand.Rand {
return c.machine.rand
}

func (c *Context) Set(key string, value []byte) {
c.machine.set(key, value)
}
Expand Down Expand Up @@ -172,6 +177,8 @@ type Machine struct {
pendingAcks map[uint32]wire.AckableMessage
pendingMutex sync.RWMutex

rand *rand.Rand

failure any
}

Expand Down Expand Up @@ -204,6 +211,7 @@ func (m *Machine) Start(inner context.Context, dropReplayLogs bool, logHandler s
m.ctx = inner
m.suspensionCtx, m.suspend = context.WithCancelCause(m.ctx)
m.id = start.Id
m.rand = rand.New(m.id)
m.key = start.Key

logHandler = logHandler.WithAttrs([]slog.Attr{slog.String("invocationID", start.DebugId)})
Expand Down
6 changes: 6 additions & 0 deletions router.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (

"github.com/restatedev/sdk-go/internal"
"github.com/restatedev/sdk-go/internal/futures"
"github.com/restatedev/sdk-go/internal/rand"
"github.com/vmihailenco/msgpack/v5"
)

Expand Down Expand Up @@ -50,6 +51,11 @@ type Selector interface {
type Context interface {
RunContext

// Returns a random source which will give deterministic results for a given invocation
// The source wraps the stdlib rand.Rand but with some extra helper methods
// This source is not safe for use inside .Run()
Rand() *rand.Rand

// Sleep for the duration d
Sleep(d time.Duration)
// Return a handle on a sleep duration which can be combined
Expand Down

0 comments on commit fe436d0

Please sign in to comment.