1
0
mirror of https://github.com/Luzifer/vault-openvpn.git synced 2024-09-20 10:02:56 +00:00
vault-openvpn/vendor/github.com/sethgrid/pester/main_test.go

466 lines
9.7 KiB
Go
Raw Normal View History

package pester
import (
"fmt"
"log"
"net"
"runtime"
"strconv"
"strings"
"sync"
"testing"
"time"
"net/http"
"net/http/cookiejar"
"errors"
)
func TestConcurrentRequests(t *testing.T) {
t.Parallel()
c := New()
c.Concurrency = 3
c.KeepLog = true
nonExistantURL := "http://localhost:9000/foo"
_, err := c.Get(nonExistantURL)
if err == nil {
t.Fatal("expected to get an error")
}
c.Wait()
// in the event of an error, let's see what the logs were
t.Log("\n", c.LogString())
if got, want := c.LogErrCount(), c.Concurrency*c.MaxRetries; got != want {
t.Errorf("got %d attempts, want %d", got, want)
}
}
func TestConcurrent2Retry0(t *testing.T) {
t.Parallel()
c := New()
c.Concurrency = 2
c.MaxRetries = 0
c.KeepLog = true
nonExistantURL := "http://localhost:9000/foo"
_, err := c.Get(nonExistantURL)
if err == nil {
t.Fatal("expected to get an error")
}
c.Wait()
// in the event of an error, let's see what the logs were
t.Log("\n", c.LogString())
if got, want := c.LogErrCount(), c.Concurrency; got != want {
t.Errorf("got %d attempts, want %d", got, want)
}
}
func TestDefaultBackoff(t *testing.T) {
t.Parallel()
c := New()
c.KeepLog = true
nonExistantURL := "http://localhost:9000/foo"
_, err := c.Get(nonExistantURL)
if err == nil {
t.Fatal("expected to get an error")
}
c.Wait()
// in the event of an error, let's see what the logs were
t.Log("\n", c.LogString())
if got, want := c.Concurrency, 1; got != want {
t.Errorf("got %d, want %d for concurrency", got, want)
}
if got, want := c.LogErrCount(), c.MaxRetries; got != want {
t.Fatalf("got %d errors, want %d", got, want)
}
var startTime int64
for i, e := range c.ErrLog {
if i == 0 {
startTime = e.Time.Unix()
continue
}
if got, want := e.Time.Unix(), startTime+int64(i); got != want {
t.Errorf("got time %d, want %d (%d greater than start time %d)", got, want, i, startTime)
}
}
}
func TestFormatError(t *testing.T) {
t.Parallel()
err := errors.New("Get http://localhost:9000/foo: dial tcp 127.0.0.1:9000: getsockopt: connection refused")
expected := "1491271979 Get [GET] http://localhost:9000/foo request-0 retry-2 error: "+ err.Error()+"\n"
e := ErrEntry{
Time: time.Unix(1491271979, 0),
Method: "Get",
URL: "http://localhost:9000/foo",
Verb: http.MethodGet,
Request: 0,
Retry: 2,
Attempt: 1,
Err: err,
}
c := New()
formatted := c.FormatError(e)
if strings.Compare(expected, formatted) != 0 {
t.Errorf("\nExpected:\n%s\nGot:\n%s", expected, formatted)
}
}
func TestCustomLogHook(t *testing.T) {
t.Parallel()
expectedRetries := 5
errorLines := []ErrEntry{}
c := New()
//c.KeepLog = true
c.MaxRetries = expectedRetries
c.Backoff = func(_ int) time.Duration {
return 10 * time.Microsecond
}
c.LogHook = func(e ErrEntry) {
errorLines = append(errorLines, e)
}
nonExistantURL := "http://localhost:9000/foo"
_, err := c.Get(nonExistantURL)
if err == nil {
t.Fatal("expected to get an error")
}
c.Wait()
// in the event of an error, let's see what the logs were
if expectedRetries != len(errorLines) {
t.Errorf("Expected %d lines to be emitted. Got %d", expectedRetries, errorLines)
}
}
func TestDefaultLogHook(t *testing.T) {
t.Parallel()
errorLines := 0
c := New()
//c.KeepLog = true
c.MaxRetries = 5
c.Backoff = func(_ int) time.Duration {
return 10 * time.Microsecond
}
nonExistantURL := "http://localhost:9000/foo"
_, err := c.Get(nonExistantURL)
if err == nil {
t.Fatal("expected to get an error")
}
c.Wait()
// in the event of an error, let's see what the logs were
if errorLines != 0 {
t.Errorf("Expected 0 lines to be emitted. Got %d", errorLines)
}
}
func TestLinearJitterBackoff(t *testing.T) {
t.Parallel()
c := New()
c.Backoff = LinearJitterBackoff
c.KeepLog = true
nonExistantURL := "http://localhost:9000/foo"
_, err := c.Get(nonExistantURL)
if err == nil {
t.Fatal("expected to get an error")
}
c.Wait()
// in the event of an error, let's see what the logs were
t.Log("\n", c.LogString())
var startTime int64
var delta int64
for i, e := range c.ErrLog {
switch i {
case 0:
startTime = e.Time.Unix()
case 1:
delta += 1
case 2:
delta += 2
case 3:
delta += 3
}
if got, want := e.Time.Unix(), startTime+delta; withinEpsilon(got, want, 0.0) {
t.Errorf("got time %d, want %d (within epsilon of start time %d)", got, want, startTime)
}
}
}
func TestExponentialBackoff(t *testing.T) {
t.Parallel()
c := New()
c.MaxRetries = 4
c.Backoff = ExponentialBackoff
c.KeepLog = true
nonExistantURL := "http://localhost:9000/foo"
_, err := c.Get(nonExistantURL)
if err == nil {
t.Fatal("expected to get an error")
}
c.Wait()
// in the event of an error, let's see what the logs were
t.Log("\n", c.LogString())
if got, want := c.LogErrCount(), c.MaxRetries; got != want {
t.Fatalf("got %d errors, want %d", got, want)
}
var startTime int64
var delta int64
for i, e := range c.ErrLog {
switch i {
case 0:
startTime = e.Time.Unix()
case 1:
delta += 2
case 2:
delta += 4
case 3:
delta += 8
}
if got, want := e.Time.Unix(), startTime+delta; got != want {
t.Errorf("got time %d, want %d (%d greater than start time %d)", got, want, delta, startTime)
}
}
}
func TestCookiesJarPersistence(t *testing.T) {
// make sure that client properties like .Jar are held onto through the request
port, err := cookieServer()
if err != nil {
t.Fatal("unable to start cookie server", err)
}
jar, err := cookiejar.New(nil)
if err != nil {
t.Fatal("Cannot create cookiejar", err)
}
c := New()
c.Jar = jar
url := fmt.Sprintf("http://localhost:%d", port)
response, err := c.Get(url)
if err != nil {
t.Fatal("unable to GET", err)
}
c.Wait()
response.Body.Close()
if !strings.Contains(fmt.Sprintf("%v", jar), "mah-cookie nomnomnom") {
t.Error("unable to find expected cookie")
}
}
func TestEmbeddedClientTimeout(t *testing.T) {
// set up a server that will timeout
clientTimeout := 1000 * time.Millisecond
port, err := timeoutServer(2 * clientTimeout)
if err != nil {
t.Fatal("unable to start timeout server", err)
}
hc := http.DefaultClient
hc.Timeout = clientTimeout
c := NewExtendedClient(hc)
_, err = c.Get(fmt.Sprintf("http://localhost:%d/", port))
if err == nil {
t.Error("expected a timeout error, did not get it")
}
}
func TestConcurrentRequestsNotRacyAndDontLeak_FailedRequest(t *testing.T) {
goroStart := runtime.NumGoroutine()
c := New()
port, err := cookieServer()
if err != nil {
t.Fatalf("unable to start server %v", err)
}
goodURL := fmt.Sprintf("http://localhost:%d", port)
conc := 5
errCh := make(chan error, conc)
wg := &sync.WaitGroup{}
block := make(chan struct{})
for i := 0; i < conc; i++ {
wg.Add(1)
go func() {
<-block
defer wg.Done()
resp, err := c.Get(goodURL)
if err != nil {
errCh <- fmt.Errorf("got unexpected error getting %s, %v", goodURL, err)
return
}
if resp != nil {
resp.Body.Close()
}
}()
}
close(block)
go func() {
select {
case err := <-errCh:
t.Fatal(err)
case <-time.After(250 * time.Millisecond):
return
}
}()
wg.Wait()
// give background goroutines time to clean up
<-time.After(1000 * time.Millisecond)
goroEnd := runtime.NumGoroutine()
if goroStart < goroEnd {
t.Errorf("got %d running goroutines, want %d", goroEnd, goroStart)
}
}
func TestConcurrentRequestsNotRacyAndDontLeak_SuccessfulRequest(t *testing.T) {
goroStart := runtime.NumGoroutine()
c := New()
nonExistantURL := "http://localhost:9000/foo"
conc := 5
errCh := make(chan error, conc)
wg := &sync.WaitGroup{}
block := make(chan struct{})
for i := 0; i < conc; i++ {
wg.Add(1)
go func() {
<-block
defer wg.Done()
resp, err := c.Get(nonExistantURL)
if err == nil {
errCh <- fmt.Errorf("should have had an error getting %s", nonExistantURL)
return
}
if resp != nil {
resp.Body.Close()
}
}()
}
close(block)
go func() {
select {
case err := <-errCh:
t.Fatal(err)
case <-time.After(250 * time.Millisecond):
return
}
}()
wg.Wait()
// give background goroutines time to clean up
<-time.After(1000 * time.Millisecond)
goroEnd := runtime.NumGoroutine()
if goroStart < goroEnd {
t.Errorf("got %d running goroutines, want %d", goroEnd, goroStart)
}
}
func withinEpsilon(got, want int64, epslion float64) bool {
if want <= int64(epslion*float64(got)) || want >= int64(epslion*float64(got)) {
return false
}
return true
}
func cookieServer() (int, error) {
mux := http.NewServeMux()
mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
cookie := &http.Cookie{}
cookie.Name = "mah-cookie"
cookie.Value = "nomnomnom"
http.SetCookie(w, cookie)
w.Write([]byte("OK"))
})
l, err := net.Listen("tcp", ":0")
if err != nil {
return -1, fmt.Errorf("unable to secure listener %v", err)
}
go func() {
if err := http.Serve(l, mux); err != nil {
log.Fatalf("slow-server error %v", err)
}
}()
var port int
_, sport, err := net.SplitHostPort(l.Addr().String())
if err == nil {
port, err = strconv.Atoi(sport)
}
if err != nil {
return -1, fmt.Errorf("unable to determine port %v", err)
}
return port, nil
}
func timeoutServer(timeout time.Duration) (int, error) {
mux := http.NewServeMux()
mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
<-time.After(timeout)
w.Write([]byte("OK"))
})
l, err := net.Listen("tcp", ":0")
if err != nil {
return -1, fmt.Errorf("unable to secure listener %v", err)
}
go func() {
if err := http.Serve(l, mux); err != nil {
log.Fatalf("slow-server error %v", err)
}
}()
var port int
_, sport, err := net.SplitHostPort(l.Addr().String())
if err == nil {
port, err = strconv.Atoi(sport)
}
if err != nil {
return -1, fmt.Errorf("unable to determine port %v", err)
}
return port, nil
}