Skip to content

Commit

Permalink
Fix #124 (#125)
Browse files Browse the repository at this point in the history
  • Loading branch information
ElecTwix authored Feb 5, 2024
1 parent e78588e commit 7c2584a
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 13 deletions.
4 changes: 2 additions & 2 deletions db.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ func New(url string, connection conn.Connection) (*DB, error) {
// --------------------------------------------------

// Close closes the underlying WebSocket connection.
func (db *DB) Close() {
_ = db.conn.Close()
func (db *DB) Close() error {
return db.conn.Close()
}

// --------------------------------------------------
Expand Down
24 changes: 17 additions & 7 deletions db_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"bytes"
"encoding/json"
"fmt"
"io"
rawslog "log/slog"
"os"
"sync"
Expand Down Expand Up @@ -31,6 +32,7 @@ type SurrealDBTestSuite struct {
db *surrealdb.DB
name string
connImplementations map[string]conn.Connection
logBuffer *bytes.Buffer
}

// a simple user struct for testing
Expand All @@ -55,20 +57,23 @@ func TestSurrealDBSuite(t *testing.T) {
SurrealDBSuite.connImplementations = make(map[string]conn.Connection)

// Without options
logData := createLogger(t)
buff := bytes.NewBufferString("")
logData := createLogger(t, buff)
SurrealDBSuite.connImplementations["gorilla"] = gorilla.Create().Logger(logData)
SurrealDBSuite.logBuffer = buff

// With options
logData = createLogger(t)
SurrealDBSuite.connImplementations["gorilla_opt"] = gorilla.Create().SetTimeOut(time.Minute).SetCompression(true).Logger(logData)
buffOpt := bytes.NewBufferString("")
logDataOpt := createLogger(t, buff)
SurrealDBSuite.connImplementations["gorilla_opt"] = gorilla.Create().SetTimeOut(time.Minute).SetCompression(true).Logger(logDataOpt)
SurrealDBSuite.logBuffer = buffOpt

RunWsMap(t, SurrealDBSuite)
}

func createLogger(t *testing.T) logger.Logger {
func createLogger(t *testing.T, writer io.Writer) logger.Logger {
t.Helper()
buff := bytes.NewBuffer([]byte{})
handler := rawslog.NewJSONHandler(buff, &rawslog.HandlerOptions{Level: rawslog.LevelDebug})
handler := rawslog.NewJSONHandler(writer, &rawslog.HandlerOptions{Level: rawslog.LevelDebug})
return slog.New(handler)
}

Expand All @@ -86,11 +91,16 @@ func RunWsMap(t *testing.T, s *SurrealDBTestSuite) {
func (s *SurrealDBTestSuite) TearDownTest() {
_, err := s.db.Delete("users")
s.Require().NoError(err)

if s.logBuffer.Len() > 0 {
s.T().Logf("Log output:\n%s", s.logBuffer.String())
}
}

// TearDownSuite is called after the s has finished running
func (s *SurrealDBTestSuite) TearDownSuite() {
s.db.Close()
err := s.db.Close()
s.Require().NoError(err)
}

func (t testUser) String() (str string, err error) {
Expand Down
16 changes: 12 additions & 4 deletions pkg/conn/gorilla/gorilla.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"encoding/json"
"errors"
"fmt"
"net"
"reflect"
"strconv"
"sync"
Expand Down Expand Up @@ -104,11 +105,15 @@ func (ws *WebSocket) SetCompression(compress bool) *WebSocket {
}

func (ws *WebSocket) Close() error {
defer func() {
close(ws.close)
}()
ws.connLock.Lock()
defer ws.connLock.Unlock()
close(ws.close)
err := ws.Conn.WriteMessage(gorilla.CloseMessage, gorilla.FormatCloseMessage(CloseMessageCode, ""))
if err != nil {
return err
}

return ws.Conn.WriteMessage(gorilla.CloseMessage, gorilla.FormatCloseMessage(CloseMessageCode, ""))
return ws.Conn.Close()
}

func (ws *WebSocket) LiveNotifications(liveQueryID string) (chan model.Notification, error) {
Expand Down Expand Up @@ -239,6 +244,9 @@ func (ws *WebSocket) initialize() {
var res rpc.RPCResponse
err := ws.read(&res)
if err != nil {
if errors.Is(err, net.ErrClosed) {
break
}
ws.logger.Error(err.Error())
continue
}
Expand Down

0 comments on commit 7c2584a

Please sign in to comment.