1
0
Fork 0
mirror of https://github.com/Luzifer/nginx-sso.git synced 2025-01-04 12:06:03 +00:00
nginx-sso/vendor/cloud.google.com/go/httpreplay/cmd/httpr/integration_test.go

229 lines
5.4 KiB
Go
Raw Normal View History

// Copyright 2018 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package main_test
import (
"context"
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
"io/ioutil"
"net"
"net/http"
"net/url"
"os"
"os/exec"
"strings"
"testing"
"time"
"cloud.google.com/go/internal/testutil"
"cloud.google.com/go/storage"
"golang.org/x/oauth2"
"google.golang.org/api/option"
)
const initial = "initial state"
func TestIntegration_HTTPR(t *testing.T) {
if testing.Short() {
t.Skip("Integration tests skipped in short mode")
}
if testutil.ProjID() == "" {
t.Fatal("set GCLOUD_TESTS_GOLANG_PROJECT_ID and GCLOUD_TESTS_GOLANG_KEY")
}
// Get a unique temporary filename.
f, err := ioutil.TempFile("", "httpreplay")
if err != nil {
t.Fatal(err)
}
replayFilename := f.Name()
if err := f.Close(); err != nil {
t.Fatal(err)
}
defer os.Remove(replayFilename)
if err := exec.Command("go", "build").Run(); err != nil {
t.Fatalf("running 'go build': %v", err)
}
defer os.Remove("./httpr")
want := runRecord(t, replayFilename)
got := runReplay(t, replayFilename)
if got != want {
t.Fatalf("got %q, want %q", got, want)
}
}
func runRecord(t *testing.T, filename string) string {
cmd, tr, cport, err := start("-record", filename)
if err != nil {
t.Fatal(err)
}
defer stop(t, cmd)
ctx := context.Background()
hc := &http.Client{
Transport: &oauth2.Transport{
Base: tr,
Source: testutil.TokenSource(ctx, storage.ScopeFullControl),
},
}
res, err := http.Post(
fmt.Sprintf("http://localhost:%s/initial", cport),
"text/plain",
strings.NewReader(initial))
if err != nil {
t.Fatal(err)
}
if res.StatusCode != 200 {
t.Fatalf("from POST: %s", res.Status)
}
info, err := getBucketInfo(ctx, hc)
if err != nil {
t.Fatal(err)
}
return info
}
func runReplay(t *testing.T, filename string) string {
cmd, tr, cport, err := start("-replay", filename)
if err != nil {
t.Fatal(err)
}
defer stop(t, cmd)
hc := &http.Client{Transport: tr}
res, err := http.Get(fmt.Sprintf("http://localhost:%s/initial", cport))
if err != nil {
t.Fatal(err)
}
if res.StatusCode != 200 {
t.Fatalf("from GET: %s", res.Status)
}
bytes, err := ioutil.ReadAll(res.Body)
res.Body.Close()
if err != nil {
t.Fatal(err)
}
if got, want := string(bytes), initial; got != want {
t.Errorf("initial: got %q, want %q", got, want)
}
info, err := getBucketInfo(context.Background(), hc)
if err != nil {
t.Fatal(err)
}
return info
}
// Start the proxy binary and wait for it to come up.
// Return a transport that talks to the proxy, as well as the control port.
// modeFlag must be either "-record" or "-replay".
func start(modeFlag, filename string) (*exec.Cmd, *http.Transport, string, error) {
pport, err := pickPort()
if err != nil {
return nil, nil, "", err
}
cport, err := pickPort()
if err != nil {
return nil, nil, "", err
}
cmd := exec.Command("./httpr", "-port", pport, "-control-port", cport, modeFlag, filename, "-debug-headers")
if err := cmd.Start(); err != nil {
return nil, nil, "", err
}
// Wait for the server to come up.
serverUp := false
for i := 0; i < 10; i++ {
if conn, err := net.Dial("tcp", "localhost:"+cport); err == nil {
conn.Close()
serverUp = true
break
}
time.Sleep(time.Second)
}
if !serverUp {
return nil, nil, "", errors.New("server never came up")
}
tr, err := proxyTransport(pport, cport)
if err != nil {
return nil, nil, "", err
}
return cmd, tr, cport, nil
}
func stop(t *testing.T, cmd *exec.Cmd) {
if err := cmd.Process.Signal(os.Interrupt); err != nil {
t.Fatal(err)
}
}
// pickPort picks an unused port.
func pickPort() (string, error) {
l, err := net.Listen("tcp", ":0")
if err != nil {
return "", err
}
addr := l.Addr().String()
_, port, err := net.SplitHostPort(addr)
if err != nil {
return "", err
}
l.Close()
return port, nil
}
func proxyTransport(pport, cport string) (*http.Transport, error) {
caCert, err := getBody(fmt.Sprintf("http://localhost:%s/authority.cer", cport))
if err != nil {
return nil, err
}
caCertPool := x509.NewCertPool()
if !caCertPool.AppendCertsFromPEM([]byte(caCert)) {
return nil, errors.New("bad CA Cert")
}
return &http.Transport{
Proxy: http.ProxyURL(&url.URL{Host: "localhost:" + pport}),
TLSClientConfig: &tls.Config{RootCAs: caCertPool},
}, nil
}
func getBucketInfo(ctx context.Context, hc *http.Client) (string, error) {
client, err := storage.NewClient(ctx, option.WithHTTPClient(hc))
if err != nil {
return "", err
}
defer client.Close()
b := client.Bucket(testutil.ProjID())
attrs, err := b.Attrs(ctx)
if err != nil {
return "", err
}
return fmt.Sprintf("name:%s reqpays:%v location:%s sclass:%s",
attrs.Name, attrs.RequesterPays, attrs.Location, attrs.StorageClass), nil
}
func getBody(url string) ([]byte, error) {
res, err := http.Get(url)
if err != nil {
return nil, err
}
if res.StatusCode != 200 {
return nil, fmt.Errorf("response: %s", res.Status)
}
defer res.Body.Close()
return ioutil.ReadAll(res.Body)
}