Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Check if a stream is closed just before writing to the stream channels #33

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
231 changes: 135 additions & 96 deletions stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,28 @@ import (
"io/ioutil"
"log"
"net/http"
"sync"
"time"
)

const (
streamErrChanLen = 10
minStreamBackoff = 1 * time.Millisecond
)

type streamConnectFunc func(string) (io.ReadCloser, error)

// Stream handles a connection for receiving Server Sent Events.
// It will try and reconnect if the connection is lost, respecting both
// received retry delays and event id's.
type Stream struct {
c *http.Client
req *http.Request
// connectFunc is the function used to connect to the event stream
connectFunc streamConnectFunc

lastEventId string
retry time.Duration
// closeChan is used to notify the streamer goroutine that the stream has been closed
closeChan chan struct{}

// Events emits the events received by the stream
Events chan Event
// Errors emits any errors encountered while reading events from the stream.
Expand All @@ -28,10 +38,6 @@ type Stream struct {
Errors chan error
// Logger is a logger that, when set, will be used for logging debug messages
Logger *log.Logger
// isClosed is a marker that the stream is/should be closed
isClosed bool
// isClosedMutex is a mutex protecting concurrent read/write access of isClosed
isClosedMutex sync.RWMutex
}

type SubscriptionError struct {
Expand Down Expand Up @@ -62,135 +68,168 @@ func SubscribeWithRequest(lastEventId string, request *http.Request) (*Stream, e
// SubscribeWith takes a http client and request providing customization over both headers and
// control over the http client settings (timeouts, tls, etc)
func SubscribeWith(lastEventId string, client *http.Client, request *http.Request) (*Stream, error) {
return subscribe(lastEventId, connectHTTP(lastEventId, client, request))
}

func subscribe(lastEventId string, connectFunc streamConnectFunc) (*Stream, error) {
stream := &Stream{
c: client,
req: request,
connectFunc: connectFunc,

lastEventId: lastEventId,
retry: time.Millisecond * 3000,
Events: make(chan Event),
Errors: make(chan error),
}
stream.c.CheckRedirect = checkRedirect
retry: 3000 * time.Millisecond,
closeChan: make(chan struct{}),

r, err := stream.connect()
if err != nil {
return nil, err
Events: make(chan Event),
Errors: make(chan error, streamErrChanLen),
}
go stream.stream(r)

go stream.run()

return stream, nil
}

// Close will close the stream. It is safe for concurrent access and can be called multiple times.
func (stream *Stream) Close() {
if stream.isStreamClosed() {
return
select {
case <-stream.closeChan:
default:
close(stream.closeChan)
}

stream.markStreamClosed()
close(stream.Errors)
close(stream.Events)
}

func (stream *Stream) isStreamClosed() bool {
stream.isClosedMutex.RLock()
defer stream.isClosedMutex.RUnlock()
return stream.isClosed
}
// connectHTTP connects to an event stream using the provided http request and client
func connectHTTP(lastEventID string, client *http.Client, request *http.Request) func(string) (io.ReadCloser, error) {
return func(lastEvtID string) (io.ReadCloser, error) {
client.CheckRedirect = checkRedirect

func (stream *Stream) markStreamClosed() {
stream.isClosedMutex.Lock()
defer stream.isClosedMutex.Unlock()
stream.isClosed = true
}
request.Header.Set("Cache-Control", "no-cache")
request.Header.Set("Accept", "text/event-stream")

// Go's http package doesn't copy headers across when it encounters
// redirects so we need to do that manually.
func checkRedirect(req *http.Request, via []*http.Request) error {
if len(via) >= 10 {
return errors.New("stopped after 10 redirects")
}
for k, vv := range via[0].Header {
for _, v := range vv {
req.Header.Add(k, v)
if lastEventID != "" {
request.Header.Set("Last-Event-ID", lastEventID)
}
}
return nil
}

func (stream *Stream) connect() (r io.ReadCloser, err error) {
var resp *http.Response
stream.req.Header.Set("Cache-Control", "no-cache")
stream.req.Header.Set("Accept", "text/event-stream")
if len(stream.lastEventId) > 0 {
stream.req.Header.Set("Last-Event-ID", stream.lastEventId)
}
if resp, err = stream.c.Do(stream.req); err != nil {
return
}
if resp.StatusCode != 200 {
message, _ := ioutil.ReadAll(resp.Body)
err = SubscriptionError{
Code: resp.StatusCode,
Message: string(message),
res, err := client.Do(request)
if err != nil {
return nil, err
}

if res.StatusCode != 200 {
message, _ := ioutil.ReadAll(res.Body)
res.Body.Close()

return nil, SubscriptionError{
Code: res.StatusCode,
Message: string(message),
}
}

return res.Body, nil
}
r = resp.Body
return
}

func (stream *Stream) stream(r io.ReadCloser) {
defer r.Close()
func (stream *Stream) run() {
backoff := minStreamBackoff

runLoop:
for {
reader, err := stream.connectFunc(stream.lastEventId)
if err != nil {
if stream.Logger != nil {
stream.Logger.Printf("Reconnecting in %0.4f secs\n", backoff.Seconds())
}
time.Sleep(backoff)
backoff = capStreamBackoff(backoff * 2)
continue
}

// receives events until an error is encountered
stream.receiveEvents(r)
// We connected successfully so reset the backoff
backoff = minStreamBackoff
stream.stream(reader)

// Check if we're supposed to close after the stream finishes. If not
// just make a new connection and start all over!
select {
case <-stream.closeChan:
break runLoop
default:
}
}

// tries to reconnect and start the stream again
stream.retryRestartStream()
close(stream.Events)
close(stream.Errors)
}

func (stream *Stream) receiveEvents(r io.ReadCloser) {
dec := NewDecoder(r)
func (stream *Stream) stream(reader io.ReadCloser) {
// If we fail to stream for some reason make sure it gets closed
// when we're done.
defer reader.Close()

dec := NewDecoder(reader)

for {
ev, err := dec.Decode()
if stream.isStreamClosed() {
// Check if the stream was closed before every event read just in case
// we get stuck in a bad decode loop and never reach the point of
// actually sending an event.
select {
case <-stream.closeChan:
return
default:
}

ev, err := dec.Decode()
if err != nil {
stream.Errors <- err
return
stream.writeError(err)
continue
}

pub := ev.(*publication)
if pub.Retry() > 0 {
stream.retry = time.Duration(pub.Retry()) * time.Millisecond
}
if len(pub.Id()) > 0 {
if pub.Id() != "" {
stream.lastEventId = pub.Id()
}
stream.Events <- ev
}
}

func (stream *Stream) retryRestartStream() {
backoff := stream.retry
for {
if stream.Logger != nil {
stream.Logger.Printf("Reconnecting in %0.4f secs\n", backoff.Seconds())
}
time.Sleep(backoff)
if stream.isStreamClosed() {
// Send the event but also watch for a close in case nobody
// reads from the event channel and we get blocked writing.
select {
case <-stream.closeChan:
reader.Close()
return
case stream.Events <- ev:
}
// NOTE: because of the defer we're opening the new connection
// before closing the old one. Shouldn't be a problem in practice,
// but something to be aware of.
r, err := stream.connect()
if err == nil {
go stream.stream(r)
return
}

}

func (stream *Stream) writeError(err error) {
// Start dropping old errors if nobody is reading from the
// other end so we don't end up blocking on writing an error
// but still keep a short history of errors.
if len(stream.Errors) == streamErrChanLen {
<-stream.Errors
}
stream.Errors <- err
}

func capStreamBackoff(backoff time.Duration) time.Duration {
if backoff > 10*time.Second {
return 10 * time.Second
}
return backoff
}

// Go's http package doesn't copy headers across when it encounters
// redirects so we need to do that manually.
func checkRedirect(req *http.Request, via []*http.Request) error {
if len(via) >= 10 {
return errors.New("stopped after 10 redirects")
}
for k, vv := range via[0].Header {
for _, v := range vv {
req.Header.Add(k, v)
}
stream.Errors <- err
backoff *= 2
}
return nil
}
57 changes: 48 additions & 9 deletions stream_test.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
package eventsource

import (
"bytes"
"io"
"net/http"
"net/http/httptest"
"reflect"
"testing"
Expand All @@ -13,23 +15,60 @@ const (
timeToWaitForEvent = 100 * time.Millisecond
)

type testStream struct {
bytes.Buffer
}

func (ts testStream) Close() error {
return nil
}

func TestStreamSubscribeEventsChan(t *testing.T) {
server := NewServer()
httpServer := httptest.NewServer(server.Handler(eventChannelName))
// The server has to be closed before the httpServer is closed.
// Otherwise the httpServer has still an open connection and it can not close.
buff := &testStream{}
stream, err := subscribe("", func(lastEvtID string) (io.ReadCloser, error) {
return buff, nil
})
if err != nil {
t.Fatalf("Could not create event stream: %s", err)
}

expectedEvent := &publication{id: "123"}
buff.WriteString("" +
"id: 123\n" +
"data:\n" +
"\n",
)

select {
case receivedEvent := <-stream.Events:
if !reflect.DeepEqual(receivedEvent, expectedEvent) {
t.Errorf("got event %+v, want %+v", receivedEvent, expectedEvent)
}
case <-time.After(timeToWaitForEvent):
t.Error("Timed out waiting for event")
}
}

func TestStreamSubscribeEventsChanHTTP(t *testing.T) {
httpServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Write the same event to the stream for every request
w.Write([]byte("" +
"id: 123\n" +
"data:\n" +
"\n",
))
}))
defer httpServer.Close()
defer server.Close()

stream := mustSubscribe(t, httpServer.URL, "")
defer stream.Close()

publishedEvent := &publication{id: "123"}
server.Publish([]string{eventChannelName}, publishedEvent)
expectedEvent := &publication{id: "123"}

select {
case receivedEvent := <-stream.Events:
if !reflect.DeepEqual(receivedEvent, publishedEvent) {
t.Errorf("got event %+v, want %+v", receivedEvent, publishedEvent)
if !reflect.DeepEqual(receivedEvent, expectedEvent) {
t.Errorf("got event %+v, want %+v", receivedEvent, expectedEvent)
}
case <-time.After(timeToWaitForEvent):
t.Error("Timed out waiting for event")
Expand Down