Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions trickle/local_publisher.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,16 +35,23 @@ func (c *TrickleLocalPublisher) Write(data io.Reader) error {
stream := c.server.getOrCreateStream(c.channelName, c.mimeType, true)
c.mu.Lock()
seq := c.seq
segment, exists := stream.getForWrite(seq)
segment, exists, closed := stream.getForWrite(seq)
if closed {
c.mu.Unlock()
return errors.New("stream closed")
}
if exists {
c.mu.Unlock()
return errors.New("Entry already exists for this sequence")
}

// before we begin - let's pre-create the next segment
nextSeq := c.seq + 1
if _, exists = stream.getForWrite(nextSeq); exists {
if _, exists, closed = stream.getForWrite(nextSeq); exists || closed {
c.mu.Unlock()
if closed {
return errors.New("Stream closed")
}
return errors.New("Next entry already exists in this sequence")
}
c.seq = nextSeq
Expand Down
57 changes: 45 additions & 12 deletions trickle/trickle_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,12 @@ import (

const CHANGEFEED = "_changes"

// Stream exists but segment doesn't
const StatusNoSegment = 470

// Stream exists but is closed
const StatusStreamClosed = 471

type TrickleServerConfig struct {
// Base HTTP path for the server
BasePath string
Expand All @@ -36,6 +42,14 @@ type TrickleServerConfig struct {

// How often to sweep for idle channels (default 1 minute)
SweepInterval time.Duration

// Return the current time. Used mostly for testing.
Now func() time.Time

// Whether to delay cleanup of channel contents after closing.
// Writes are not allowed after closing, but reads are.
// Closed channels are cleared after IdleTimeout.
DelayCleanup bool
}

type Server struct {
Expand All @@ -50,6 +64,7 @@ type Server struct {

type Stream struct {
mutex sync.RWMutex
config TrickleServerConfig
segments []*Segment
name string
mimeType string
Expand Down Expand Up @@ -97,6 +112,9 @@ func applyDefaults(config *TrickleServerConfig) {
if config.SweepInterval == 0 {
config.SweepInterval = time.Minute
}
if config.Now == nil {
config.Now = time.Now
}
}

func ConfigureServer(config TrickleServerConfig) *Server {
Expand Down Expand Up @@ -160,9 +178,10 @@ func (sm *Server) getOrCreateStream(streamName, mimeType string, isLocal bool) *
if !exists && (isLocal || sm.config.Autocreate) {
stream = &Stream{
segments: make([]*Segment, maxSegmentsPerStream),
config: sm.config,
name: streamName,
mimeType: mimeType,
writeTime: time.Now(),
writeTime: sm.config.Now(),
canReset: !isLocal,
}
sm.streams[streamName] = stream
Expand Down Expand Up @@ -201,7 +220,7 @@ func (sm *Server) sweepIdleChannels() {
sm.mutex.Lock()
streams := slices.Collect(maps.Values(sm.streams))
sm.mutex.Unlock()
now := time.Now()
now := sm.config.Now()
for _, s := range streams {
// skip internal channels for now, eg changefeed
if strings.HasPrefix(s.name, "_") {
Expand All @@ -226,7 +245,9 @@ func (s *Stream) close() {
for _, segment := range s.segments {
segment.close()
}
s.segments = make([]*Segment, maxSegmentsPerStream)
if !s.config.DelayCleanup {
s.segments = make([]*Segment, maxSegmentsPerStream)
}
s.closed = true
}

Expand All @@ -240,7 +261,9 @@ func (sm *Server) closeStream(streamName string) error {

stream.close()
sm.mutex.Lock()
delete(sm.streams, streamName)
if !sm.config.DelayCleanup || sm.config.Now().Sub(stream.writeTime) > sm.config.IdleTimeout {
delete(sm.streams, streamName)
}
sm.mutex.Unlock()
slog.Info("Deleted stream", "streamName", streamName)

Expand Down Expand Up @@ -373,7 +396,14 @@ func (tr *timeoutReader) Close() error {

// Handle post requests for a given index
func (s *Stream) handlePost(w http.ResponseWriter, r *http.Request, idx int) {
segment, _ := s.getForWrite(idx)
segment, _, closed := s.getForWrite(idx)

if closed {
w.Header().Set("Connection", "close") // wakes up gotrickle preconnects
w.Header().Set("Lp-Trickle-Closed", "terminated")
http.Error(w, "Stream closed", http.StatusNotFound)
return
}

// Wrap the request body with the custom timeoutReader so we can send
// provisional headers (keepalives) until receiving the first byte
Expand All @@ -394,7 +424,7 @@ func (s *Stream) handlePost(w http.ResponseWriter, r *http.Request, idx int) {
if totalRead == 0 {
s.mutex.Lock()
s.nextWrite = idx + 1
s.writeTime = time.Now()
s.writeTime = s.config.Now()
s.mutex.Unlock()
}
segment.writeData(buf[:n])
Expand Down Expand Up @@ -438,9 +468,12 @@ func (s *Stream) handlePost(w http.ResponseWriter, r *http.Request, idx int) {
segment.close()
}

func (s *Stream) getForWrite(idx int) (*Segment, bool) {
func (s *Stream) getForWrite(idx int) (*Segment, bool, bool) {
s.mutex.Lock()
defer s.mutex.Unlock()
if s.closed {
return nil, false, true
}
if idx == -1 {
idx = s.nextWrite
}
Expand All @@ -453,17 +486,17 @@ func (s *Stream) getForWrite(idx int) (*Segment, bool) {
if reset > 0 {
slog.Warn("Reset an existing segment", "stream", s.name, "idx", idx, "bytes", reset)
}
return segment, reset > 0
return segment, reset > 0, false
}
return segment, !segment.isFresh()
return segment, !segment.isFresh(), false
}
// something exists here but its not the expected segment
// probably an old segment so overwrite it
segment.close()
}
segment := newSegment(idx)
s.segments[segmentPos] = segment
return segment, false
return segment, false, false
}

func (s *Stream) getForRead(idx int) (*Segment, int, bool, bool) {
Expand Down Expand Up @@ -518,7 +551,7 @@ func (s *Stream) handleGet(w http.ResponseWriter, r *http.Request, idx int) {
w.Header().Set("Lp-Trickle-Closed", "terminated")
} else {
// Special status to indicate "stream exists but segment doesn't"
w.WriteHeader(470)
w.WriteHeader(StatusNoSegment)
}
w.Write([]byte("Entry not found"))
return
Expand Down Expand Up @@ -577,7 +610,7 @@ func (s *Stream) handleGet(w http.ResponseWriter, r *http.Request, idx int) {
// other times, the subscriber is slow and the segment falls out of the live window
// send over latest seq so slow clients can grab leading edge
w.Header().Set("Lp-Trickle-Latest", strconv.Itoa(latestSeq))
w.WriteHeader(470)
w.WriteHeader(StatusNoSegment)
}
}
return totalWrites, nil
Expand Down
4 changes: 2 additions & 2 deletions trickle/trickle_subscriber.go
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ func (c *TrickleSubscriber) connect(ctx context.Context) (*http.Response, error)
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
resp.Body.Close() // Ensure we close the body to avoid leaking connections
if resp.StatusCode == http.StatusNotFound || resp.StatusCode == 470 {
if resp.StatusCode == http.StatusNotFound || resp.StatusCode == StatusNoSegment {
return resp, nil
}
return nil, fmt.Errorf("failed GET segment, status code: %d, msg: %s", resp.StatusCode, string(body))
Expand Down Expand Up @@ -247,7 +247,7 @@ func (c *TrickleSubscriber) Read() (*http.Response, error) {
return nil, StreamNotFoundErr
}

if conn.StatusCode == 470 {
if conn.StatusCode == StatusNoSegment {
// stream exists but segment dosn't
return nil, &SequenceNonexistent{Seq: GetSeq(conn), Latest: GetLatest(conn)}
}
Expand Down
97 changes: 97 additions & 0 deletions trickle/trickle_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,87 @@ func TestTrickle_Close(t *testing.T) {
require.Error(StreamNotFoundErr, pub2.Write(bytes.NewReader([]byte("bad post"))))
}

func TestTrickle_DelayedClose(t *testing.T) {
require := require.New(t)
mux := http.NewServeMux()
clock := MockClock{}
server := ConfigureServer(TrickleServerConfig{
Mux: mux,
DelayCleanup: true,
Now: clock.Now,
SweepInterval: 5 * time.Millisecond,
})

// Check:
// 1. reading after close
// 2. writing after close

stop := server.Start()
ts := httptest.NewServer(mux)
defer ts.Close()
defer stop()
lp := NewLocalPublisher(server, "testest", "text/plain")
lp.CreateChannel()

// Write to a channel and close it
channelURL := ts.URL + "/testest"
segs := []string{"012", "345", "678", "901"}
pub, err := NewTricklePublisher(channelURL)
require.Nil(err)
for _, c := range segs {
require.Nil(pub.Write(bytes.NewReader([]byte(c))), "pub.Write")
}

// Now close the stream
pub.Close()

//
// NB: we set the "Connection: close" header on the server after close
// If the client is under load and slow this could lead to premature
// termination of the preconnect before the response is finished reading.
// Easily triggered with eg, `go test -race -count 100`
//
// Check repeatedly to run out preconnects, and allow for a couple of
// preconnect failures, but most re-connect attempts should succeed.
gotEOS := 0
for i := 0; i < 10; i++ {
err = pub.Write(bytes.NewReader([]byte("234")))
if errors.Is(err, EOS) {
gotEOS++
}
}
require.GreaterOrEqual(gotEOS, 8, "did not have enough EOS writes")

// Check reads
sub, err := NewTrickleSubscriber(subConfig(channelURL))
require.Nil(err)
sub.SetSeq(0)
for _, s := range segs {
resp, err := sub.Read()
require.Nil(err, "sub.Read")
buf, err := io.ReadAll(resp.Body)
require.Nil(err)
require.Equal(s, string(buf))
resp.Body.Close()
}
// check for EOS multiple times to flush out read preconnects
for i := 0; i < 5; i++ {
_, err = sub.Read()
require.ErrorIs(err, EOS)
}

// check for sweep
clock.Set(time.Now())
time.Sleep(20 * time.Millisecond)
pub, err = NewTricklePublisher(channelURL)
require.Nil(err)
require.ErrorIs(pub.Write(bytes.NewReader([]byte("invalid"))), StreamNotFoundErr)
sub, err = NewTrickleSubscriber(subConfig(channelURL))
require.Nil(err)
_, err = sub.Read()
require.ErrorIs(err, StreamNotFoundErr)
}

func TestTrickle_SetSeq(t *testing.T) {
require := require.New(t)
mux := http.NewServeMux()
Expand Down Expand Up @@ -415,3 +496,19 @@ func makeServer(t *testing.T) (*require.Assertions, string) {
func subConfig(url string) TrickleSubscriberConfig {
return TrickleSubscriberConfig{URL: url}
}

type MockClock struct {
mu sync.Mutex
now time.Time
}

func (mc *MockClock) Now() time.Time {
mc.mu.Lock()
defer mc.mu.Unlock()
return mc.now
}
func (mc *MockClock) Set(now time.Time) {
mc.mu.Lock()
defer mc.mu.Unlock()
mc.now = now
}
Loading