diff --git a/speedometer/speedometer.go b/speedometer/speedometer.go new file mode 100644 index 0000000..0e820af --- /dev/null +++ b/speedometer/speedometer.go @@ -0,0 +1,244 @@ +package util + +import ( + "errors" + "fmt" + "io" + "sync" + "sync/atomic" + "time" +) + +var ErrLimitReached = errors.New("limit reached") + +// Speedometer is an io.Writer wrapper that will limit the rate at which data is written to the underlying target. +// +// It is safe for concurrent use, but writers will block when slowed down. +// +// Optionally, it can be given; +// +// - a capacity, which will cause it to return an error if the capacity is exceeded. +// +// - a speed limit, causing slow downs of data written to the underlying writer if the speed limit is exceeded. +type Speedometer struct { + ceiling int64 + speedLimit *SpeedLimit + internal atomics + w io.Writer +} + +type atomics struct { + count *atomic.Int64 + closed *atomic.Bool + start *sync.Once + stop *sync.Once + birth *atomic.Pointer[time.Time] + duration *atomic.Pointer[time.Duration] + slow *atomic.Bool +} + +func newAtomics() atomics { + manhattan := atomics{ + count: new(atomic.Int64), + closed: new(atomic.Bool), + start: new(sync.Once), + stop: new(sync.Once), + birth: new(atomic.Pointer[time.Time]), + duration: new(atomic.Pointer[time.Duration]), + slow: new(atomic.Bool), + } + manhattan.birth.Store(&time.Time{}) + manhattan.closed.Store(false) + manhattan.count.Store(0) + return manhattan +} + +// SpeedLimit is used to limit the rate at which data is written to the underlying writer. +type SpeedLimit struct { + // Burst is the number of bytes that can be written to the underlying writer per Frame. + Burst int64 + // Frame is the duration of the frame in which Burst can be written to the underlying writer. + Frame time.Duration + // CheckEveryBytes is the number of bytes written before checking if the speed limit has been exceeded. + CheckEveryBytes int64 + // Delay is the duration to delay writing if the speed limit has been exceeded during a Write call. (blocking) + Delay time.Duration +} + +func NewBytesPerSecondLimit(bytes int64) *SpeedLimit { + return &SpeedLimit{ + Burst: bytes, + Frame: time.Second, + CheckEveryBytes: 1, + Delay: 100 * time.Millisecond, + } +} + +const fallbackDelay = 100 + +func regulateSpeedLimit(speedLimit *SpeedLimit) (*SpeedLimit, error) { + if speedLimit.Burst <= 0 || speedLimit.Frame <= 0 { + return nil, errors.New("invalid speed limit") + } + if speedLimit.CheckEveryBytes <= 0 { + speedLimit.CheckEveryBytes = speedLimit.Burst + } + if speedLimit.Delay <= 0 { + speedLimit.Delay = fallbackDelay * time.Millisecond + } + return speedLimit, nil +} + +func newSpeedometer(w io.Writer, speedLimit *SpeedLimit, ceiling int64) (*Speedometer, error) { + if w == nil { + return nil, errors.New("writer cannot be nil") + } + var err error + if speedLimit != nil { + if speedLimit, err = regulateSpeedLimit(speedLimit); err != nil { + return nil, err + } + } + + return &Speedometer{ + w: w, + ceiling: ceiling, + speedLimit: speedLimit, + internal: newAtomics(), + }, nil +} + +// NewSpeedometer creates a new Speedometer that wraps the given io.Writer. +// It will not limit the rate at which data is written to the underlying writer, it only measures it. +func NewSpeedometer(w io.Writer) (*Speedometer, error) { + return newSpeedometer(w, nil, -1) +} + +// NewLimitedSpeedometer creates a new Speedometer that wraps the given io.Writer. +// If the speed limit is exceeded, writes to the underlying writer will be limited. +// See SpeedLimit for more information. +func NewLimitedSpeedometer(w io.Writer, speedLimit *SpeedLimit) (*Speedometer, error) { + return newSpeedometer(w, speedLimit, -1) +} + +// NewCappedSpeedometer creates a new Speedometer that wraps the given io.Writer. +// If len(written) bytes exceeds cap, writes to the underlying writer will be ceased permanently for the Speedometer. +func NewCappedSpeedometer(w io.Writer, capacity int64) (*Speedometer, error) { + return newSpeedometer(w, nil, capacity) +} + +// NewCappedLimitedSpeedometer creates a new Speedometer that wraps the given io.Writer. +// It is a combination of NewLimitedSpeedometer and NewCappedSpeedometer. +func NewCappedLimitedSpeedometer(w io.Writer, speedLimit *SpeedLimit, capacity int64) (*Speedometer, error) { + return newSpeedometer(w, speedLimit, capacity) +} + +func (s *Speedometer) increment(inc int64) (int, error) { + if s.internal.closed.Load() { + return 0, io.ErrClosedPipe + } + var err error + if s.ceiling > 0 && s.Total()+inc > s.ceiling { + _ = s.Close() + err = ErrLimitReached + inc = s.ceiling - s.Total() + } + s.internal.count.Add(inc) + return int(inc), err +} + +// Running returns true if the Speedometer is still running. +func (s *Speedometer) Running() bool { + return !s.internal.closed.Load() +} + +// Total returns the total number of bytes written to the underlying writer. +func (s *Speedometer) Total() int64 { + return s.internal.count.Load() +} + +// Close stops the Speedometer. No additional writes will be accepted. +func (s *Speedometer) Close() error { + if s.internal.closed.Load() { + return io.ErrClosedPipe + } + s.internal.stop.Do(func() { + s.internal.closed.Store(true) + stopped := time.Now() + birth := s.internal.birth.Load() + duration := stopped.Sub(*birth) + s.internal.duration.Store(&duration) + }) + return nil +} + +/*func (s *Speedometer) IsSlow() bool { + return s.internal.slow.Load() +}*/ + +// Rate returns the rate at which data is being written to the underlying writer per second. +func (s *Speedometer) Rate() float64 { + if s.internal.closed.Load() { + return float64(s.Total()) / s.internal.duration.Load().Seconds() + } + return float64(s.Total()) / time.Since(*s.internal.birth.Load()).Seconds() +} + +func (s *Speedometer) slowDown() error { + switch { + case s.speedLimit == nil: + return nil + case s.speedLimit.Burst <= 0 || s.speedLimit.Frame <= 0, + s.speedLimit.CheckEveryBytes <= 0, s.speedLimit.Delay <= 0: + return errors.New("invalid speed limit") + default: + // + } + if s.Total()%int64(s.speedLimit.CheckEveryBytes) != 0 { + return nil + } + s.internal.slow.Store(true) + for s.Rate() > float64(s.speedLimit.Burst)/s.speedLimit.Frame.Seconds() { + time.Sleep(s.speedLimit.Delay) + } + s.internal.slow.Store(false) + return nil +} + +// Write writes p to the underlying writer, following all defined speed limits. +func (s *Speedometer) Write(p []byte) (n int, err error) { + if s.internal.closed.Load() { + return 0, io.ErrClosedPipe + } + s.internal.start.Do(func() { + now := time.Now() + s.internal.birth.Store(&now) + }) + + // if no speed limit, just write and record + if s.speedLimit == nil { + n, err = s.w.Write(p) + if err != nil { + return n, fmt.Errorf("error writing to underlying writer: %w", err) + } + return s.increment(int64(len(p))) + } + + var ( + wErr error + accepted int + ) + accepted, wErr = s.increment(int64(len(p))) + + if wErr != nil { + return 0, fmt.Errorf("error incrementing: %w", wErr) + } + + _ = s.slowDown() + + var iErr error + if n, iErr = s.w.Write(p[:accepted]); iErr != nil { + return n, fmt.Errorf("error writing to underlying writer: %w", iErr) + } + return +} diff --git a/speedometer/speedometer_test.go b/speedometer/speedometer_test.go new file mode 100644 index 0000000..cc5e4d3 --- /dev/null +++ b/speedometer/speedometer_test.go @@ -0,0 +1,452 @@ +package util + +import ( + "bytes" + "errors" + "fmt" + "io" + "net" + "sync" + "sync/atomic" + "testing" + "time" +) + +type testWriter struct { + t *testing.T + total int64 +} + +func (w *testWriter) Write(p []byte) (n int, err error) { + atomic.AddInt64(&w.total, int64(len(p))) + return len(p), nil +} + +func writeStuff(t *testing.T, target io.Writer, count int) error { + t.Helper() + write := func() error { + _, err := target.Write([]byte("a")) + if err != nil { + return fmt.Errorf("error writing: %w", err) + } + return nil + } + + if count < 0 { + var err error + for err = write(); err == nil; err = write() { + time.Sleep(5 * time.Millisecond) + } + return err + } + for i := 0; i < count; i++ { + if err := write(); err != nil { + return err + } + } + return nil +} + +//nolint:funlen +func Test_Speedometer(t *testing.T) { + t.Parallel() + type results struct { + total int64 + written int + rate float64 + err error + } + + isIt := func(want, have results) { + t.Helper() + if have.total != want.total { + t.Errorf("total: want %d, have %d", want.total, have.total) + } + if have.written != want.written { + t.Errorf("written: want %d, have %d", want.written, have.written) + } + if have.rate != want.rate { + t.Errorf("rate: want %f, have %f", want.rate, have.rate) + } + if !errors.Is(have.err, want.err) { + t.Errorf("wantErr: want %v, have %v", want.err, have.err) + } + } + + var ( + errChan = make(chan error, 10) + ) + + t.Run("EarlyClose", func(t *testing.T) { + var ( + err error + cnt int + ) + t.Parallel() + sp, nerr := NewSpeedometer(&testWriter{t: t}) + if nerr != nil { + t.Errorf("unexpected error: %v", nerr) + } + go func() { + errChan <- writeStuff(t, sp, -1) + }() + time.Sleep(1 * time.Second) + if closeErr := sp.Close(); closeErr != nil { + t.Errorf("wantErr: want %v, have %v", nil, closeErr) + } + err = <-errChan + if !errors.Is(err, io.ErrClosedPipe) { + t.Errorf("wantErr: want %v, have %v", io.ErrClosedPipe, err) + } + cnt, err = sp.Write([]byte("a")) + isIt(results{err: io.ErrClosedPipe, written: 0}, results{err: err, written: cnt}) + }) + + t.Run("Basic", func(t *testing.T) { + var ( + err error + cnt int + ) + t.Parallel() + sp, nerr := NewSpeedometer(&testWriter{t: t}) + if nerr != nil { + t.Errorf("unexpected error: %v", nerr) + } + cnt, err = sp.Write([]byte("a")) + isIt(results{err: nil, written: 1, total: 1}, results{err: err, written: cnt, total: sp.Total()}) + cnt, err = sp.Write([]byte("aa")) + isIt(results{err: nil, written: 2, total: 3}, results{err: err, written: cnt, total: sp.Total()}) + cnt, err = sp.Write([]byte("a")) + isIt(results{err: nil, written: 1, total: 4}, results{err: err, written: cnt, total: sp.Total()}) + cnt, err = sp.Write([]byte("a")) + isIt(results{err: nil, written: 1, total: 5}, results{err: err, written: cnt, total: sp.Total()}) + }) + + t.Run("ConcurrentWrites", func(t *testing.T) { + var ( + err error + ) + + count := int64(0) + sp, nerr := NewSpeedometer(&testWriter{t: t}) + if nerr != nil { + t.Errorf("unexpected error: %v", nerr) + } + wg := &sync.WaitGroup{} + wg.Add(100) + for i := 0; i < 100; i++ { + go func() { + var counted int + var gerr error + counted, gerr = sp.Write([]byte("a")) + if gerr != nil { + t.Errorf("unexpected error: %v", err) + } + atomic.AddInt64(&count, int64(counted)) + wg.Done() + }() + } + wg.Wait() + isIt(results{err: nil, written: 100, total: 100}, + results{err: err, written: int(atomic.LoadInt64(&count)), total: sp.Total()}) + }) + + t.Run("GottaGoFast", func(t *testing.T) { + t.Parallel() + var ( + err error + ) + sp, nerr := NewSpeedometer(&testWriter{t: t}) + if nerr != nil { + t.Errorf("unexpected error: %v", nerr) + } + go func() { + errChan <- writeStuff(t, sp, -1) + }() + var count = 0 + for sp.Running() { + select { + case err = <-errChan: + if !errors.Is(err, io.ErrClosedPipe) { + t.Errorf("unexpected error: %v", err) + } else { + if count < 5 { + t.Errorf("too few iterations: %d", count) + } + t.Logf("final rate: %v per second", sp.Rate()) + } + default: + if count > 5 { + _ = sp.Close() + } + time.Sleep(100 * time.Millisecond) + t.Logf("rate: %v per second", sp.Rate()) + count++ + } + } + }) + + // test limiter with speedlimit + t.Run("CantGoFast", func(t *testing.T) { + t.Parallel() + t.Run("10BytesASecond", func(t *testing.T) { + t.Parallel() + var ( + err error + ) + sp, nerr := NewLimitedSpeedometer(&testWriter{t: t}, &SpeedLimit{ + Burst: 10, + Frame: time.Second, + CheckEveryBytes: 1, + Delay: 100 * time.Millisecond, + }) + if nerr != nil { + t.Errorf("unexpected error: %v", nerr) + } + for i := 0; i < 15; i++ { + if _, err = sp.Write([]byte("a")); err != nil { + t.Errorf("unexpected error: %v", err) + } + /*if sp.IsSlow() { + t.Errorf("unexpected slow state") + }*/ + t.Logf("rate: %v per second", sp.Rate()) + if sp.Rate() > 10 { + t.Errorf("speeding in a school zone (expected under %d): %v", sp.speedLimit.Burst, sp.Rate()) + } + } + }) + + t.Run("1000BytesPer5SecondsMeasuredEvery5000Bytes", func(t *testing.T) { + t.Parallel() + var ( + err error + ) + sp, nerr := NewLimitedSpeedometer(&testWriter{t: t}, &SpeedLimit{ + Burst: 1000, + Frame: 2 * time.Second, + CheckEveryBytes: 5000, + Delay: 500 * time.Millisecond, + }) + + if nerr != nil { + t.Errorf("unexpected error: %v", nerr) + } + + for i := 0; i < 4999; i++ { + if _, err = sp.Write([]byte("a")); err != nil { + t.Errorf("unexpected error: %v", err) + } + if i%1000 == 0 { + t.Logf("rate: %v per second", sp.Rate()) + } + if sp.Rate() < 1000 { + t.Errorf("shouldn't have slowed down yet (expected over %d): %v", sp.speedLimit.Burst, sp.Rate()) + } + } + if _, err = sp.Write([]byte("a")); err != nil { + t.Errorf("unexpected error: %v", err) + } + for i := 0; i < 10; i++ { + if _, err = sp.Write([]byte("a")); err != nil { + t.Errorf("unexpected error: %v", err) + } + t.Logf("rate: %v per second", sp.Rate()) + if sp.Rate() > 1000 { + t.Errorf("speeding in a school zone (expected under %d): %v", sp.speedLimit.Burst, sp.Rate()) + } + } + }) + }) + + // test capped speedometer + t.Run("OnlyALittle", func(t *testing.T) { + t.Parallel() + var ( + err error + ) + sp, nerr := NewCappedSpeedometer(&testWriter{t: t}, 1024) + if nerr != nil { + t.Errorf("unexpected error: %v", nerr) + } + for i := 0; i < 1024; i++ { + if _, err = sp.Write([]byte("a")); err != nil { + t.Errorf("unexpected error: %v", err) + } + if sp.Total() > 1024 { + t.Errorf("shouldn't have written more than 1024 bytes") + } + } + if _, err = sp.Write([]byte("a")); err == nil { + t.Errorf("expected error when writing over capacity") + } + }) + + t.Run("SynSynAckAck", func(t *testing.T) { + t.Parallel() + var ( + server net.Listener + err error + ) + //goland:noinspection GoCommentLeadingSpace + if server, err = net.Listen("tcp", ":8080"); err != nil { // #nosec:G102 - this is a unit test. + t.Fatalf("Failed to start server: %v", err) + } + defer func(server net.Listener) { + if cErr := server.Close(); cErr != nil { + t.Errorf("Failed to close server: %v", err) + } + }(server) + + go func() { + var ( + conn net.Conn + aErr error + ) + if conn, aErr = server.Accept(); aErr != nil { + t.Errorf("Failed to accept connection: %v", err) + } + + t.Logf("Accepted connection from %s", conn.RemoteAddr().String()) + + defer func(conn net.Conn) { + if cErr := conn.Close(); cErr != nil { + t.Errorf("Failed to close connection: %v", err) + } + }(conn) + + speedLimit := &SpeedLimit{ + Burst: 512, + Frame: time.Second, + CheckEveryBytes: 1, + Delay: 10 * time.Millisecond, + } + + var ( + speedometer *Speedometer + sErr error + ) + if speedometer, sErr = NewCappedLimitedSpeedometer(conn, speedLimit, 4096); sErr != nil { + t.Errorf("Failed to create speedometer: %v", sErr) + } + + buf := make([]byte, 1024) + for i := range buf { + targ := byte('E') + if i%2 == 0 { + targ = byte('e') + } + buf[i] = targ + } + for { + n, wErr := speedometer.Write(buf) + switch { + case errors.Is(wErr, io.EOF), errors.Is(wErr, ErrLimitReached): + return + case wErr != nil: + t.Errorf("Failed to write: %v", wErr) + case n != len(buf): + t.Errorf("Failed to write all bytes: %d", n) + default: + t.Logf("Wrote %d bytes", n) + } + } + }() + + var ( + client net.Conn + aErr error + ) + + if client, aErr = net.Dial("tcp", "localhost:8080"); aErr != nil { + t.Fatalf("Failed to connect to server: %v", err) + } + + defer func(client net.Conn) { + if clErr := client.Close(); clErr != nil { + t.Errorf("Failed to close client: %v", err) + } + }(client) + + buf := &bytes.Buffer{} + startTime := time.Now() + n, cpErr := io.Copy(buf, client) + if cpErr != nil { + t.Errorf("Failed to copy: %v", cpErr) + } + + duration := time.Since(startTime) + if buf.Len() == 0 || n == 0 { + t.Fatalf("No data received") + } + + rate := measureRate(t, n, duration) + + if rate > 512.0 { + t.Fatalf("Rate exceeded: got %f, expected <= 100.0", rate) + } + }) +} + +type badWrites struct { +} + +func (bw badWrites) Write(_ []byte) (int, error) { + return 0, io.EOF +} + +func TestImprobableEdgeCasesForCoverage(t *testing.T) { + t.Parallel() + sp, _ := NewSpeedometer(io.Discard) + sp.speedLimit = nil + if sp.slowDown() != nil { + t.Fatal("should have received no error when running slowdown with nil speedlimit") + } + sp, _ = NewLimitedSpeedometer(io.Discard, NewBytesPerSecondLimit(5)) + sp.speedLimit.Burst = 0 + if sp.slowDown() == nil { + t.Fatal("should have received error when running slowdown with invalid speedlimit") + } + sp, _ = NewLimitedSpeedometer(badWrites{}, NewBytesPerSecondLimit(5)) + if _, e := sp.Write([]byte("yeet")); !errors.Is(e, io.EOF) { + t.Errorf("wrong error from underlying writer err passdown: %v", e) + } + sp.speedLimit = nil + if _, e := sp.Write([]byte("yeet")); !errors.Is(e, io.EOF) { + t.Errorf("wrong error from underlying writer err passdown: %v", e) + } + if e := sp.Close(); e != nil { + t.Fatal("close err not nil") + } + if e := sp.Close(); !errors.Is(e, io.ErrClosedPipe) { + t.Errorf("wrong error from already closed speedo: %v", e) + } + if _, e := sp.increment(1); !errors.Is(e, io.ErrClosedPipe) { + t.Errorf("wrong error from already closed speedo: %v", e) + } + if _, err := NewLimitedSpeedometer(nil, nil); err == nil { + t.Fatal("should have received error when creating invalid speedo") + } + if _, err := NewLimitedSpeedometer(io.Discard, &SpeedLimit{}); err == nil { + t.Fatal("should have received error when creating invalid speedo") + } + sl := NewBytesPerSecondLimit(5) + sl.CheckEveryBytes = 0 + sl.Delay = 0 + var err error + if sp, err = NewLimitedSpeedometer(io.Discard, sl); err != nil { + t.Fatal("should have received no error when creating iffy speedo") + } + if sp.speedLimit.CheckEveryBytes != 5 { + t.Fatal("speed limit regularization failed") + } + if sp.speedLimit.Delay != time.Duration(100)*time.Millisecond { + t.Fatal("speed limit regularization failed") + } + +} + +func measureRate(t *testing.T, received int64, duration time.Duration) float64 { + t.Helper() + return float64(received) / duration.Seconds() +}