diff --git a/websocket.go b/websocket.go index cecf03b..886fc67 100644 --- a/websocket.go +++ b/websocket.go @@ -70,7 +70,7 @@ func ReloadRemoteDatabase(url string) error { return nil } -func Shutdown(database Database, filename string) error { +func StoreDatabase(database Database, filename string) error { err := Dump(database, filename) if err != nil { return err @@ -103,6 +103,12 @@ func NewContainer(filename string) (Database, *http.ServeMux, error) { log.Println(err) } }) + mux.HandleFunc("/store", func(w http.ResponseWriter, r *http.Request) { + err := StoreDatabase(database, filename) + if err != nil { + log.Println(err) + } + }) mux.HandleFunc("/database", DatabaseHandleFunc(database)) return database, mux, nil @@ -119,7 +125,7 @@ func main() { signal.Notify(signals, syscall.SIGINT, syscall.SIGTERM) go func() { <-signals - err := Shutdown(database, filename) + err := StoreDatabase(database, filename) if err != nil { log.Fatalln(err) } diff --git a/websocket_test.go b/websocket_test.go index d85d81b..dcd531b 100644 --- a/websocket_test.go +++ b/websocket_test.go @@ -103,21 +103,46 @@ func TestWebsocket_Client(t *testing.T) { assert.Check(t, strings.Contains(string(body), "Database")) } -func TestWebsocket_Shutdown(t *testing.T) { - // Given database with data - database := NewDatabase() - database.Update("username", "john") +func TestWebsocket_Store(t *testing.T) { + // Given empty database server filename := "test.txt" + _, mux, err := NewContainer(filename) + assert.NilError(t, err) + server := httptest.NewServer(mux) + ws, _, err := websocket.DefaultDialer.Dial(strings.Replace(server.URL, "http", "ws", 1)+"/database", nil) + assert.NilError(t, err) + + // When we add value + err = ws.WriteMessage(websocket.TextMessage, []byte("update username john")) + assert.NilError(t, err) + + // Then value is there + err = ws.WriteMessage(websocket.TextMessage, []byte("select username")) + assert.NilError(t, err) + _, message, err := ws.ReadMessage() + assert.NilError(t, err) + assert.Equal(t, string(message), "john") - // When we shut it down - err := Shutdown(database, filename) + // When we store dump + request, err := http.NewRequest("GET", server.URL+"/store", nil) + assert.NilError(t, err) + client := &http.Client{} + _, err = client.Do(request) assert.NilError(t, err) - // Then state is saved - database = NewDatabase() - err = Load(database, filename) + // And when we load it again + _, mux, err = NewContainer("test.txt") + assert.NilError(t, err) + server = httptest.NewServer(mux) + ws, _, err = websocket.DefaultDialer.Dial(strings.Replace(server.URL, "http", "ws", 1)+"/database", nil) assert.NilError(t, err) - assert.DeepEqual(t, database.Select("username"), []string{"john"}) + + // Then value is there + err = ws.WriteMessage(websocket.TextMessage, []byte("select username")) + assert.NilError(t, err) + _, message, err = ws.ReadMessage() + assert.NilError(t, err) + assert.Equal(t, string(message), "john") // Clean up err = os.Remove(filename)