diff --git a/db_test.go b/db_test.go index 874c567..6fed8c0 100644 --- a/db_test.go +++ b/db_test.go @@ -19,6 +19,7 @@ import ( "github.com/stretchr/testify/suite" "github.com/surrealdb/surrealdb.go" "github.com/surrealdb/surrealdb.go/pkg/conn/gorilla" + "github.com/surrealdb/surrealdb.go/pkg/conn/nhooyr" "github.com/surrealdb/surrealdb.go/pkg/constants" "github.com/surrealdb/surrealdb.go/pkg/conn" @@ -56,17 +57,31 @@ func TestSurrealDBSuite(t *testing.T) { SurrealDBSuite := new(SurrealDBTestSuite) SurrealDBSuite.connImplementations = make(map[string]conn.Connection) + // // Gorilla // Without options - buff := bytes.NewBufferString("") - logData := createLogger(t, buff) - SurrealDBSuite.connImplementations["gorilla"] = gorilla.Create().Logger(logData) - SurrealDBSuite.logBuffer = buff + gbuff := bytes.NewBufferString("") + glogData := createLogger(t, gbuff) + SurrealDBSuite.connImplementations["gorilla"] = gorilla.Create().Logger(glogData) + SurrealDBSuite.logBuffer = gbuff // With options - buffOpt := bytes.NewBufferString("") - logDataOpt := createLogger(t, buff) - SurrealDBSuite.connImplementations["gorilla_opt"] = gorilla.Create().SetTimeOut(time.Minute).SetCompression(true).Logger(logDataOpt) - SurrealDBSuite.logBuffer = buffOpt + gbuffOpt := bytes.NewBufferString("") + glogDataOpt := createLogger(t, gbuff) + SurrealDBSuite.connImplementations["gorilla_opt"] = gorilla.Create().SetTimeOut(time.Minute).SetCompression(true).Logger(glogDataOpt) + SurrealDBSuite.logBuffer = gbuffOpt + + // // Nhooyr + // Without options + nbuff := bytes.NewBufferString("") + nlogData := createLogger(t, gbuff) + SurrealDBSuite.connImplementations["nhooyr"] = nhooyr.Create().Logger(nlogData) + SurrealDBSuite.logBuffer = nbuff + + // With options + nbuffOpt := bytes.NewBufferString("") + nlogDataOpt := createLogger(t, gbuff) + SurrealDBSuite.connImplementations["nhooyr_opt"] = nhooyr.Create().Logger(nlogDataOpt) + SurrealDBSuite.logBuffer = nbuffOpt RunWsMap(t, SurrealDBSuite) } diff --git a/go.mod b/go.mod index 43ef31e..6979fa3 100644 --- a/go.mod +++ b/go.mod @@ -5,6 +5,7 @@ go 1.20 require ( github.com/gorilla/websocket v1.5.0 github.com/stretchr/testify v1.8.4 + nhooyr.io/websocket v1.8.10 ) require ( diff --git a/go.sum b/go.sum index ab9f4d8..922228c 100644 --- a/go.sum +++ b/go.sum @@ -23,3 +23,5 @@ gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntN gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +nhooyr.io/websocket v1.8.10 h1:mv4p+MnGrLDcPlBoWsvPP7XCzTYMXP9F9eIGoKbgx7Q= +nhooyr.io/websocket v1.8.10/go.mod h1:rN9OFWIUwuxg4fR5tELlYC04bXYowCP9GX47ivo2l+c= diff --git a/pkg/conn/nhooyr/nhooyr.go b/pkg/conn/nhooyr/nhooyr.go new file mode 100644 index 0000000..09eae2b --- /dev/null +++ b/pkg/conn/nhooyr/nhooyr.go @@ -0,0 +1,348 @@ +package nhooyr + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "net" + "net/http" + "reflect" + "strconv" + "sync" + "time" + + "github.com/surrealdb/surrealdb.go/pkg/model" + nhooyr "nhooyr.io/websocket" + + "github.com/surrealdb/surrealdb.go/internal/rpc" + "github.com/surrealdb/surrealdb.go/pkg/conn" + "github.com/surrealdb/surrealdb.go/pkg/logger" + "github.com/surrealdb/surrealdb.go/pkg/rand" +) + +const ( + // RequestIDLength size of id sent on WS request + RequestIDLength = 16 + // CloseMessageCode identifier the message id for a close request + CloseMessageCode = 1000 + // DefaultTimeout timeout in seconds + DefaultTimeout = 30 +) + +type Option func(ws *WebSocket) error + +type WebSocket struct { + Conn *nhooyr.Conn + connLock sync.Mutex + Timeout time.Duration + Option []Option + logger logger.Logger + + responseChannels map[string]chan rpc.RPCResponse + responseChannelsLock sync.RWMutex + + notificationChannels map[string]chan model.Notification + notificationChannelsLock sync.RWMutex + + close chan int +} + +func Create() *WebSocket { + return &WebSocket{ + Conn: nil, + close: make(chan int), + responseChannels: make(map[string]chan rpc.RPCResponse), + notificationChannels: make(map[string]chan model.Notification), + Timeout: DefaultTimeout * time.Second, + } +} + +func (ws *WebSocket) Connect(url string) (conn.Connection, error) { + ctx, cancel := context.WithTimeout(context.Background(), time.Minute) + defer cancel() + + connection, resp, err := nhooyr.Dial(ctx, url, nil) + if err != nil { + return nil, err + } + + if resp.StatusCode != http.StatusSwitchingProtocols { + resp.Body.Close() + return nil, fmt.Errorf("unexpected status code: %d", resp.StatusCode) + } + + ws.Conn = connection + + for _, option := range ws.Option { + if err := option(ws); err != nil { + return ws, err + } + } + + ws.initialize() + return ws, nil +} + +func (ws *WebSocket) SetTimeOut(timeout time.Duration) *WebSocket { + ws.Option = append(ws.Option, func(ws *WebSocket) error { + ws.Timeout = timeout + return nil + }) + return ws +} + +// If path is empty it will use os.stdout/os.stderr +func (ws *WebSocket) Logger(logData logger.Logger) *WebSocket { + ws.logger = logData + return ws +} + +func (ws *WebSocket) RawLogger(logData logger.Logger) *WebSocket { + ws.logger = logData + return ws +} + +func (ws *WebSocket) Close() error { + ws.connLock.Lock() + defer ws.connLock.Unlock() + close(ws.close) + + return ws.Conn.Close(nhooyr.StatusNormalClosure, "") +} + +func (ws *WebSocket) LiveNotifications(liveQueryID string) (chan model.Notification, error) { + c, err := ws.createNotificationChannel(liveQueryID) + if err != nil { + ws.logger.Error(err.Error()) + } + return c, err +} + +var ( + ErrIDInUse = errors.New("id already in use") + ErrTimeout = errors.New("timeout") + ErrInvalidResponseID = errors.New("invalid response id") +) + +func (ws *WebSocket) createResponseChannel(id string) (chan rpc.RPCResponse, error) { + ws.responseChannelsLock.Lock() + defer ws.responseChannelsLock.Unlock() + + if _, ok := ws.responseChannels[id]; ok { + return nil, fmt.Errorf("%w: %v", ErrIDInUse, id) + } + + ch := make(chan rpc.RPCResponse) + ws.responseChannels[id] = ch + + return ch, nil +} + +func (ws *WebSocket) createNotificationChannel(liveQueryID string) (chan model.Notification, error) { + ws.notificationChannelsLock.Lock() + defer ws.notificationChannelsLock.Unlock() + + if _, ok := ws.notificationChannels[liveQueryID]; ok { + return nil, fmt.Errorf("%w: %v", ErrIDInUse, liveQueryID) + } + + ch := make(chan model.Notification) + ws.notificationChannels[liveQueryID] = ch + + return ch, nil +} + +func (ws *WebSocket) removeResponseChannel(id string) { + ws.responseChannelsLock.Lock() + defer ws.responseChannelsLock.Unlock() + delete(ws.responseChannels, id) +} + +func (ws *WebSocket) getResponseChannel(id string) (chan rpc.RPCResponse, bool) { + ws.responseChannelsLock.RLock() + defer ws.responseChannelsLock.RUnlock() + ch, ok := ws.responseChannels[id] + return ch, ok +} + +func (ws *WebSocket) getLiveChannel(id string) (chan model.Notification, bool) { + ws.notificationChannelsLock.RLock() + defer ws.notificationChannelsLock.RUnlock() + ch, ok := ws.notificationChannels[id] + return ch, ok +} + +func (ws *WebSocket) Send(method string, params []interface{}) (interface{}, error) { + id := rand.String(RequestIDLength) + request := &rpc.RPCRequest{ + ID: id, + Method: method, + Params: params, + } + + responseChan, err := ws.createResponseChannel(id) + if err != nil { + return nil, err + } + defer ws.removeResponseChannel(id) + + if err := ws.write(request); err != nil { + return nil, err + } + + timeout := time.After(ws.Timeout) + + select { + case <-timeout: + return nil, ErrTimeout + case res, open := <-responseChan: + if !open { + return nil, errors.New("channel closed") + } + if res.ID != id { + return nil, ErrInvalidResponseID + } + if res.Error != nil { + return nil, res.Error + } + return res.Result, nil + } +} + +func (ws *WebSocket) read(v interface{}) error { + _, data, err := ws.Conn.Read(context.Background()) + if err != nil { + return err + } + return json.Unmarshal(data, v) +} + +func (ws *WebSocket) write(v interface{}) error { + data, err := json.Marshal(v) + if err != nil { + return err + } + + ws.connLock.Lock() + defer ws.connLock.Unlock() + return ws.Conn.Write(context.Background(), nhooyr.MessageText, data) +} + +func (ws *WebSocket) initialize() { + go func() { + for { + select { + case <-ws.close: + return + default: + var res rpc.RPCResponse + err := ws.read(&res) + if err != nil { + if errors.Is(err, net.ErrClosed) { + break + } + ws.logger.Error(err.Error()) + continue + } + go ws.handleResponse(res) + } + } + }() +} + +func (ws *WebSocket) handleResponse(res rpc.RPCResponse) { + if res.ID != nil && res.ID != "" { + // Try to resolve message as response to query + responseChan, ok := ws.getResponseChannel(fmt.Sprintf("%v", res.ID)) + if !ok { + err := fmt.Errorf("unavailable ResponseChannel %+v", res.ID) + ws.logger.Error(err.Error()) + return + } + defer close(responseChan) + responseChan <- res + } else { + // Try to resolve response as live query notification + mappedRes, _ := res.Result.(map[string]interface{}) + resolvedID, ok := mappedRes["id"] + if !ok { + err := fmt.Errorf("response did not contain an 'id' field") + + ws.logger.Error(err.Error(), "result", fmt.Sprint(res.Result)) + return + } + var notification model.Notification + err := unmarshalMapToStruct(mappedRes, ¬ification) + if err != nil { + ws.logger.Error(err.Error(), "result", fmt.Sprint(res.Result)) + return + } + LiveNotificationChan, ok := ws.getLiveChannel(notification.ID) + if !ok { + err := fmt.Errorf("unavailable ResponseChannel %+v", resolvedID) + ws.logger.Error(err.Error(), "result", fmt.Sprint(res.Result)) + return + } + LiveNotificationChan <- notification + } +} + +func unmarshalMapToStruct(data map[string]interface{}, outStruct interface{}) error { + outValue := reflect.ValueOf(outStruct) + if outValue.Kind() != reflect.Ptr || outValue.Elem().Kind() != reflect.Struct { + return fmt.Errorf("outStruct must be a pointer to a struct") + } + + structValue := outValue.Elem() + structType := structValue.Type() + + for i := 0; i < structValue.NumField(); i++ { + field := structType.Field(i) + fieldName := field.Name + jsonTag := field.Tag.Get("json") + if jsonTag != "" { + fieldName = jsonTag + } + mapValue, ok := data[fieldName] + if !ok { + return fmt.Errorf("missing field in map: %s", fieldName) + } + + fieldValue := structValue.Field(i) + if !fieldValue.CanSet() { + return fmt.Errorf("cannot set field: %s", fieldName) + } + + if mapValue == nil { + // Handle nil values appropriately for your struct fields + // For simplicity, we skip nil values in this example + continue + } + + // Type conversion based on the field type + switch fieldValue.Kind() { + case reflect.String: + fieldValue.SetString(fmt.Sprint(mapValue)) + case reflect.Int: + intVal, err := strconv.Atoi(fmt.Sprint(mapValue)) + if err != nil { + return err + } + fieldValue.SetInt(int64(intVal)) + case reflect.Bool: + boolVal, err := strconv.ParseBool(fmt.Sprint(mapValue)) + if err != nil { + return err + } + fieldValue.SetBool(boolVal) + case reflect.Interface: + fieldValue.Set(reflect.ValueOf(mapValue)) + // Add cases for other types as needed + default: + return fmt.Errorf("unsupported field type: %s", fieldName) + } + } + + return nil +}