Skip to content

Commit

Permalink
fix: code improves
Browse files Browse the repository at this point in the history
  • Loading branch information
savsgio committed Mar 16, 2023
1 parent 32b16ca commit f480a08
Show file tree
Hide file tree
Showing 5 changed files with 83 additions and 45 deletions.
26 changes: 18 additions & 8 deletions encoding_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,7 @@ func TestMSGPEncodeDecode(t *testing.T) {
t.Fatal(err)
}

err = MSGPDecode(&dst, b1)
if err != nil {
if err := MSGPDecode(&dst, b1); err != nil {
t.Fatal(err)
}

Expand All @@ -46,8 +45,7 @@ func TestBase64EncodeDecode(t *testing.T) {
t.Fatal(err)
}

err = Base64Decode(&dst, b1)
if err != nil {
if err := Base64Decode(&dst, b1); err != nil {
t.Fatal(err)
}

Expand All @@ -60,8 +58,11 @@ func BenchmarkMSGPEncode(b *testing.B) {
src := getSRC()

b.ResetTimer()

for i := 0; i < b.N; i++ {
MSGPEncode(src)
if _, err := MSGPEncode(src); err != nil {
b.Fatal(err)
}
}
}

Expand All @@ -72,17 +73,23 @@ func BenchmarkMSGPDecode(b *testing.B) {
srcBytes, _ := MSGPEncode(src)

b.ResetTimer()

for i := 0; i < b.N; i++ {
MSGPDecode(&dst, srcBytes)
if err := MSGPDecode(&dst, srcBytes); err != nil {
b.Fatal(err)
}
}
}

func BenchmarkBase64Encode(b *testing.B) {
src := getSRC()

b.ResetTimer()

for i := 0; i < b.N; i++ {
Base64Encode(src)
if _, err := Base64Encode(src); err != nil {
b.Fatal(err)
}
}
}

Expand All @@ -93,7 +100,10 @@ func BenchmarkBase64Decode(b *testing.B) {
srcBytes, _ := Base64Encode(src)

b.ResetTimer()

for i := 0; i < b.N; i++ {
Base64Decode(&dst, srcBytes)
if err := Base64Decode(&dst, srcBytes); err != nil {
b.Fatal(err)
}
}
}
7 changes: 5 additions & 2 deletions session.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ func New(cfg Config) *Session {
config: cfg,
cookie: newCookie(),
log: cfg.Logger,
storePool: &sync.Pool{
storePool: sync.Pool{
New: func() interface{} {
return NewStore()
},
Expand All @@ -79,9 +79,12 @@ func (s *Session) SetProvider(provider Provider) error {
}

func (s *Session) startGC() {
ticker := time.NewTicker(s.config.GCLifetime)
defer ticker.Stop()

for {
select {
case <-time.After(s.config.GCLifetime):
case <-ticker.C:
err := s.provider.GC()
if err != nil {
s.log.Printf("session GC crash: %v", err)
Expand Down
89 changes: 56 additions & 33 deletions session_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,8 @@ func Test_New(t *testing.T) {
t.Error("Session.cookie is nil")
}

if s.storePool == nil {
t.Error("Session.storePool is nil")
if v := s.storePool.Get().(*Store); v == nil {
t.Errorf("Session.storePool returns: %v", v)
}
}

Expand All @@ -103,7 +103,10 @@ func TestSession_SetProvider(t *testing.T) {
})
provider := &mockProvider{needGCValue: true}

s.SetProvider(provider)
if err := s.SetProvider(provider); err != nil {
t.Fatalf("unexpected error: %v", err)
}

time.Sleep(s.config.GCLifetime + 100*time.Millisecond)
s.stopGC()

Expand Down Expand Up @@ -266,12 +269,14 @@ func TestSession_GetErrEmptySessionID(t *testing.T) {
return []byte("")
},
})
s.SetProvider(new(mockProvider))

if err := s.SetProvider(new(mockProvider)); err != nil {
t.Fatalf("unexpected error: %v", err)
}

ctx := new(fasthttp.RequestCtx)

store, err := s.Get(ctx)

if err != ErrEmptySessionID {
t.Errorf("Expected error: %v", ErrEmptySessionID)
}
Expand All @@ -284,13 +289,15 @@ func TestSession_GetErrEmptySessionID(t *testing.T) {
func TestSession_GetProviderError(t *testing.T) {
s := New(Config{})
provider := &mockProvider{errGet: errors.New("error from provider")}
s.SetProvider(provider)

if err := s.SetProvider(provider); err != nil {
t.Fatalf("unexpected error: %v", err)
}

ctx := new(fasthttp.RequestCtx)
ctx.Request.Header.SetCookie(s.config.CookieName, "aiasdiasd")

store, err := s.Get(ctx)

if err != provider.errGet {
t.Errorf("Expected error: %v", provider.errGet)
}
Expand All @@ -304,18 +311,20 @@ func TestSession_Get(t *testing.T) {
s := New(Config{})

provider := new(mockProvider)
s.SetProvider(provider)

if err := s.SetProvider(provider); err != nil {
t.Fatalf("unexpected error: %v", err)
}

ctx := new(fasthttp.RequestCtx)

store, err := s.Get(ctx)

if err != nil {
t.Errorf("Unexpected error: %v", err)
}

if store == nil {
t.Error("The store is nil")
t.Fatal("store is nil")
}

if len(store.sessionID) == 0 {
Expand All @@ -330,22 +339,26 @@ func TestSession_Get(t *testing.T) {
func TestSession_SaveProviderError(t *testing.T) {
s := New(Config{})
provider := &mockProvider{errSave: errors.New("error from provider")}
s.SetProvider(provider)

if err := s.SetProvider(provider); err != nil {
t.Fatalf("unexpected error: %v", err)
}

ctx := new(fasthttp.RequestCtx)
store := NewStore()

err := s.Save(ctx, store)

if err != provider.errSave {
if err := s.Save(ctx, store); err != provider.errSave {
t.Errorf("Expected error: %v", provider.errGet)
}
}

func TestSession_Save(t *testing.T) {
s := New(Config{})
provider := new(mockProvider)
s.SetProvider(provider)

if err := s.SetProvider(provider); err != nil {
t.Fatalf("unexpected error: %v", err)
}

ctx := new(fasthttp.RequestCtx)

Expand Down Expand Up @@ -382,7 +395,10 @@ func TestSession_RegenerateErrEmptySessionID(t *testing.T) {
return []byte("")
},
})
s.SetProvider(new(mockProvider))

if err := s.SetProvider(new(mockProvider)); err != nil {
t.Fatalf("unexpected error: %v", err)
}

ctx := new(fasthttp.RequestCtx)
ctx.Request.Header.SetCookie(s.config.CookieName, "d32r2f2ecev")
Expand All @@ -395,7 +411,10 @@ func TestSession_RegenerateErrEmptySessionID(t *testing.T) {
func TestSession_RegenerateProviderError(t *testing.T) {
s := New(Config{})
provider := &mockProvider{errRegenerate: errors.New("error from provider")}
s.SetProvider(provider)

if err := s.SetProvider(provider); err != nil {
t.Fatalf("unexpected error: %v", err)
}

ctx := new(fasthttp.RequestCtx)
ctx.Request.Header.SetCookie(s.config.CookieName, "d32r2f2ecev")
Expand All @@ -408,7 +427,10 @@ func TestSession_RegenerateProviderError(t *testing.T) {
func TestSession_Regenerate(t *testing.T) {
s := New(Config{})
provider := &mockProvider{}
s.SetProvider(provider)

if err := s.SetProvider(provider); err != nil {
t.Fatalf("unexpected error: %v", err)
}

id := "d32r2f2ecev"
ctx := new(fasthttp.RequestCtx)
Expand All @@ -427,53 +449,54 @@ func TestSession_DestroyErrNotProvider(t *testing.T) {
s := New(Config{})
ctx := new(fasthttp.RequestCtx)

err := s.Destroy(ctx)

if err != ErrNotSetProvider {
if err := s.Destroy(ctx); err != ErrNotSetProvider {
t.Errorf("Expected error: %v", ErrNotSetProvider)
}
}

func TestSession_DestroyIDNotExist(t *testing.T) {
s := New(Config{})
provider := new(mockProvider)
s.SetProvider(provider)

ctx := new(fasthttp.RequestCtx)
if err := s.SetProvider(provider); err != nil {
t.Fatalf("unexpected error: %v", err)
}

err := s.Destroy(ctx)
ctx := new(fasthttp.RequestCtx)

if err != nil {
if err := s.Destroy(ctx); err != nil {
t.Errorf("Unexpected error: %v", err)
}
}

func TestSession_DestroyProviderError(t *testing.T) {
s := New(Config{})
provider := &mockProvider{errDestroy: errors.New("error from provider")}
s.SetProvider(provider)

if err := s.SetProvider(provider); err != nil {
t.Fatalf("unexpected error: %v", err)
}

ctx := new(fasthttp.RequestCtx)
ctx.Request.Header.SetCookie(s.config.CookieName, "asd2324n")

err := s.Destroy(ctx)

if err != provider.errDestroy {
if err := s.Destroy(ctx); err != provider.errDestroy {
t.Errorf("Expected error: %v", provider.errDestroy)
}
}

func TestSession_Destroy(t *testing.T) {
s := New(Config{})
provider := new(mockProvider)
s.SetProvider(provider)

if err := s.SetProvider(provider); err != nil {
t.Fatalf("unexpected error: %v", err)
}

ctx := new(fasthttp.RequestCtx)
ctx.Request.Header.SetCookie(s.config.CookieName, "asd2324n")

err := s.Destroy(ctx)

if err != nil {
if err := s.Destroy(ctx); err != nil {
t.Errorf("Unexpected error: %v", err)
}
}
4 changes: 3 additions & 1 deletion store_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,9 @@ func TestStore_SetGetHasExpiration(t *testing.T) {

expiration := 10 * time.Second

store.SetExpiration(expiration)
if err := store.SetExpiration(expiration); err != nil {
t.Errorf("unexpected error: %v", err)
}

if v := store.GetExpiration(); v != expiration {
t.Errorf("Store.GetExpiration() == %d, want %d", v, expiration)
Expand Down
2 changes: 1 addition & 1 deletion types.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ type Session struct {
cookie *cookie
log Logger

storePool *sync.Pool
storePool sync.Pool
stopGCChan chan struct{}
}

Expand Down

0 comments on commit f480a08

Please sign in to comment.