1
0
Fork 0
mirror of https://github.com/Luzifer/go_helpers.git synced 2024-10-18 06:14:21 +00:00

Add ThrottledReader as io-helper

Signed-off-by: Knut Ahlers <knut@ahlers.me>
This commit is contained in:
Knut Ahlers 2024-03-06 17:16:34 +01:00
parent b1fa066a12
commit 7179a1859b
Signed by: luzifer
SSH key fingerprint: SHA256:/xtE5lCgiRDQr8SLxHMS92ZBlACmATUmF1crK16Ks4E
2 changed files with 97 additions and 0 deletions

64
io/throttledReader.go Normal file
View file

@ -0,0 +1,64 @@
// Package io contains helpers for I/O tasks
package io
import (
"errors"
"fmt"
"io"
"time"
)
// ThrottledReader implements a reader imposing a rate limit to the
// reading side to i.e. limit downloads, limit I/O on a filesystem, …
// The reads will burst and then wait until the rate "calmed" to the
// desired rate.
type ThrottledReader struct {
startRead time.Time
totalReadBytes uint64
readRateBpns float64
next io.Reader
}
// NewThrottledReader creates a reader with next as its underlying reader and
// rate as its throttle rate in Bytes / Second
func NewThrottledReader(next io.Reader, rate float64) *ThrottledReader {
return &ThrottledReader{next: next, readRateBpns: rate / float64(time.Second)}
}
// Read implements the io.Reader interface
func (t *ThrottledReader) Read(p []byte) (n int, err error) {
if t.startRead.IsZero() {
t.startRead = time.Now()
}
// First read is for free
n, err = t.next.Read(p)
if err != nil {
if errors.Is(err, io.EOF) {
return n, io.EOF
}
return n, fmt.Errorf("reading from next: %w", err)
}
// Count the data
t.totalReadBytes += uint64(n)
// Now lets see how long we need to wait
var (
currentRate float64
timePassedNS = int64(time.Since(t.startRead))
)
if timePassedNS > 0 {
currentRate = float64(t.totalReadBytes) / float64(timePassedNS)
}
if currentRate > t.readRateBpns {
timeToWait := int64(float64(t.totalReadBytes)/t.readRateBpns - float64(timePassedNS))
time.Sleep(time.Duration(timeToWait))
}
// Waited long enough, rate is fine again, return
return n, nil
}

View file

@ -0,0 +1,33 @@
package io
import (
"crypto/rand"
"io"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestThrottledReader(t *testing.T) {
var (
testSize = 10 * 1024 * 1024 // 10 Mi
testLimit = 20 * 1024 * 1024 // 20 Mi/s
// 20Mi/s on 10M = 500ms exec time
expectedTimeMillisecs = float64(testSize) / float64(testLimit) * 1000
tolerance = 50 // Millisecs
)
lr := io.LimitReader(rand.Reader, int64(testSize))
var tr io.Reader = NewThrottledReader(lr, float64(testLimit))
start := time.Now()
n, err := io.Copy(io.Discard, tr)
require.NoError(t, err)
assert.Equal(t, int64(testSize), n)
assert.Greater(t, time.Since(start)/time.Millisecond, time.Duration(expectedTimeMillisecs-float64(tolerance)))
assert.Less(t, time.Since(start)/time.Millisecond, time.Duration(expectedTimeMillisecs+float64(tolerance)))
}