diff --git a/db.go b/db.go index f73f2da..494ba99 100644 --- a/db.go +++ b/db.go @@ -2,16 +2,15 @@ package surrealdb import ( "fmt" - - "github.com/surrealdb/surrealdb.go/pkg/model" - - "github.com/surrealdb/surrealdb.go/pkg/conn" + "github.com/surrealdb/surrealdb.go/internal/connection" "github.com/surrealdb/surrealdb.go/pkg/constants" + "github.com/surrealdb/surrealdb.go/pkg/model" ) // DB is a client for the SurrealDB database that holds the connection. type DB struct { - conn conn.Connection + conn connection.Connection + liveHandler connection.LiveHandler } // Auth is a struct that holds surrealdb auth data for login. @@ -24,12 +23,24 @@ type Auth struct { } // New creates a new SurrealDB client. -func New(url string, connection conn.Connection) (*DB, error) { - connection, err := connection.Connect(url) +func New(url string, engine string) (*DB, error) { + newParams := connection.NewConnectionParams{ + Encoder: model.GetCborEncoder(), + Decoder: model.GetCborDecoder(), + } + var conn connection.Connection + if engine != "http" { + conn = connection.NewHttp(newParams) + } else { + conn = connection.NewWebSocket(newParams) + } + + connect, err := conn.Connect(url) if err != nil { return nil, err } - return &DB{connection}, nil + + return &DB{conn: connect}, nil } // -------------------------------------------------- @@ -126,8 +137,8 @@ func (db *DB) Insert(what string, data interface{}) (interface{}, error) { } // LiveNotifications returns a channel for live query. -func (db *DB) LiveNotifications(liveQueryID string) (chan model.Notification, error) { - return db.conn.LiveNotifications(liveQueryID) +func (db *DB) LiveNotifications(liveQueryID string) (chan connection.Notification, error) { + return db.liveHandler.LiveNotifications(liveQueryID) //check if implemented } // -------------------------------------------------- @@ -155,6 +166,7 @@ func (db *DB) send(method string, params ...interface{}) (interface{}, error) { // resp is a helper method for parsing the response from a query. func (db *DB) resp(_ string, _ []interface{}, res interface{}) (interface{}, error) { if res == nil { + //return nil, pkg.ErrNoRow return nil, constants.ErrNoRow } return res, nil diff --git a/db_test.go b/db_test.go index def4dd1..78ac38f 100644 --- a/db_test.go +++ b/db_test.go @@ -4,6 +4,8 @@ import ( "bytes" "encoding/json" "fmt" + "github.com/surrealdb/surrealdb.go/pkg/constants" + "github.com/surrealdb/surrealdb.go/pkg/logger" "io" rawslog "log/slog" "os" @@ -12,17 +14,10 @@ import ( "time" "github.com/stretchr/testify/assert" - "github.com/surrealdb/surrealdb.go/pkg/logger/slog" - "github.com/surrealdb/surrealdb.go/pkg/model" - "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" "github.com/surrealdb/surrealdb.go" - "github.com/surrealdb/surrealdb.go/pkg/conn/gorilla" - "github.com/surrealdb/surrealdb.go/pkg/constants" - - "github.com/surrealdb/surrealdb.go/pkg/conn" - "github.com/surrealdb/surrealdb.go/pkg/logger" + conn "github.com/surrealdb/surrealdb.go/internal/connection" "github.com/surrealdb/surrealdb.go/pkg/marshal" ) @@ -68,13 +63,13 @@ func TestSurrealDBSuite(t *testing.T) { // Without options buff := bytes.NewBufferString("") logData := createLogger(t, buff) - SurrealDBSuite.connImplementations["gorilla"] = gorilla.Create().Logger(logData) + SurrealDBSuite.connImplementations["ws"] = conn.NewWebSocket(conn.NewConnectionParams{}).Logger(logData) SurrealDBSuite.logBuffer = buff // With options buffOpt := bytes.NewBufferString("") logDataOpt := createLogger(t, buff) - SurrealDBSuite.connImplementations["gorilla_opt"] = gorilla.Create().SetTimeOut(time.Minute).SetCompression(true).Logger(logDataOpt) + SurrealDBSuite.connImplementations["ws_opt"] = conn.NewWebSocket(conn.NewConnectionParams{}).SetTimeOut(time.Minute).SetCompression(true).Logger(logDataOpt) SurrealDBSuite.logBuffer = buffOpt RunWsMap(t, SurrealDBSuite) @@ -83,7 +78,7 @@ func TestSurrealDBSuite(t *testing.T) { func createLogger(t *testing.T, writer io.Writer) logger.Logger { t.Helper() handler := rawslog.NewJSONHandler(writer, &rawslog.HandlerOptions{Level: rawslog.LevelDebug}) - return slog.New(handler) + return logger.New(handler) } func RunWsMap(t *testing.T, s *SurrealDBTestSuite) { @@ -134,7 +129,7 @@ func (s *SurrealDBTestSuite) createTestDB() *surrealdb.DB { // openConnection opens a new connection to the database func (s *SurrealDBTestSuite) openConnection(url string, impl conn.Connection) *surrealdb.DB { require.NotNil(s.T(), impl) - db, err := surrealdb.New(url, impl) + db, err := surrealdb.New(url, "") s.Require().NoError(err) return db } @@ -177,7 +172,7 @@ func (s *SurrealDBTestSuite) TestLiveViaMethod() { }) s.Require().NoError(e) notification := <-notifications - s.Require().Equal(model.CreateAction, notification.Action) + s.Require().Equal(conn.CreateAction, notification.Action) s.Require().Equal(live, notification.ID) } @@ -208,7 +203,7 @@ func (s *SurrealDBTestSuite) TestLiveWithOptionsViaMethod() { s.Require().NoError(e) notification := <-notifications - s.Require().Equal(model.UpdateAction, notification.Action) + s.Require().Equal(conn.UpdateAction, notification.Action) s.Require().Equal(live, notification.ID) } @@ -236,7 +231,7 @@ func (s *SurrealDBTestSuite) TestLiveViaQuery() { }) s.Require().NoError(e) notification := <-notifications - s.Require().Equal(model.CreateAction, notification.Action) + s.Require().Equal(conn.CreateAction, notification.Action) s.Require().Equal(liveID, notification.ID) } @@ -781,7 +776,7 @@ func (s *SurrealDBTestSuite) TestConcurrentOperations() { } func (s *SurrealDBTestSuite) TestConnectionBreak() { - ws := gorilla.Create() + ws := conn.NewWebSocket(conn.NewConnectionParams{}) var url string if currentURL == "" { url = defaultURL diff --git a/go.mod b/go.mod index 43ef31e..2c4b8d9 100644 --- a/go.mod +++ b/go.mod @@ -9,9 +9,11 @@ require ( require ( github.com/davecgh/go-spew v1.1.1 // indirect + github.com/fxamacker/cbor/v2 v2.7.0 // indirect github.com/kr/pretty v0.3.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/rogpeppe/go-internal v1.10.0 // indirect + github.com/x448/float16 v0.8.4 // indirect gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index ab9f4d8..2adb75d 100644 --- a/go.sum +++ b/go.sum @@ -1,6 +1,8 @@ github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/fxamacker/cbor/v2 v2.7.0 h1:iM5WgngdRBanHcxugY4JySA0nk1wZorNOpTgCMedv5E= +github.com/fxamacker/cbor/v2 v2.7.0/go.mod h1:pxXPTn3joSm21Gbwsv0w9OSA2y1HFR9qXEeXQVeNoDQ= github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc= github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= @@ -18,6 +20,8 @@ github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjR github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog= github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM= +github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= diff --git a/internal/benchmark/benchmark_test.go b/internal/benchmark/benchmark_test.go index 03532ce..2bb07c2 100644 --- a/internal/benchmark/benchmark_test.go +++ b/internal/benchmark/benchmark_test.go @@ -5,7 +5,6 @@ import ( "testing" "github.com/surrealdb/surrealdb.go" - "github.com/surrealdb/surrealdb.go/internal/mock" "github.com/surrealdb/surrealdb.go/pkg/marshal" ) @@ -18,7 +17,7 @@ type testUser struct { } func SetupMockDB() (*surrealdb.DB, error) { - return surrealdb.New("", mock.Create()) + return surrealdb.New("", "") } func BenchmarkCreate(b *testing.B) { diff --git a/internal/connection/connection.go b/internal/connection/connection.go new file mode 100644 index 0000000..80d0052 --- /dev/null +++ b/internal/connection/connection.go @@ -0,0 +1,27 @@ +package connection + +type Connection interface { + Connect(url string) (Connection, error) + Close() error + Send(method string, params []interface{}) (interface{}, error) +} + +type LiveHandler interface { + LiveNotifications(id string) (chan Notification, error) +} + +type Encoder func(value interface{}) ([]byte, error) + +type Decoder func(encoded []byte, value interface{}) error + +type BaseConnection struct { + encode Encoder + decode Decoder + baseURL string +} + +type NewConnectionParams struct { + Encoder Encoder + Decoder Decoder + BaseURL string +} diff --git a/internal/connection/http.go b/internal/connection/http.go new file mode 100644 index 0000000..9c21e13 --- /dev/null +++ b/internal/connection/http.go @@ -0,0 +1,113 @@ +package connection + +import ( + "bytes" + "fmt" + "github.com/surrealdb/surrealdb.go/internal/rand" + "io/ioutil" + "log" + "net/http" + "time" +) + +type Http struct { + BaseConnection + + httpClient *http.Client + + namespace string + database string +} + +func NewHttp(p NewConnectionParams) Connection { + con := Http{ + BaseConnection: BaseConnection{ + encode: p.Encoder, + decode: p.Decoder, + }, + } + + if con.httpClient == nil { + con.httpClient = &http.Client{ + Timeout: 10 * time.Second, // Set a default timeout to avoid hanging requests + } + } + + return &con +} + +func (h *Http) Connect(url string) (Connection, error) { + // TODO: EXTRACT BASE url and set + h.baseURL = url + + _, err := h.MakeRequest(http.MethodGet, "/health", nil) + if err != nil { + return nil, err + } + + return h, nil +} + +func (h *Http) Close() error { + return nil +} + +func (h *Http) SetTimeout(timeout time.Duration) *Http { + h.httpClient.Timeout = timeout + return h +} + +func (h *Http) SetHttpClient(client *http.Client) *Http { + h.httpClient = client + return h +} + +func (h *Http) Send(method string, params []interface{}) (interface{}, error) { + if h.baseURL == "" { + return nil, fmt.Errorf("connection host not set") + } + + rpcReq := &RPCRequest{ + ID: rand.String(RequestIDLength), + Method: method, + Params: params, + } + + reqBody, err := h.encode(rpcReq) + if err != nil { + return nil, err + } + + httpReq, err := http.NewRequest(method, h.baseURL+"rpc", bytes.NewBuffer(reqBody)) + httpReq.Header.Set("Accept", "application/cbor") + httpReq.Header.Set("Content-Type", "application/cbor") + + resp, err := h.MakeRequest(http.MethodPost, "/rpc", reqBody) + if err != nil { + return nil, err + } + + var rpcResponse RPCResponse + err = h.decode(resp, &rpcResponse) + + return &rpcResponse, nil +} + +func (h *Http) MakeRequest(method string, url string, body []byte) ([]byte, error) { + req, err := http.NewRequest(method, url, bytes.NewBuffer(body)) + if err != nil { + log.Fatalf("Error creating request: %v", err) + } + + resp, err := h.httpClient.Do(req) + if err != nil { + log.Fatalf("Error making HTTP request: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + return nil, fmt.Errorf("request failed with status code %d", resp.StatusCode) + } + + return ioutil.ReadAll(resp.Body) +} diff --git a/internal/connection/http_test.go b/internal/connection/http_test.go new file mode 100644 index 0000000..a3aca81 --- /dev/null +++ b/internal/connection/http_test.go @@ -0,0 +1,48 @@ +package connection + +import ( + "bytes" + "fmt" + "github.com/stretchr/testify/assert" + "io/ioutil" + "net/http" + "testing" +) + +// RoundTripFunc . +type RoundTripFunc func(req *http.Request) *http.Response + +// RoundTrip . +func (f RoundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) { + return f(req), nil +} + +// NewTestClient returns *http.Client with Transport replaced to avoid making real calls +func NewTestClient(fn RoundTripFunc) *http.Client { + return &http.Client{ + Transport: fn, + } +} + +func TestEngine_MakeRequest(t *testing.T) { + httpClient := NewTestClient(func(req *http.Request) *http.Response { + assert.Equal(t, req.URL.String(), "http://test.surreal/rpc") + + return &http.Response{ + StatusCode: 400, + // Send response to be tested + Body: ioutil.NopCloser(bytes.NewBufferString(`OK`)), + // Must be set to non-nil value or it panics + Header: make(http.Header), + } + }) + + p := NewConnectionParams{} + httpEngine := (NewHttp(p)).(*Http) + httpEngine.SetHttpClient(httpClient) + + resp, err := httpEngine.MakeRequest(http.MethodGet, "http://test.surreal/rpc", nil) + assert.Error(t, err, "should return error for status code 400") + + fmt.Println(resp) +} diff --git a/internal/rpc/rpc.go b/internal/connection/model.go similarity index 59% rename from internal/rpc/rpc.go rename to internal/connection/model.go index 38285fc..98a24d6 100644 --- a/internal/rpc/rpc.go +++ b/internal/connection/model.go @@ -1,4 +1,4 @@ -package rpc +package connection // RPCError represents a JSON-RPC error type RPCError struct { @@ -31,3 +31,28 @@ type RPCNotification struct { Method string `json:"method,omitempty" msgpack:"method,omitempty"` Params []interface{} `json:"params,omitempty" msgpack:"params,omitempty"` } + +type RPCFunction string + +var ( + FUse RPCFunction = "use" + Info RPCFunction = "info" + SignUp RPCFunction = "signup" + SignIn RPCFunction = "signin" + Authenticate RPCFunction = "authenticate" + Invalidate RPCFunction = "invalidate" + Let RPCFunction = "let" + Unset RPCFunction = "unset" + Live RPCFunction = "live" + Kill RPCFunction = "kill" + Query RPCFunction = "query" + Select RPCFunction = "select" + Create RPCFunction = "create" + Insert RPCFunction = "insert" + Update RPCFunction = "update" + Upsert RPCFunction = "upsert" + Relate RPCFunction = "relate" + Merge RPCFunction = "merge" + Patch RPCFunction = "patch" + Delete RPCFunction = "delete" +) diff --git a/pkg/model/notification.go b/internal/connection/notification.go similarity index 93% rename from pkg/model/notification.go rename to internal/connection/notification.go index cf2c280..301ba08 100644 --- a/pkg/model/notification.go +++ b/internal/connection/notification.go @@ -1,4 +1,4 @@ -package model +package connection type Notification struct { ID string `json:"id"` diff --git a/pkg/conn/gorilla/gorilla.go b/internal/connection/ws.go similarity index 87% rename from pkg/conn/gorilla/gorilla.go rename to internal/connection/ws.go index 4dcbf2c..80b6b50 100644 --- a/pkg/conn/gorilla/gorilla.go +++ b/internal/connection/ws.go @@ -1,9 +1,11 @@ -package gorilla +package connection import ( "encoding/json" "errors" "fmt" + "github.com/surrealdb/surrealdb.go/internal/rand" + "github.com/surrealdb/surrealdb.go/pkg/logger" "io" "net" "reflect" @@ -11,13 +13,7 @@ import ( "sync" "time" - "github.com/surrealdb/surrealdb.go/pkg/model" - gorilla "github.com/gorilla/websocket" - "github.com/surrealdb/surrealdb.go/internal/rpc" - "github.com/surrealdb/surrealdb.go/pkg/conn" - "github.com/surrealdb/surrealdb.go/pkg/logger" - "github.com/surrealdb/surrealdb.go/pkg/rand" ) const ( @@ -32,33 +28,41 @@ const ( type Option func(ws *WebSocket) error type WebSocket struct { + BaseConnection + Conn *gorilla.Conn connLock sync.Mutex Timeout time.Duration Option []Option logger logger.Logger - responseChannels map[string]chan rpc.RPCResponse + responseChannels map[string]chan RPCResponse responseChannelsLock sync.RWMutex - notificationChannels map[string]chan model.Notification + notificationChannels map[string]chan Notification notificationChannelsLock sync.RWMutex closeChan chan int closeError error } -func Create() *WebSocket { +func NewWebSocket(p NewConnectionParams) *WebSocket { return &WebSocket{ + BaseConnection: BaseConnection{ + encode: p.Encoder, + decode: p.Decoder, + }, + logger: p.Logger, + Conn: nil, closeChan: make(chan int), - responseChannels: make(map[string]chan rpc.RPCResponse), - notificationChannels: make(map[string]chan model.Notification), + responseChannels: make(map[string]chan RPCResponse), + notificationChannels: make(map[string]chan Notification), Timeout: DefaultTimeout * time.Second, } } -func (ws *WebSocket) Connect(url string) (conn.Connection, error) { +func (ws *WebSocket) Connect(url string) (Connection, error) { dialer := gorilla.DefaultDialer dialer.EnableCompression = true @@ -118,7 +122,7 @@ func (ws *WebSocket) Close() error { return ws.Conn.Close() } -func (ws *WebSocket) LiveNotifications(liveQueryID string) (chan model.Notification, error) { +func (ws *WebSocket) LiveNotifications(liveQueryID string) (chan Notification, error) { c, err := ws.createNotificationChannel(liveQueryID) if err != nil { ws.logger.Error(err.Error()) @@ -132,7 +136,7 @@ var ( ErrInvalidResponseID = errors.New("invalid response id") ) -func (ws *WebSocket) createResponseChannel(id string) (chan rpc.RPCResponse, error) { +func (ws *WebSocket) createResponseChannel(id string) (chan RPCResponse, error) { ws.responseChannelsLock.Lock() defer ws.responseChannelsLock.Unlock() @@ -140,13 +144,13 @@ func (ws *WebSocket) createResponseChannel(id string) (chan rpc.RPCResponse, err return nil, fmt.Errorf("%w: %v", ErrIDInUse, id) } - ch := make(chan rpc.RPCResponse) + ch := make(chan RPCResponse) ws.responseChannels[id] = ch return ch, nil } -func (ws *WebSocket) createNotificationChannel(liveQueryID string) (chan model.Notification, error) { +func (ws *WebSocket) createNotificationChannel(liveQueryID string) (chan Notification, error) { ws.notificationChannelsLock.Lock() defer ws.notificationChannelsLock.Unlock() @@ -154,7 +158,7 @@ func (ws *WebSocket) createNotificationChannel(liveQueryID string) (chan model.N return nil, fmt.Errorf("%w: %v", ErrIDInUse, liveQueryID) } - ch := make(chan model.Notification) + ch := make(chan Notification) ws.notificationChannels[liveQueryID] = ch return ch, nil @@ -166,14 +170,14 @@ func (ws *WebSocket) removeResponseChannel(id string) { delete(ws.responseChannels, id) } -func (ws *WebSocket) getResponseChannel(id string) (chan rpc.RPCResponse, bool) { +func (ws *WebSocket) getResponseChannel(id string) (chan RPCResponse, bool) { ws.responseChannelsLock.RLock() defer ws.responseChannelsLock.RUnlock() ch, ok := ws.responseChannels[id] return ch, ok } -func (ws *WebSocket) getLiveChannel(id string) (chan model.Notification, bool) { +func (ws *WebSocket) getLiveChannel(id string) (chan Notification, bool) { ws.notificationChannelsLock.RLock() defer ws.notificationChannelsLock.RUnlock() ch, ok := ws.notificationChannels[id] @@ -188,7 +192,7 @@ func (ws *WebSocket) Send(method string, params []interface{}) (interface{}, err } id := rand.String(RequestIDLength) - request := &rpc.RPCRequest{ + request := &RPCRequest{ ID: id, Method: method, Params: params, @@ -248,7 +252,7 @@ func (ws *WebSocket) initialize() { case <-ws.closeChan: return default: - var res rpc.RPCResponse + var res RPCResponse err := ws.read(&res) if err != nil { shouldExit := ws.handleError(err) @@ -277,7 +281,7 @@ func (ws *WebSocket) handleError(err error) bool { return false } -func (ws *WebSocket) handleResponse(res rpc.RPCResponse) { +func (ws *WebSocket) handleResponse(res RPCResponse) { if res.ID != nil && res.ID != "" { // Try to resolve message as response to query responseChan, ok := ws.getResponseChannel(fmt.Sprintf("%v", res.ID)) @@ -298,7 +302,7 @@ func (ws *WebSocket) handleResponse(res rpc.RPCResponse) { ws.logger.Error(err.Error(), "result", fmt.Sprint(res.Result)) return } - var notification model.Notification + var notification Notification err := unmarshalMapToStruct(mappedRes, ¬ification) if err != nil { ws.logger.Error(err.Error(), "result", fmt.Sprint(res.Result)) diff --git a/internal/connection/ws_test.go b/internal/connection/ws_test.go new file mode 100644 index 0000000..c2b0392 --- /dev/null +++ b/internal/connection/ws_test.go @@ -0,0 +1 @@ +package connection diff --git a/internal/mock/mock.go b/internal/mock/mock.go index a3b7e3e..746326f 100644 --- a/internal/mock/mock.go +++ b/internal/mock/mock.go @@ -2,9 +2,7 @@ package mock import ( "errors" - - "github.com/surrealdb/surrealdb.go/pkg/conn" - "github.com/surrealdb/surrealdb.go/pkg/model" + conn "github.com/surrealdb/surrealdb.go/internal/connection" ) type ws struct { @@ -22,7 +20,7 @@ func (w *ws) Close() error { return nil } -func (w *ws) LiveNotifications(id string) (chan model.Notification, error) { +func (w *ws) LiveNotifications(id string) (chan conn.Notification, error) { return nil, errors.New("live queries are unimplemented for mocks") } diff --git a/pkg/rand/rand.go b/internal/rand/rand.go similarity index 100% rename from pkg/rand/rand.go rename to internal/rand/rand.go diff --git a/pkg/util/util.go b/internal/util/util.go similarity index 100% rename from pkg/util/util.go rename to internal/util/util.go diff --git a/pkg/conn/conn.go b/pkg/conn/conn.go deleted file mode 100644 index dc619ce..0000000 --- a/pkg/conn/conn.go +++ /dev/null @@ -1,10 +0,0 @@ -package conn - -import "github.com/surrealdb/surrealdb.go/pkg/model" - -type Connection interface { - Connect(url string) (Connection, error) - Send(method string, params []interface{}) (interface{}, error) - Close() error - LiveNotifications(id string) (chan model.Notification, error) -} diff --git a/pkg/logger/slog/slog.go b/pkg/logger/slog.go similarity index 97% rename from pkg/logger/slog/slog.go rename to pkg/logger/slog.go index a8e58e6..2734a08 100644 --- a/pkg/logger/slog/slog.go +++ b/pkg/logger/slog.go @@ -1,4 +1,4 @@ -package slog +package logger import ( "log/slog" diff --git a/pkg/logger/slog/slog_test.go b/pkg/logger/slog_test.go similarity index 94% rename from pkg/logger/slog/slog_test.go rename to pkg/logger/slog_test.go index f86c01b..05174cd 100644 --- a/pkg/logger/slog/slog_test.go +++ b/pkg/logger/slog_test.go @@ -1,4 +1,4 @@ -package slog_test +package logger_test import ( "bytes" @@ -10,7 +10,6 @@ import ( rawslog "log/slog" "github.com/stretchr/testify/require" - "github.com/surrealdb/surrealdb.go/pkg/logger/slog" ) type testMethod struct { @@ -37,7 +36,7 @@ func TestLogger(t *testing.T) { // level needs to be set to debug for log all handler := rawslog.NewJSONHandler(buffer, &rawslog.HandlerOptions{Level: rawslog.LevelDebug}) - logger := slog.New(handler) + logger := New(handler) testMethods := []testMethod{ {fn: logger.Error, level: rawslog.LevelError}, diff --git a/pkg/marshal/marshal.go b/pkg/marshal/marshal.go index 6e3e1ae..06752cd 100644 --- a/pkg/marshal/marshal.go +++ b/pkg/marshal/marshal.go @@ -5,10 +5,9 @@ import ( "encoding/json" "errors" "fmt" - "reflect" - + "github.com/surrealdb/surrealdb.go/internal/util" "github.com/surrealdb/surrealdb.go/pkg/constants" - "github.com/surrealdb/surrealdb.go/pkg/util" + "reflect" ) const StatusOK = "OK" diff --git a/pkg/model/cbor.go b/pkg/model/cbor.go new file mode 100644 index 0000000..5b98d99 --- /dev/null +++ b/pkg/model/cbor.go @@ -0,0 +1,127 @@ +package model + +import ( + "github.com/fxamacker/cbor/v2" + "github.com/surrealdb/surrealdb.go/internal/connection" + "reflect" +) + +type CustomCBORTag uint64 + +var ( + DateTimeStringTag CustomCBORTag = 0 + NoneTag CustomCBORTag = 6 + TableNameTag CustomCBORTag = 7 + RecordIDTag CustomCBORTag = 8 + UUIDStringTag CustomCBORTag = 9 + DecimalStringTag CustomCBORTag = 10 + DateTimeCompactString CustomCBORTag = 12 + DurationStringTag CustomCBORTag = 13 + DurationCompactStringTag CustomCBORTag = 14 + BinaryUUIDTag CustomCBORTag = 37 + GeometryPointTag CustomCBORTag = 88 + GeometryLineTag CustomCBORTag = 89 + GeometryPolygonTag CustomCBORTag = 90 + GeometryMultiPointTag CustomCBORTag = 91 + GeometryMultiLineTag CustomCBORTag = 92 + GeometryMultiPolygonTag CustomCBORTag = 93 + GeometryCollectionTag CustomCBORTag = 94 +) + +func registerCborTags() cbor.TagSet { + customTags := map[CustomCBORTag]interface{}{ + GeometryPointTag: GeometryPoint{}, + GeometryLineTag: GeometryLine{}, + GeometryPolygonTag: GeometryPolygon{}, + GeometryMultiPointTag: GeometryMultiPoint{}, + GeometryMultiLineTag: GeometryMultiLine{}, + GeometryMultiPolygonTag: GeometryMultiPolygon{}, + GeometryCollectionTag: GeometryCollection{}, + + TableNameTag: Table(""), + UUIDStringTag: UUID(""), + BinaryUUIDTag: UUIDBin{}, + } + + tags := cbor.NewTagSet() + for tag, customType := range customTags { + err := tags.Add( + cbor.TagOptions{EncTag: cbor.EncTagRequired, DecTag: cbor.DecTagRequired}, + reflect.TypeOf(customType), + uint64(tag), + ) + if err != nil { + panic(err) + } + } + + return tags +} + +func GetCborEncoder() connection.Encoder { + tags := registerCborTags() + em, err := cbor.EncOptions{}.EncModeWithTags(tags) + if err != nil { + panic(err) + } + + return em.Marshal +} + +func GetCborDecoder() connection.Decoder { + tags := registerCborTags() + dm, err := cbor.DecOptions{}.DecModeWithTags(tags) + if err != nil { + panic(err) + } + + return dm.Unmarshal +} + +func (gp *GeometryPoint) MarshalCBOR() ([]byte, error) { + enc := GetCborEncoder() + + return enc(cbor.Tag{ + Number: uint64(GeometryPointTag), + Content: gp.GetCoordinates(), + }) +} + +func (g *GeometryPoint) UnmarshalCBOR(data []byte) error { + dec := GetCborDecoder() + + var temp [2]float64 + err := dec(data, &temp) + if err != nil { + return err + } + + g.Latitude = temp[0] + g.Longitude = temp[1] + + return nil +} + +func (r *RecordID) MarshalCBOR() ([]byte, error) { + enc := GetCborEncoder() + + return enc(cbor.Tag{ + Number: uint64(RecordIDTag), + Content: []interface{}{r.ID, r.Table}, + }) +} + +func (r *RecordID) UnmarshalCBOR(data []byte) error { + dec := GetCborDecoder() + + var temp []interface{} + err := dec(data, &temp) + if err != nil { + return err + } + + r.Table = temp[0].(string) + r.ID = temp[1] + + return nil +} diff --git a/pkg/model/cbor_test.go b/pkg/model/cbor_test.go new file mode 100644 index 0000000..16ac501 --- /dev/null +++ b/pkg/model/cbor_test.go @@ -0,0 +1,88 @@ +package model + +import ( + "fmt" + "github.com/fxamacker/cbor/v2" + "github.com/stretchr/testify/assert" + "testing" +) + +func TestForGeometryPoint(t *testing.T) { + em := GetCborEncoder() + dm := GetCborDecoder() + + gp := NewGeometryPoint(12.23, 45.65) + encoded, err := em(gp) + assert.Nil(t, err, "Should not encounter an error while encoding") + + decoded := GeometryPoint{} + err = dm(encoded, &decoded) + + assert.Nil(t, err, "Should not encounter an error while decoding") + assert.Equal(t, gp, decoded) +} + +func TestForGeometryLine(t *testing.T) { + em := GetCborEncoder() + dm := GetCborDecoder() + + gp1 := NewGeometryPoint(12.23, 45.65) + gp2 := NewGeometryPoint(23.34, 56.75) + gp3 := NewGeometryPoint(33.45, 86.99) + + gl := GeometryLine{gp1, gp2, gp3} + + encoded, err := em(gl) + assert.Nil(t, err, "Should not encounter an error while encoding") + + decoded := GeometryLine{} + err = dm(encoded, &decoded) + assert.Nil(t, err, "Should not encounter an error while decoding") + assert.Equal(t, gl, decoded) +} + +func TestForGeometryPolygon(t *testing.T) { + em := GetCborEncoder() + dm := GetCborDecoder() + + gl1 := GeometryLine{NewGeometryPoint(12.23, 45.65), NewGeometryPoint(23.33, 44.44)} + gl2 := GeometryLine{GeometryPoint{12.23, 45.65}, GeometryPoint{23.33, 44.44}} + gl3 := GeometryLine{NewGeometryPoint(12.23, 45.65), NewGeometryPoint(23.33, 44.44)} + gp := GeometryPolygon{gl1, gl2, gl3} + + encoded, err := em(gp) + assert.Nil(t, err, "Should not encounter an error while encoding") + + decoded := GeometryPolygon{} + err = dm(encoded, &decoded) + + assert.Nil(t, err, "Should not encounter an error while decoding") + assert.Equal(t, gp, decoded) +} + +func TestForRequestPayload(t *testing.T) { + em := GetCborEncoder() + + params := []interface{}{ + "SELECT marketing, count() FROM $tb GROUP BY marketing", + map[string]interface{}{ + "tb": Table("person"), + "line": GeometryLine{NewGeometryPoint(11.11, 22.22), NewGeometryPoint(33.33, 44.44)}, + }, + } + + requestPayload := map[string]interface{}{ + "id": "2", + "method": "query", + "params": params, + } + + encoded, err := em(requestPayload) + + assert.Nil(t, err, "should not return an error while encoding payload") + + diagStr, err := cbor.Diagnose(encoded) + assert.Nil(t, err, "should not return an error while diagnosing payload") + + fmt.Println(diagStr) +} diff --git a/pkg/model/model.go b/pkg/model/model.go new file mode 100644 index 0000000..9eb323d --- /dev/null +++ b/pkg/model/model.go @@ -0,0 +1,39 @@ +package model + +type GeometryPoint struct { + Latitude float64 + Longitude float64 +} + +func NewGeometryPoint(latitude float64, longitude float64) GeometryPoint { + return GeometryPoint{ + Latitude: latitude, Longitude: longitude, + } +} + +func (g *GeometryPoint) GetCoordinates() [2]float64 { + return [2]float64{g.Latitude, g.Longitude} +} + +type GeometryLine []GeometryPoint + +type GeometryPolygon []GeometryLine + +type GeometryMultiPoint []GeometryPoint + +type GeometryMultiLine []GeometryLine + +type GeometryMultiPolygon []GeometryPolygon + +type GeometryCollection []any + +type Table string + +type UUID string + +type UUIDBin []byte + +type RecordID struct { + Table string + ID interface{} +}