Skip to content

Commit

Permalink
fix(#46): make memory provider safe-concurrent
Browse files Browse the repository at this point in the history
  • Loading branch information
savsgio committed Oct 24, 2022
1 parent b8ae6df commit da2b08b
Show file tree
Hide file tree
Showing 13 changed files with 415 additions and 86 deletions.
6 changes: 4 additions & 2 deletions encoding.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ var b64Encoding = base64.StdEncoding

// MSGPEncode MessagePack encode
func MSGPEncode(src Dict) ([]byte, error) {
if len(src.D) == 0 {
if len(src.KV) == 0 {
return nil, nil
}

Expand All @@ -22,7 +22,9 @@ func MSGPEncode(src Dict) ([]byte, error) {

// MSGPDecode MessagePack decode
func MSGPDecode(dst *Dict, src []byte) error {
dst.Reset()
for k := range dst.KV {
delete(dst.KV, k)
}

if len(src) == 0 {
return nil
Expand Down
52 changes: 21 additions & 31 deletions encoding_test.go
Original file line number Diff line number Diff line change
@@ -1,68 +1,58 @@
package session

import (
"bytes"
"reflect"
"testing"
)

func getSRC() *Dict {
src := new(Dict)
func getSRC() Dict {
src := newDictValue()

src.Set("k1", 1)
src.Set("k2", 2)
src.KV["k1"] = "1"
src.KV["k2"] = "2"

return src
}

func getDST() *Dict {
return new(Dict)
func getDST() Dict {
return newDictValue()
}

func TestMSGPEncodeDecode(t *testing.T) {
src := getSRC()
dst := getDST()

b1, err := MSGPEncode(*src)
b1, err := MSGPEncode(src)
if err != nil {
t.Fatal(err)
}

err = MSGPDecode(dst, b1)
err = MSGPDecode(&dst, b1)
if err != nil {
t.Fatal(err)
}

b2, err := MSGPEncode(*dst)
if err != nil {
t.Fatal(err)
}

if !bytes.Equal(b1, b2) {
t.Errorf("The bytes results of 'src' and 'dst' must be equals, src = %s; dst = %s", b1, b2)
if !reflect.DeepEqual(src, dst) {
t.Errorf("The results of 'src' and 'dst' must be equals, src = %v; dst = %v", src, dst)
}
}

func TestBase64EncodeDecode(t *testing.T) {
src := getSRC()
dst := getDST()

b1, err := Base64Encode(*src)
if err != nil {
t.Fatal(err)
}

err = Base64Decode(dst, b1)
b1, err := Base64Encode(src)
if err != nil {
t.Fatal(err)
}

b2, err := Base64Encode(*dst)
err = Base64Decode(&dst, b1)
if err != nil {
t.Fatal(err)
}

if !bytes.Equal(b1, b2) {
t.Errorf("The bytes results of 'src' and 'dst' must be equals, src = %s; dst = %s", b1, b2)
if !reflect.DeepEqual(src, dst) {
t.Errorf("The results of 'src' and 'dst' must be equals, src = %v; dst = %v", src, dst)
}
}

Expand All @@ -71,19 +61,19 @@ func BenchmarkMSGPEncode(b *testing.B) {

b.ResetTimer()
for i := 0; i < b.N; i++ {
MSGPEncode(*src)
MSGPEncode(src)
}
}

func BenchmarkMSGPDecode(b *testing.B) {
src := getSRC()
dst := getDST()

srcBytes, _ := MSGPEncode(*src)
srcBytes, _ := MSGPEncode(src)

b.ResetTimer()
for i := 0; i < b.N; i++ {
MSGPDecode(dst, srcBytes)
MSGPDecode(&dst, srcBytes)
}
}

Expand All @@ -92,18 +82,18 @@ func BenchmarkBase64Encode(b *testing.B) {

b.ResetTimer()
for i := 0; i < b.N; i++ {
Base64Encode(*src)
Base64Encode(src)
}
}

func BenchmarkBase64Decode(b *testing.B) {
src := getSRC()
dst := getDST()

srcBytes, _ := Base64Encode(*src)
srcBytes, _ := Base64Encode(src)

b.ResetTimer()
for i := 0; i < b.N; i++ {
Base64Decode(dst, srcBytes)
Base64Decode(&dst, srcBytes)
}
}
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ require (
github.com/go-sql-driver/mysql v1.6.0
github.com/lib/pq v1.10.7
github.com/mattn/go-sqlite3 v1.14.15
github.com/savsgio/dictpool v0.0.0-20220406081701-03de5edb2e6d
github.com/savsgio/gotils v0.0.0-20220530130905-52f3993e8d6d
github.com/tinylib/msgp v1.1.6
github.com/valyala/bytebufferpool v1.0.0
github.com/valyala/fasthttp v1.40.0
)
3 changes: 0 additions & 3 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,6 @@ github.com/onsi/gomega v1.18.1/go.mod h1:0q+aL8jAiMXy9hbwj2mr5GziHiwhAIQpFmmtT5h
github.com/philhofer/fwd v1.1.1 h1:GdGcTjf5RNAxwS4QLsiMzJYj5KEvPJD3Abr261yRQXQ=
github.com/philhofer/fwd v1.1.1/go.mod h1:gk3iGcWd9+svBvR0sR+KPcfE+RNWozjowpeBVG3ZVNU=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/savsgio/dictpool v0.0.0-20220406081701-03de5edb2e6d h1:ICMDEgNgR5xFW6ZDeMKTtmh07YiLr7GkDw897I2DwKg=
github.com/savsgio/dictpool v0.0.0-20220406081701-03de5edb2e6d/go.mod h1:jrsy/bTK2n5uybo7bAvtLGzmuzAbxp+nKS8bzgrZURE=
github.com/savsgio/gotils v0.0.0-20220401102855-e56b59f40436/go.mod h1:Gy+0tqhJvgGlqnTF8CVGP0AaGRjwBtXs/a5PA0Y3+A4=
github.com/savsgio/gotils v0.0.0-20220530130905-52f3993e8d6d h1:Q+gqLBOPkFGHyCJxXMRqtUgUbTjI8/Ze8vu8GGyNFwo=
github.com/savsgio/gotils v0.0.0-20220530130905-52f3993e8d6d/go.mod h1:Gy+0tqhJvgGlqnTF8CVGP0AaGRjwBtXs/a5PA0Y3+A4=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
Expand Down
64 changes: 43 additions & 21 deletions providers/memory/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"sync"
"time"

"github.com/fasthttp/session/v2"
"github.com/savsgio/gotils/strconv"
)

Expand All @@ -30,16 +29,21 @@ func releaseItem(item *item) {
func New(cfg Config) (*Provider, error) {
p := &Provider{
config: cfg,
db: new(session.Dict),
}

return p, nil
}

func (p *Provider) getSessionKey(sessionID []byte) string {
return strconv.B2S(sessionID)
}

// Get returns the data of the given session id
func (p *Provider) Get(id []byte) ([]byte, error) {
val := p.db.GetBytes(id)
if val == nil { // Not exist
key := p.getSessionKey(id)

val, found := p.db.Load(key)
if !found || val == nil { // Not exist
return nil, nil
}

Expand All @@ -50,48 +54,64 @@ func (p *Provider) Get(id []byte) ([]byte, error) {

// Save saves the session data and expiration from the given session id
func (p *Provider) Save(id, data []byte, expiration time.Duration) error {
key := p.getSessionKey(id)

item := acquireItem()
item.data = data
item.lastActiveTime = time.Now().UnixNano()
item.expiration = expiration

p.db.SetBytes(id, item)
p.db.Store(key, item)

return nil
}

// Regenerate updates the session id and expiration with the new session id
// of the the given current session id
func (p *Provider) Regenerate(id, newID []byte, expiration time.Duration) error {
data := p.db.GetBytes(id)
if data != nil {
key := p.getSessionKey(id)

data, found := p.db.LoadAndDelete(key)
if found && data != nil {
item := data.(*item)
item.lastActiveTime = time.Now().UnixNano()
item.expiration = expiration

p.db.SetBytes(newID, item)
p.db.DelBytes(id)
newKey := p.getSessionKey(newID)

p.db.Store(newKey, item)
}

return nil
}

// Destroy destroys the session from the given id
func (p *Provider) Destroy(id []byte) error {
val := p.db.GetBytes(id)
if val == nil {
func (p *Provider) destroy(key string) error {
val, found := p.db.LoadAndDelete(key)
if !found || val == nil {
return nil
}

p.db.DelBytes(id)
releaseItem(val.(*item))

return nil
}

// Destroy destroys the session from the given id
func (p *Provider) Destroy(id []byte) error {
key := p.getSessionKey(id)

return p.destroy(key)
}

// Count returns the total of stored sessions
func (p *Provider) Count() int {
return len(p.db.D)
func (p *Provider) Count() (count int) {
p.db.Range(func(_, _ interface{}) bool {
count++

return true
})

return count
}

// NeedGC indicates if the GC needs to be run
Expand All @@ -103,17 +123,19 @@ func (p *Provider) NeedGC() bool {
func (p *Provider) GC() error {
now := time.Now().UnixNano()

for _, kv := range p.db.D {
item := kv.Value.(*item)
p.db.Range(func(key, value interface{}) bool {
item := value.(*item)

if item.expiration == 0 {
continue
return true
}

if now >= (item.lastActiveTime + item.expiration.Nanoseconds()) {
p.Destroy(strconv.S2B(kv.Key))
_ = p.destroy(key.(string))
}
}

return true
})

return nil
}
5 changes: 2 additions & 3 deletions providers/memory/types.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
package memory

import (
"sync"
"time"

"github.com/fasthttp/session/v2"
)

// Config provider settings
Expand All @@ -12,7 +11,7 @@ type Config struct{}
// Provider backend manager
type Provider struct {
config Config
db *session.Dict
db sync.Map
}

type item struct {
Expand Down
2 changes: 1 addition & 1 deletion session.go
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ func (s *Session) Get(ctx *fasthttp.RequestCtx) (*Store, error) {
return nil, err
}

if err := s.config.DecodeFunc(store.data, data); err != nil {
if err := s.config.DecodeFunc(&store.data, data); err != nil {
return store, nil
}
}
Expand Down
Loading

0 comments on commit da2b08b

Please sign in to comment.