diff --git a/stream.go b/stream.go index 5970877..2b0a150 100644 --- a/stream.go +++ b/stream.go @@ -7,18 +7,28 @@ import ( "io/ioutil" "log" "net/http" - "sync" "time" ) +const ( + streamErrChanLen = 10 + minStreamBackoff = 1 * time.Millisecond +) + +type streamConnectFunc func(string) (io.ReadCloser, error) + // Stream handles a connection for receiving Server Sent Events. // It will try and reconnect if the connection is lost, respecting both // received retry delays and event id's. type Stream struct { - c *http.Client - req *http.Request + // connectFunc is the function used to connect to the event stream + connectFunc streamConnectFunc + lastEventId string retry time.Duration + // closeChan is used to notify the streamer goroutine that the stream has been closed + closeChan chan struct{} + // Events emits the events received by the stream Events chan Event // Errors emits any errors encountered while reading events from the stream. @@ -28,10 +38,6 @@ type Stream struct { Errors chan error // Logger is a logger that, when set, will be used for logging debug messages Logger *log.Logger - // isClosed is a marker that the stream is/should be closed - isClosed bool - // isClosedMutex is a mutex protecting concurrent read/write access of isClosed - isClosedMutex sync.RWMutex } type SubscriptionError struct { @@ -62,135 +68,168 @@ func SubscribeWithRequest(lastEventId string, request *http.Request) (*Stream, e // SubscribeWith takes a http client and request providing customization over both headers and // control over the http client settings (timeouts, tls, etc) func SubscribeWith(lastEventId string, client *http.Client, request *http.Request) (*Stream, error) { + return subscribe(lastEventId, connectHTTP(lastEventId, client, request)) +} + +func subscribe(lastEventId string, connectFunc streamConnectFunc) (*Stream, error) { stream := &Stream{ - c: client, - req: request, + connectFunc: connectFunc, + lastEventId: lastEventId, - retry: time.Millisecond * 3000, - Events: make(chan Event), - Errors: make(chan error), - } - stream.c.CheckRedirect = checkRedirect + retry: 3000 * time.Millisecond, + closeChan: make(chan struct{}), - r, err := stream.connect() - if err != nil { - return nil, err + Events: make(chan Event), + Errors: make(chan error, streamErrChanLen), } - go stream.stream(r) + + go stream.run() + return stream, nil } // Close will close the stream. It is safe for concurrent access and can be called multiple times. func (stream *Stream) Close() { - if stream.isStreamClosed() { - return + select { + case <-stream.closeChan: + default: + close(stream.closeChan) } - - stream.markStreamClosed() - close(stream.Errors) - close(stream.Events) } -func (stream *Stream) isStreamClosed() bool { - stream.isClosedMutex.RLock() - defer stream.isClosedMutex.RUnlock() - return stream.isClosed -} +// connectHTTP connects to an event stream using the provided http request and client +func connectHTTP(lastEventID string, client *http.Client, request *http.Request) func(string) (io.ReadCloser, error) { + return func(lastEvtID string) (io.ReadCloser, error) { + client.CheckRedirect = checkRedirect -func (stream *Stream) markStreamClosed() { - stream.isClosedMutex.Lock() - defer stream.isClosedMutex.Unlock() - stream.isClosed = true -} + request.Header.Set("Cache-Control", "no-cache") + request.Header.Set("Accept", "text/event-stream") -// Go's http package doesn't copy headers across when it encounters -// redirects so we need to do that manually. -func checkRedirect(req *http.Request, via []*http.Request) error { - if len(via) >= 10 { - return errors.New("stopped after 10 redirects") - } - for k, vv := range via[0].Header { - for _, v := range vv { - req.Header.Add(k, v) + if lastEventID != "" { + request.Header.Set("Last-Event-ID", lastEventID) } - } - return nil -} -func (stream *Stream) connect() (r io.ReadCloser, err error) { - var resp *http.Response - stream.req.Header.Set("Cache-Control", "no-cache") - stream.req.Header.Set("Accept", "text/event-stream") - if len(stream.lastEventId) > 0 { - stream.req.Header.Set("Last-Event-ID", stream.lastEventId) - } - if resp, err = stream.c.Do(stream.req); err != nil { - return - } - if resp.StatusCode != 200 { - message, _ := ioutil.ReadAll(resp.Body) - err = SubscriptionError{ - Code: resp.StatusCode, - Message: string(message), + res, err := client.Do(request) + if err != nil { + return nil, err } + + if res.StatusCode != 200 { + message, _ := ioutil.ReadAll(res.Body) + res.Body.Close() + + return nil, SubscriptionError{ + Code: res.StatusCode, + Message: string(message), + } + } + + return res.Body, nil } - r = resp.Body - return } -func (stream *Stream) stream(r io.ReadCloser) { - defer r.Close() +func (stream *Stream) run() { + backoff := minStreamBackoff + +runLoop: + for { + reader, err := stream.connectFunc(stream.lastEventId) + if err != nil { + if stream.Logger != nil { + stream.Logger.Printf("Reconnecting in %0.4f secs\n", backoff.Seconds()) + } + time.Sleep(backoff) + backoff = capStreamBackoff(backoff * 2) + continue + } - // receives events until an error is encountered - stream.receiveEvents(r) + // We connected successfully so reset the backoff + backoff = minStreamBackoff + stream.stream(reader) + + // Check if we're supposed to close after the stream finishes. If not + // just make a new connection and start all over! + select { + case <-stream.closeChan: + break runLoop + default: + } + } - // tries to reconnect and start the stream again - stream.retryRestartStream() + close(stream.Events) + close(stream.Errors) } -func (stream *Stream) receiveEvents(r io.ReadCloser) { - dec := NewDecoder(r) +func (stream *Stream) stream(reader io.ReadCloser) { + // If we fail to stream for some reason make sure it gets closed + // when we're done. + defer reader.Close() + + dec := NewDecoder(reader) for { - ev, err := dec.Decode() - if stream.isStreamClosed() { + // Check if the stream was closed before every event read just in case + // we get stuck in a bad decode loop and never reach the point of + // actually sending an event. + select { + case <-stream.closeChan: return + default: } + + ev, err := dec.Decode() if err != nil { - stream.Errors <- err - return + stream.writeError(err) + continue } pub := ev.(*publication) if pub.Retry() > 0 { stream.retry = time.Duration(pub.Retry()) * time.Millisecond } - if len(pub.Id()) > 0 { + if pub.Id() != "" { stream.lastEventId = pub.Id() } - stream.Events <- ev - } -} -func (stream *Stream) retryRestartStream() { - backoff := stream.retry - for { - if stream.Logger != nil { - stream.Logger.Printf("Reconnecting in %0.4f secs\n", backoff.Seconds()) - } - time.Sleep(backoff) - if stream.isStreamClosed() { + // Send the event but also watch for a close in case nobody + // reads from the event channel and we get blocked writing. + select { + case <-stream.closeChan: + reader.Close() return + case stream.Events <- ev: } - // NOTE: because of the defer we're opening the new connection - // before closing the old one. Shouldn't be a problem in practice, - // but something to be aware of. - r, err := stream.connect() - if err == nil { - go stream.stream(r) - return + } + +} + +func (stream *Stream) writeError(err error) { + // Start dropping old errors if nobody is reading from the + // other end so we don't end up blocking on writing an error + // but still keep a short history of errors. + if len(stream.Errors) == streamErrChanLen { + <-stream.Errors + } + stream.Errors <- err +} + +func capStreamBackoff(backoff time.Duration) time.Duration { + if backoff > 10*time.Second { + return 10 * time.Second + } + return backoff +} + +// Go's http package doesn't copy headers across when it encounters +// redirects so we need to do that manually. +func checkRedirect(req *http.Request, via []*http.Request) error { + if len(via) >= 10 { + return errors.New("stopped after 10 redirects") + } + for k, vv := range via[0].Header { + for _, v := range vv { + req.Header.Add(k, v) } - stream.Errors <- err - backoff *= 2 } + return nil } diff --git a/stream_test.go b/stream_test.go index 019c527..1417611 100644 --- a/stream_test.go +++ b/stream_test.go @@ -1,7 +1,9 @@ package eventsource import ( + "bytes" "io" + "net/http" "net/http/httptest" "reflect" "testing" @@ -13,23 +15,60 @@ const ( timeToWaitForEvent = 100 * time.Millisecond ) +type testStream struct { + bytes.Buffer +} + +func (ts testStream) Close() error { + return nil +} + func TestStreamSubscribeEventsChan(t *testing.T) { - server := NewServer() - httpServer := httptest.NewServer(server.Handler(eventChannelName)) - // The server has to be closed before the httpServer is closed. - // Otherwise the httpServer has still an open connection and it can not close. + buff := &testStream{} + stream, err := subscribe("", func(lastEvtID string) (io.ReadCloser, error) { + return buff, nil + }) + if err != nil { + t.Fatalf("Could not create event stream: %s", err) + } + + expectedEvent := &publication{id: "123"} + buff.WriteString("" + + "id: 123\n" + + "data:\n" + + "\n", + ) + + select { + case receivedEvent := <-stream.Events: + if !reflect.DeepEqual(receivedEvent, expectedEvent) { + t.Errorf("got event %+v, want %+v", receivedEvent, expectedEvent) + } + case <-time.After(timeToWaitForEvent): + t.Error("Timed out waiting for event") + } +} + +func TestStreamSubscribeEventsChanHTTP(t *testing.T) { + httpServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Write the same event to the stream for every request + w.Write([]byte("" + + "id: 123\n" + + "data:\n" + + "\n", + )) + })) defer httpServer.Close() - defer server.Close() stream := mustSubscribe(t, httpServer.URL, "") + defer stream.Close() - publishedEvent := &publication{id: "123"} - server.Publish([]string{eventChannelName}, publishedEvent) + expectedEvent := &publication{id: "123"} select { case receivedEvent := <-stream.Events: - if !reflect.DeepEqual(receivedEvent, publishedEvent) { - t.Errorf("got event %+v, want %+v", receivedEvent, publishedEvent) + if !reflect.DeepEqual(receivedEvent, expectedEvent) { + t.Errorf("got event %+v, want %+v", receivedEvent, expectedEvent) } case <-time.After(timeToWaitForEvent): t.Error("Timed out waiting for event")