mirror of
https://github.com/Luzifer/go_helpers.git
synced 2024-12-25 05:21:20 +00:00
Add ThrottledReader
as io-helper
Signed-off-by: Knut Ahlers <knut@ahlers.me>
This commit is contained in:
parent
b1fa066a12
commit
7179a1859b
2 changed files with 97 additions and 0 deletions
64
io/throttledReader.go
Normal file
64
io/throttledReader.go
Normal file
|
@ -0,0 +1,64 @@
|
||||||
|
// Package io contains helpers for I/O tasks
|
||||||
|
package io
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ThrottledReader implements a reader imposing a rate limit to the
|
||||||
|
// reading side to i.e. limit downloads, limit I/O on a filesystem, …
|
||||||
|
// The reads will burst and then wait until the rate "calmed" to the
|
||||||
|
// desired rate.
|
||||||
|
type ThrottledReader struct {
|
||||||
|
startRead time.Time
|
||||||
|
totalReadBytes uint64
|
||||||
|
readRateBpns float64
|
||||||
|
|
||||||
|
next io.Reader
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewThrottledReader creates a reader with next as its underlying reader and
|
||||||
|
// rate as its throttle rate in Bytes / Second
|
||||||
|
func NewThrottledReader(next io.Reader, rate float64) *ThrottledReader {
|
||||||
|
return &ThrottledReader{next: next, readRateBpns: rate / float64(time.Second)}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read implements the io.Reader interface
|
||||||
|
func (t *ThrottledReader) Read(p []byte) (n int, err error) {
|
||||||
|
if t.startRead.IsZero() {
|
||||||
|
t.startRead = time.Now()
|
||||||
|
}
|
||||||
|
|
||||||
|
// First read is for free
|
||||||
|
n, err = t.next.Read(p)
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, io.EOF) {
|
||||||
|
return n, io.EOF
|
||||||
|
}
|
||||||
|
return n, fmt.Errorf("reading from next: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Count the data
|
||||||
|
t.totalReadBytes += uint64(n)
|
||||||
|
|
||||||
|
// Now lets see how long we need to wait
|
||||||
|
var (
|
||||||
|
currentRate float64
|
||||||
|
timePassedNS = int64(time.Since(t.startRead))
|
||||||
|
)
|
||||||
|
|
||||||
|
if timePassedNS > 0 {
|
||||||
|
currentRate = float64(t.totalReadBytes) / float64(timePassedNS)
|
||||||
|
}
|
||||||
|
|
||||||
|
if currentRate > t.readRateBpns {
|
||||||
|
timeToWait := int64(float64(t.totalReadBytes)/t.readRateBpns - float64(timePassedNS))
|
||||||
|
time.Sleep(time.Duration(timeToWait))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Waited long enough, rate is fine again, return
|
||||||
|
return n, nil
|
||||||
|
}
|
33
io/throttledReader_test.go
Normal file
33
io/throttledReader_test.go
Normal file
|
@ -0,0 +1,33 @@
|
||||||
|
package io
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/rand"
|
||||||
|
"io"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestThrottledReader(t *testing.T) {
|
||||||
|
var (
|
||||||
|
testSize = 10 * 1024 * 1024 // 10 Mi
|
||||||
|
testLimit = 20 * 1024 * 1024 // 20 Mi/s
|
||||||
|
|
||||||
|
// 20Mi/s on 10M = 500ms exec time
|
||||||
|
expectedTimeMillisecs = float64(testSize) / float64(testLimit) * 1000
|
||||||
|
tolerance = 50 // Millisecs
|
||||||
|
)
|
||||||
|
|
||||||
|
lr := io.LimitReader(rand.Reader, int64(testSize))
|
||||||
|
var tr io.Reader = NewThrottledReader(lr, float64(testLimit))
|
||||||
|
|
||||||
|
start := time.Now()
|
||||||
|
n, err := io.Copy(io.Discard, tr)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
assert.Equal(t, int64(testSize), n)
|
||||||
|
assert.Greater(t, time.Since(start)/time.Millisecond, time.Duration(expectedTimeMillisecs-float64(tolerance)))
|
||||||
|
assert.Less(t, time.Since(start)/time.Millisecond, time.Duration(expectedTimeMillisecs+float64(tolerance)))
|
||||||
|
}
|
Loading…
Reference in a new issue