From 7179a1859b84a4a8cd9a1ddc58e1a3ed7d6368a9 Mon Sep 17 00:00:00 2001 From: Knut Ahlers Date: Wed, 6 Mar 2024 17:16:34 +0100 Subject: [PATCH] Add `ThrottledReader` as io-helper Signed-off-by: Knut Ahlers --- io/throttledReader.go | 64 ++++++++++++++++++++++++++++++++++++++ io/throttledReader_test.go | 33 ++++++++++++++++++++ 2 files changed, 97 insertions(+) create mode 100644 io/throttledReader.go create mode 100644 io/throttledReader_test.go diff --git a/io/throttledReader.go b/io/throttledReader.go new file mode 100644 index 0000000..3a34077 --- /dev/null +++ b/io/throttledReader.go @@ -0,0 +1,64 @@ +// Package io contains helpers for I/O tasks +package io + +import ( + "errors" + "fmt" + "io" + "time" +) + +// ThrottledReader implements a reader imposing a rate limit to the +// reading side to i.e. limit downloads, limit I/O on a filesystem, … +// The reads will burst and then wait until the rate "calmed" to the +// desired rate. +type ThrottledReader struct { + startRead time.Time + totalReadBytes uint64 + readRateBpns float64 + + next io.Reader +} + +// NewThrottledReader creates a reader with next as its underlying reader and +// rate as its throttle rate in Bytes / Second +func NewThrottledReader(next io.Reader, rate float64) *ThrottledReader { + return &ThrottledReader{next: next, readRateBpns: rate / float64(time.Second)} +} + +// Read implements the io.Reader interface +func (t *ThrottledReader) Read(p []byte) (n int, err error) { + if t.startRead.IsZero() { + t.startRead = time.Now() + } + + // First read is for free + n, err = t.next.Read(p) + if err != nil { + if errors.Is(err, io.EOF) { + return n, io.EOF + } + return n, fmt.Errorf("reading from next: %w", err) + } + + // Count the data + t.totalReadBytes += uint64(n) + + // Now lets see how long we need to wait + var ( + currentRate float64 + timePassedNS = int64(time.Since(t.startRead)) + ) + + if timePassedNS > 0 { + currentRate = float64(t.totalReadBytes) / float64(timePassedNS) + } + + if currentRate > t.readRateBpns { + timeToWait := int64(float64(t.totalReadBytes)/t.readRateBpns - float64(timePassedNS)) + time.Sleep(time.Duration(timeToWait)) + } + + // Waited long enough, rate is fine again, return + return n, nil +} diff --git a/io/throttledReader_test.go b/io/throttledReader_test.go new file mode 100644 index 0000000..5490a02 --- /dev/null +++ b/io/throttledReader_test.go @@ -0,0 +1,33 @@ +package io + +import ( + "crypto/rand" + "io" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestThrottledReader(t *testing.T) { + var ( + testSize = 10 * 1024 * 1024 // 10 Mi + testLimit = 20 * 1024 * 1024 // 20 Mi/s + + // 20Mi/s on 10M = 500ms exec time + expectedTimeMillisecs = float64(testSize) / float64(testLimit) * 1000 + tolerance = 50 // Millisecs + ) + + lr := io.LimitReader(rand.Reader, int64(testSize)) + var tr io.Reader = NewThrottledReader(lr, float64(testLimit)) + + start := time.Now() + n, err := io.Copy(io.Discard, tr) + require.NoError(t, err) + + assert.Equal(t, int64(testSize), n) + assert.Greater(t, time.Since(start)/time.Millisecond, time.Duration(expectedTimeMillisecs-float64(tolerance))) + assert.Less(t, time.Since(start)/time.Millisecond, time.Duration(expectedTimeMillisecs+float64(tolerance))) +}