diff --git a/db.go b/db.go index d90329f..4eca70d 100644 --- a/db.go +++ b/db.go @@ -5,18 +5,22 @@ import ( "github.com/surrealdb/surrealdb.go/pkg/model" + "github.com/surrealdb/surrealdb.go/pkg/conn" "github.com/surrealdb/surrealdb.go/pkg/constants" - "github.com/surrealdb/surrealdb.go/pkg/websocket" ) // DB is a client for the SurrealDB database that holds are websocket connection. type DB struct { - ws websocket.WebSocket + conn conn.Connection } -// New creates a new SurrealDB client. -func New(url string, ws websocket.WebSocket) (*DB, error) { - return &DB{ws}, nil +// New creates a new SurrealDB lient. +func New(url string, connection conn.Connection) (*DB, error) { + connection, err := connection.Connect(url) + if err != nil { + return nil, err + } + return &DB{connection}, nil } // -------------------------------------------------- @@ -25,7 +29,7 @@ func New(url string, ws websocket.WebSocket) (*DB, error) { // Close closes the underlying WebSocket connection. func (db *DB) Close() { - _ = db.ws.Close() + _ = db.conn.Close() } // -------------------------------------------------- @@ -114,7 +118,7 @@ func (db *DB) Insert(what string, data interface{}) (interface{}, error) { // LiveNotifications returns a channel for live query. func (db *DB) LiveNotifications(liveQueryID string) (chan model.Notification, error) { - return db.ws.LiveNotifications(liveQueryID) + return db.conn.LiveNotifications(liveQueryID) } // -------------------------------------------------- @@ -124,7 +128,7 @@ func (db *DB) LiveNotifications(liveQueryID string) (chan model.Notification, er // send is a helper method for sending a query to the database. func (db *DB) send(method string, params ...interface{}) (interface{}, error) { // here we send the args through our websocket connection - resp, err := db.ws.Send(method, params) + resp, err := db.conn.Send(method, params) if err != nil { return nil, fmt.Errorf("sending request failed for method '%s': %w", method, err) } diff --git a/db_test.go b/db_test.go index 5eeaf70..e53f697 100644 --- a/db_test.go +++ b/db_test.go @@ -17,20 +17,20 @@ import ( "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" "github.com/surrealdb/surrealdb.go" + "github.com/surrealdb/surrealdb.go/pkg/conn/gorilla" "github.com/surrealdb/surrealdb.go/pkg/constants" - gorilla "github.com/surrealdb/surrealdb.go/pkg/gorilla" + "github.com/surrealdb/surrealdb.go/pkg/conn" "github.com/surrealdb/surrealdb.go/pkg/logger" "github.com/surrealdb/surrealdb.go/pkg/marshal" - "github.com/surrealdb/surrealdb.go/pkg/websocket" ) // TestDBSuite is a test s for the DB struct type SurrealDBTestSuite struct { suite.Suite - db *surrealdb.DB - name string - wsImplementations map[string]websocket.WebSocket + db *surrealdb.DB + name string + connImplementations map[string]conn.Connection } // a simple user struct for testing @@ -43,15 +43,15 @@ type testUser struct { func TestSurrealDBSuite(t *testing.T) { SurrealDBSuite := new(SurrealDBTestSuite) - SurrealDBSuite.wsImplementations = make(map[string]websocket.WebSocket) + SurrealDBSuite.connImplementations = make(map[string]conn.Connection) // Without options logData := createLogger(t) - SurrealDBSuite.wsImplementations["gorilla"] = gorilla.Create().Logger(logData) + SurrealDBSuite.connImplementations["gorilla"] = gorilla.Create().Logger(logData) // With options logData = createLogger(t) - SurrealDBSuite.wsImplementations["gorilla_opt"] = gorilla.Create().SetTimeOut(time.Minute).SetCompression(true).Logger(logData) + SurrealDBSuite.connImplementations["gorilla_opt"] = gorilla.Create().SetTimeOut(time.Minute).SetCompression(true).Logger(logData) RunWsMap(t, SurrealDBSuite) } @@ -64,7 +64,7 @@ func createLogger(t *testing.T) logger.Logger { } func RunWsMap(t *testing.T, s *SurrealDBTestSuite) { - for wsName := range s.wsImplementations { + for wsName := range s.connImplementations { // Run the test suite t.Run(wsName, func(t *testing.T) { s.name = wsName @@ -99,11 +99,9 @@ func (s *SurrealDBTestSuite) openConnection() *surrealdb.DB { if url == "" { url = "ws://localhost:8000/rpc" } - impl := s.wsImplementations[s.name] + impl := s.connImplementations[s.name] require.NotNil(s.T(), impl) - ws, err := impl.Connect(url) - s.Require().NoError(err) - db, err := surrealdb.New(url, ws) + db, err := surrealdb.New(url, impl) s.Require().NoError(err) return db } diff --git a/internal/mock/mock.go b/internal/mock/mock.go index 8c56fea..a3b7e3e 100644 --- a/internal/mock/mock.go +++ b/internal/mock/mock.go @@ -3,14 +3,14 @@ package mock import ( "errors" + "github.com/surrealdb/surrealdb.go/pkg/conn" "github.com/surrealdb/surrealdb.go/pkg/model" - "github.com/surrealdb/surrealdb.go/pkg/websocket" ) type ws struct { } -func (w *ws) Connect(url string) (websocket.WebSocket, error) { +func (w *ws) Connect(url string) (conn.Connection, error) { return w, nil } diff --git a/pkg/websocket/websocket.go b/pkg/conn/conn.go similarity index 50% rename from pkg/websocket/websocket.go rename to pkg/conn/conn.go index b34bf29..dc619ce 100644 --- a/pkg/websocket/websocket.go +++ b/pkg/conn/conn.go @@ -1,11 +1,9 @@ -package websocket +package conn -import ( - "github.com/surrealdb/surrealdb.go/pkg/model" -) +import "github.com/surrealdb/surrealdb.go/pkg/model" -type WebSocket interface { - Connect(url string) (WebSocket, error) +type Connection interface { + Connect(url string) (Connection, error) Send(method string, params []interface{}) (interface{}, error) Close() error LiveNotifications(id string) (chan model.Notification, error) diff --git a/pkg/gorilla/gorilla.go b/pkg/conn/gorilla/gorilla.go similarity index 97% rename from pkg/gorilla/gorilla.go rename to pkg/conn/gorilla/gorilla.go index 1ed4377..5619d14 100644 --- a/pkg/gorilla/gorilla.go +++ b/pkg/conn/gorilla/gorilla.go @@ -13,9 +13,9 @@ import ( gorilla "github.com/gorilla/websocket" "github.com/surrealdb/surrealdb.go/internal/rpc" + "github.com/surrealdb/surrealdb.go/pkg/conn" "github.com/surrealdb/surrealdb.go/pkg/logger" "github.com/surrealdb/surrealdb.go/pkg/rand" - "github.com/surrealdb/surrealdb.go/pkg/websocket" ) const ( @@ -55,16 +55,16 @@ func Create() *WebSocket { } } -func (ws *WebSocket) Connect(url string) (websocket.WebSocket, error) { +func (ws *WebSocket) Connect(url string) (conn.Connection, error) { dialer := gorilla.DefaultDialer dialer.EnableCompression = true - conn, _, err := dialer.Dial(url, nil) + connection, _, err := dialer.Dial(url, nil) if err != nil { return nil, err } - ws.Conn = conn + ws.Conn = connection for _, option := range ws.Option { if err := option(ws); err != nil {