mirror of
https://github.com/Luzifer/go_helpers.git
synced 2024-12-24 13:01:21 +00:00
Add ThrottledReader
as io-helper
Signed-off-by: Knut Ahlers <knut@ahlers.me>
This commit is contained in:
parent
b1fa066a12
commit
7179a1859b
2 changed files with 97 additions and 0 deletions
64
io/throttledReader.go
Normal file
64
io/throttledReader.go
Normal file
|
@ -0,0 +1,64 @@
|
|||
// Package io contains helpers for I/O tasks
|
||||
package io
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ThrottledReader implements a reader imposing a rate limit to the
|
||||
// reading side to i.e. limit downloads, limit I/O on a filesystem, …
|
||||
// The reads will burst and then wait until the rate "calmed" to the
|
||||
// desired rate.
|
||||
type ThrottledReader struct {
|
||||
startRead time.Time
|
||||
totalReadBytes uint64
|
||||
readRateBpns float64
|
||||
|
||||
next io.Reader
|
||||
}
|
||||
|
||||
// NewThrottledReader creates a reader with next as its underlying reader and
|
||||
// rate as its throttle rate in Bytes / Second
|
||||
func NewThrottledReader(next io.Reader, rate float64) *ThrottledReader {
|
||||
return &ThrottledReader{next: next, readRateBpns: rate / float64(time.Second)}
|
||||
}
|
||||
|
||||
// Read implements the io.Reader interface
|
||||
func (t *ThrottledReader) Read(p []byte) (n int, err error) {
|
||||
if t.startRead.IsZero() {
|
||||
t.startRead = time.Now()
|
||||
}
|
||||
|
||||
// First read is for free
|
||||
n, err = t.next.Read(p)
|
||||
if err != nil {
|
||||
if errors.Is(err, io.EOF) {
|
||||
return n, io.EOF
|
||||
}
|
||||
return n, fmt.Errorf("reading from next: %w", err)
|
||||
}
|
||||
|
||||
// Count the data
|
||||
t.totalReadBytes += uint64(n)
|
||||
|
||||
// Now lets see how long we need to wait
|
||||
var (
|
||||
currentRate float64
|
||||
timePassedNS = int64(time.Since(t.startRead))
|
||||
)
|
||||
|
||||
if timePassedNS > 0 {
|
||||
currentRate = float64(t.totalReadBytes) / float64(timePassedNS)
|
||||
}
|
||||
|
||||
if currentRate > t.readRateBpns {
|
||||
timeToWait := int64(float64(t.totalReadBytes)/t.readRateBpns - float64(timePassedNS))
|
||||
time.Sleep(time.Duration(timeToWait))
|
||||
}
|
||||
|
||||
// Waited long enough, rate is fine again, return
|
||||
return n, nil
|
||||
}
|
33
io/throttledReader_test.go
Normal file
33
io/throttledReader_test.go
Normal file
|
@ -0,0 +1,33 @@
|
|||
package io
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"io"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestThrottledReader(t *testing.T) {
|
||||
var (
|
||||
testSize = 10 * 1024 * 1024 // 10 Mi
|
||||
testLimit = 20 * 1024 * 1024 // 20 Mi/s
|
||||
|
||||
// 20Mi/s on 10M = 500ms exec time
|
||||
expectedTimeMillisecs = float64(testSize) / float64(testLimit) * 1000
|
||||
tolerance = 50 // Millisecs
|
||||
)
|
||||
|
||||
lr := io.LimitReader(rand.Reader, int64(testSize))
|
||||
var tr io.Reader = NewThrottledReader(lr, float64(testLimit))
|
||||
|
||||
start := time.Now()
|
||||
n, err := io.Copy(io.Discard, tr)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, int64(testSize), n)
|
||||
assert.Greater(t, time.Since(start)/time.Millisecond, time.Duration(expectedTimeMillisecs-float64(tolerance)))
|
||||
assert.Less(t, time.Since(start)/time.Millisecond, time.Duration(expectedTimeMillisecs+float64(tolerance)))
|
||||
}
|
Loading…
Reference in a new issue