1
0
Fork 0
mirror of https://github.com/Luzifer/s3sync.git synced 2024-12-20 19:41:15 +00:00

Migrate and update Godeps

Signed-off-by: Knut Ahlers <knut@ahlers.me>
This commit is contained in:
Knut Ahlers 2017-05-29 11:16:47 +02:00
parent 8c9fc0acc1
commit ffdaf5c63f
Signed by: luzifer
GPG key ID: DC2729FDD34BE99E
253 changed files with 44614 additions and 26945 deletions

148
Godeps/Godeps.json generated
View file

@ -1,62 +1,152 @@
{ {
"ImportPath": "github.com/Luzifer/s3sync", "ImportPath": "github.com/Luzifer/s3sync",
"GoVersion": "go1.4.2", "GoVersion": "go1.8",
"GodepVersion": "v79",
"Packages": [
"./..."
],
"Deps": [ "Deps": [
{ {
"ImportPath": "github.com/aws/aws-sdk-go/aws", "ImportPath": "github.com/aws/aws-sdk-go/aws",
"Comment": "v0.7.1-4-g1c75632", "Comment": "v1.8.30-1-g26d4122",
"Rev": "1c75632eb9b77bb3d8e6416fbd8f3e98fc1367db" "Rev": "26d4122e877bd86f43d528ad7fbde237dbbd3fa9"
}, },
{ {
"ImportPath": "github.com/aws/aws-sdk-go/internal/endpoints", "ImportPath": "github.com/aws/aws-sdk-go/aws/awserr",
"Comment": "v0.7.1-4-g1c75632", "Comment": "v1.8.30-1-g26d4122",
"Rev": "1c75632eb9b77bb3d8e6416fbd8f3e98fc1367db" "Rev": "26d4122e877bd86f43d528ad7fbde237dbbd3fa9"
}, },
{ {
"ImportPath": "github.com/aws/aws-sdk-go/internal/protocol/query", "ImportPath": "github.com/aws/aws-sdk-go/aws/awsutil",
"Comment": "v0.7.1-4-g1c75632", "Comment": "v1.8.30-1-g26d4122",
"Rev": "1c75632eb9b77bb3d8e6416fbd8f3e98fc1367db" "Rev": "26d4122e877bd86f43d528ad7fbde237dbbd3fa9"
}, },
{ {
"ImportPath": "github.com/aws/aws-sdk-go/internal/protocol/rest", "ImportPath": "github.com/aws/aws-sdk-go/aws/client",
"Comment": "v0.7.1-4-g1c75632", "Comment": "v1.8.30-1-g26d4122",
"Rev": "1c75632eb9b77bb3d8e6416fbd8f3e98fc1367db" "Rev": "26d4122e877bd86f43d528ad7fbde237dbbd3fa9"
}, },
{ {
"ImportPath": "github.com/aws/aws-sdk-go/internal/protocol/restxml", "ImportPath": "github.com/aws/aws-sdk-go/aws/client/metadata",
"Comment": "v0.7.1-4-g1c75632", "Comment": "v1.8.30-1-g26d4122",
"Rev": "1c75632eb9b77bb3d8e6416fbd8f3e98fc1367db" "Rev": "26d4122e877bd86f43d528ad7fbde237dbbd3fa9"
}, },
{ {
"ImportPath": "github.com/aws/aws-sdk-go/internal/protocol/xml/xmlutil", "ImportPath": "github.com/aws/aws-sdk-go/aws/corehandlers",
"Comment": "v0.7.1-4-g1c75632", "Comment": "v1.8.30-1-g26d4122",
"Rev": "1c75632eb9b77bb3d8e6416fbd8f3e98fc1367db" "Rev": "26d4122e877bd86f43d528ad7fbde237dbbd3fa9"
}, },
{ {
"ImportPath": "github.com/aws/aws-sdk-go/internal/signer/v4", "ImportPath": "github.com/aws/aws-sdk-go/aws/credentials",
"Comment": "v0.7.1-4-g1c75632", "Comment": "v1.8.30-1-g26d4122",
"Rev": "1c75632eb9b77bb3d8e6416fbd8f3e98fc1367db" "Rev": "26d4122e877bd86f43d528ad7fbde237dbbd3fa9"
},
{
"ImportPath": "github.com/aws/aws-sdk-go/aws/credentials/ec2rolecreds",
"Comment": "v1.8.30-1-g26d4122",
"Rev": "26d4122e877bd86f43d528ad7fbde237dbbd3fa9"
},
{
"ImportPath": "github.com/aws/aws-sdk-go/aws/credentials/endpointcreds",
"Comment": "v1.8.30-1-g26d4122",
"Rev": "26d4122e877bd86f43d528ad7fbde237dbbd3fa9"
},
{
"ImportPath": "github.com/aws/aws-sdk-go/aws/credentials/stscreds",
"Comment": "v1.8.30-1-g26d4122",
"Rev": "26d4122e877bd86f43d528ad7fbde237dbbd3fa9"
},
{
"ImportPath": "github.com/aws/aws-sdk-go/aws/defaults",
"Comment": "v1.8.30-1-g26d4122",
"Rev": "26d4122e877bd86f43d528ad7fbde237dbbd3fa9"
},
{
"ImportPath": "github.com/aws/aws-sdk-go/aws/ec2metadata",
"Comment": "v1.8.30-1-g26d4122",
"Rev": "26d4122e877bd86f43d528ad7fbde237dbbd3fa9"
},
{
"ImportPath": "github.com/aws/aws-sdk-go/aws/endpoints",
"Comment": "v1.8.30-1-g26d4122",
"Rev": "26d4122e877bd86f43d528ad7fbde237dbbd3fa9"
},
{
"ImportPath": "github.com/aws/aws-sdk-go/aws/request",
"Comment": "v1.8.30-1-g26d4122",
"Rev": "26d4122e877bd86f43d528ad7fbde237dbbd3fa9"
},
{
"ImportPath": "github.com/aws/aws-sdk-go/aws/session",
"Comment": "v1.8.30-1-g26d4122",
"Rev": "26d4122e877bd86f43d528ad7fbde237dbbd3fa9"
},
{
"ImportPath": "github.com/aws/aws-sdk-go/aws/signer/v4",
"Comment": "v1.8.30-1-g26d4122",
"Rev": "26d4122e877bd86f43d528ad7fbde237dbbd3fa9"
},
{
"ImportPath": "github.com/aws/aws-sdk-go/private/protocol",
"Comment": "v1.8.30-1-g26d4122",
"Rev": "26d4122e877bd86f43d528ad7fbde237dbbd3fa9"
},
{
"ImportPath": "github.com/aws/aws-sdk-go/private/protocol/query",
"Comment": "v1.8.30-1-g26d4122",
"Rev": "26d4122e877bd86f43d528ad7fbde237dbbd3fa9"
},
{
"ImportPath": "github.com/aws/aws-sdk-go/private/protocol/query/queryutil",
"Comment": "v1.8.30-1-g26d4122",
"Rev": "26d4122e877bd86f43d528ad7fbde237dbbd3fa9"
},
{
"ImportPath": "github.com/aws/aws-sdk-go/private/protocol/rest",
"Comment": "v1.8.30-1-g26d4122",
"Rev": "26d4122e877bd86f43d528ad7fbde237dbbd3fa9"
},
{
"ImportPath": "github.com/aws/aws-sdk-go/private/protocol/restxml",
"Comment": "v1.8.30-1-g26d4122",
"Rev": "26d4122e877bd86f43d528ad7fbde237dbbd3fa9"
},
{
"ImportPath": "github.com/aws/aws-sdk-go/private/protocol/xml/xmlutil",
"Comment": "v1.8.30-1-g26d4122",
"Rev": "26d4122e877bd86f43d528ad7fbde237dbbd3fa9"
}, },
{ {
"ImportPath": "github.com/aws/aws-sdk-go/service/s3", "ImportPath": "github.com/aws/aws-sdk-go/service/s3",
"Comment": "v0.7.1-4-g1c75632", "Comment": "v1.8.30-1-g26d4122",
"Rev": "1c75632eb9b77bb3d8e6416fbd8f3e98fc1367db" "Rev": "26d4122e877bd86f43d528ad7fbde237dbbd3fa9"
},
{
"ImportPath": "github.com/aws/aws-sdk-go/service/sts",
"Comment": "v1.8.30-1-g26d4122",
"Rev": "26d4122e877bd86f43d528ad7fbde237dbbd3fa9"
},
{
"ImportPath": "github.com/go-ini/ini",
"Comment": "v1.24.0",
"Rev": "e3c2d47c61e5333f9aa2974695dd94396eb69c75"
}, },
{ {
"ImportPath": "github.com/inconshreveable/mousetrap", "ImportPath": "github.com/inconshreveable/mousetrap",
"Rev": "76626ae9c91c4f2a10f34cad8ce83ea42c93bb75" "Rev": "76626ae9c91c4f2a10f34cad8ce83ea42c93bb75"
}, },
{
"ImportPath": "github.com/jmespath/go-jmespath",
"Comment": "0.2.2-14-gbd40a43",
"Rev": "bd40a432e4c76585ef6b72d3fd96fb9b6dc7b68d"
},
{ {
"ImportPath": "github.com/spf13/cobra", "ImportPath": "github.com/spf13/cobra",
"Rev": "385fc87e4343efec233811d3d933509e8975d11a" "Rev": "9c28e4bbd74e5c3ed7aacbc552b2cab7cfdfe744"
}, },
{ {
"ImportPath": "github.com/spf13/pflag", "ImportPath": "github.com/spf13/pflag",
"Rev": "67cbc198fd11dab704b214c1e629a97af392c085" "Rev": "c7e63cf4530bcd3ba943729cee0efeff2ebea63f"
},
{
"ImportPath": "github.com/vaughan0/go-ini",
"Rev": "a98ad7ee00ec53921f08832bc06ecf7fd600e6a1"
} }
] ]
} }

2
Godeps/_workspace/.gitignore generated vendored
View file

@ -1,2 +0,0 @@
/pkg
/bin

View file

@ -1,201 +0,0 @@
package awsutil_test
import (
"bytes"
"fmt"
"io"
"io/ioutil"
"testing"
"github.com/aws/aws-sdk-go/aws/awsutil"
"github.com/stretchr/testify/assert"
)
func ExampleCopy() {
type Foo struct {
A int
B []*string
}
// Create the initial value
str1 := "hello"
str2 := "bye bye"
f1 := &Foo{A: 1, B: []*string{&str1, &str2}}
// Do the copy
var f2 Foo
awsutil.Copy(&f2, f1)
// Print the result
fmt.Println(awsutil.Prettify(f2))
// Output:
// {
// A: 1,
// B: ["hello","bye bye"]
// }
}
func TestCopy(t *testing.T) {
type Foo struct {
A int
B []*string
C map[string]*int
}
// Create the initial value
str1 := "hello"
str2 := "bye bye"
int1 := 1
int2 := 2
f1 := &Foo{
A: 1,
B: []*string{&str1, &str2},
C: map[string]*int{
"A": &int1,
"B": &int2,
},
}
// Do the copy
var f2 Foo
awsutil.Copy(&f2, f1)
// Values are equal
assert.Equal(t, f2.A, f1.A)
assert.Equal(t, f2.B, f1.B)
assert.Equal(t, f2.C, f1.C)
// But pointers are not!
str3 := "nothello"
int3 := 57
f2.A = 100
f2.B[0] = &str3
f2.C["B"] = &int3
assert.NotEqual(t, f2.A, f1.A)
assert.NotEqual(t, f2.B, f1.B)
assert.NotEqual(t, f2.C, f1.C)
}
func TestCopyIgnoreNilMembers(t *testing.T) {
type Foo struct {
A *string
B []string
C map[string]string
}
f := &Foo{}
assert.Nil(t, f.A)
assert.Nil(t, f.B)
assert.Nil(t, f.C)
var f2 Foo
awsutil.Copy(&f2, f)
assert.Nil(t, f2.A)
assert.Nil(t, f2.B)
assert.Nil(t, f2.C)
fcopy := awsutil.CopyOf(f)
f3 := fcopy.(*Foo)
assert.Nil(t, f3.A)
assert.Nil(t, f3.B)
assert.Nil(t, f3.C)
}
func TestCopyPrimitive(t *testing.T) {
str := "hello"
var s string
awsutil.Copy(&s, &str)
assert.Equal(t, "hello", s)
}
func TestCopyNil(t *testing.T) {
var s string
awsutil.Copy(&s, nil)
assert.Equal(t, "", s)
}
func TestCopyReader(t *testing.T) {
var buf io.Reader = bytes.NewReader([]byte("hello world"))
var r io.Reader
awsutil.Copy(&r, buf)
b, err := ioutil.ReadAll(r)
assert.NoError(t, err)
assert.Equal(t, []byte("hello world"), b)
// empty bytes because this is not a deep copy
b, err = ioutil.ReadAll(buf)
assert.NoError(t, err)
assert.Equal(t, []byte(""), b)
}
func TestCopyDifferentStructs(t *testing.T) {
type SrcFoo struct {
A int
B []*string
C map[string]*int
SrcUnique string
SameNameDiffType int
}
type DstFoo struct {
A int
B []*string
C map[string]*int
DstUnique int
SameNameDiffType string
}
// Create the initial value
str1 := "hello"
str2 := "bye bye"
int1 := 1
int2 := 2
f1 := &SrcFoo{
A: 1,
B: []*string{&str1, &str2},
C: map[string]*int{
"A": &int1,
"B": &int2,
},
SrcUnique: "unique",
SameNameDiffType: 1,
}
// Do the copy
var f2 DstFoo
awsutil.Copy(&f2, f1)
// Values are equal
assert.Equal(t, f2.A, f1.A)
assert.Equal(t, f2.B, f1.B)
assert.Equal(t, f2.C, f1.C)
assert.Equal(t, "unique", f1.SrcUnique)
assert.Equal(t, 1, f1.SameNameDiffType)
assert.Equal(t, 0, f2.DstUnique)
assert.Equal(t, "", f2.SameNameDiffType)
}
func ExampleCopyOf() {
type Foo struct {
A int
B []*string
}
// Create the initial value
str1 := "hello"
str2 := "bye bye"
f1 := &Foo{A: 1, B: []*string{&str1, &str2}}
// Do the copy
v := awsutil.CopyOf(f1)
var f2 *Foo = v.(*Foo)
// Print the result
fmt.Println(awsutil.Prettify(f2))
// Output:
// {
// A: 1,
// B: ["hello","bye bye"]
// }
}

View file

@ -1,68 +0,0 @@
package awsutil_test
import (
"testing"
"github.com/aws/aws-sdk-go/aws/awsutil"
"github.com/stretchr/testify/assert"
)
type Struct struct {
A []Struct
z []Struct
B *Struct
D *Struct
C string
}
var data = Struct{
A: []Struct{{C: "value1"}, {C: "value2"}, {C: "value3"}},
z: []Struct{{C: "value1"}, {C: "value2"}, {C: "value3"}},
B: &Struct{B: &Struct{C: "terminal"}, D: &Struct{C: "terminal2"}},
C: "initial",
}
func TestValueAtPathSuccess(t *testing.T) {
assert.Equal(t, []interface{}{"initial"}, awsutil.ValuesAtPath(data, "C"))
assert.Equal(t, []interface{}{"value1"}, awsutil.ValuesAtPath(data, "A[0].C"))
assert.Equal(t, []interface{}{"value2"}, awsutil.ValuesAtPath(data, "A[1].C"))
assert.Equal(t, []interface{}{"value3"}, awsutil.ValuesAtPath(data, "A[2].C"))
assert.Equal(t, []interface{}{"value3"}, awsutil.ValuesAtAnyPath(data, "a[2].c"))
assert.Equal(t, []interface{}{"value3"}, awsutil.ValuesAtPath(data, "A[-1].C"))
assert.Equal(t, []interface{}{"value1", "value2", "value3"}, awsutil.ValuesAtPath(data, "A[].C"))
assert.Equal(t, []interface{}{"terminal"}, awsutil.ValuesAtPath(data, "B . B . C"))
assert.Equal(t, []interface{}{"terminal", "terminal2"}, awsutil.ValuesAtPath(data, "B.*.C"))
assert.Equal(t, []interface{}{"initial"}, awsutil.ValuesAtPath(data, "A.D.X || C"))
}
func TestValueAtPathFailure(t *testing.T) {
assert.Equal(t, []interface{}(nil), awsutil.ValuesAtPath(data, "C.x"))
assert.Equal(t, []interface{}(nil), awsutil.ValuesAtPath(data, ".x"))
assert.Equal(t, []interface{}{}, awsutil.ValuesAtPath(data, "X.Y.Z"))
assert.Equal(t, []interface{}{}, awsutil.ValuesAtPath(data, "A[100].C"))
assert.Equal(t, []interface{}{}, awsutil.ValuesAtPath(data, "A[3].C"))
assert.Equal(t, []interface{}{}, awsutil.ValuesAtPath(data, "B.B.C.Z"))
assert.Equal(t, []interface{}(nil), awsutil.ValuesAtPath(data, "z[-1].C"))
assert.Equal(t, []interface{}{}, awsutil.ValuesAtPath(nil, "A.B.C"))
assert.Equal(t, []interface{}{}, awsutil.ValuesAtPath(Struct{}, "A"))
}
func TestSetValueAtPathSuccess(t *testing.T) {
var s Struct
awsutil.SetValueAtPath(&s, "C", "test1")
awsutil.SetValueAtPath(&s, "B.B.C", "test2")
awsutil.SetValueAtPath(&s, "B.D.C", "test3")
assert.Equal(t, "test1", s.C)
assert.Equal(t, "test2", s.B.B.C)
assert.Equal(t, "test3", s.B.D.C)
awsutil.SetValueAtPath(&s, "B.*.C", "test0")
assert.Equal(t, "test0", s.B.B.C)
assert.Equal(t, "test0", s.B.D.C)
var s2 Struct
awsutil.SetValueAtAnyPath(&s2, "b.b.c", "test0")
assert.Equal(t, "test0", s2.B.B.C)
awsutil.SetValueAtAnyPath(&s2, "A", []Struct{{}})
assert.Equal(t, []Struct{{}}, s2.A)
}

View file

@ -1,254 +0,0 @@
package aws
import (
"net/http"
"os"
"time"
"github.com/aws/aws-sdk-go/aws/credentials"
)
// DefaultChainCredentials is a Credentials which will find the first available
// credentials Value from the list of Providers.
//
// This should be used in the default case. Once the type of credentials are
// known switching to the specific Credentials will be more efficient.
var DefaultChainCredentials = credentials.NewChainCredentials(
[]credentials.Provider{
&credentials.EnvProvider{},
&credentials.SharedCredentialsProvider{Filename: "", Profile: ""},
&credentials.EC2RoleProvider{ExpiryWindow: 5 * time.Minute},
})
// The default number of retries for a service. The value of -1 indicates that
// the service specific retry default will be used.
const DefaultRetries = -1
// DefaultConfig is the default all service configuration will be based off of.
// By default, all clients use this structure for initialization options unless
// a custom configuration object is passed in.
//
// You may modify this global structure to change all default configuration
// in the SDK. Note that configuration options are copied by value, so any
// modifications must happen before constructing a client.
var DefaultConfig = NewConfig().
WithCredentials(DefaultChainCredentials).
WithRegion(os.Getenv("AWS_REGION")).
WithHTTPClient(http.DefaultClient).
WithMaxRetries(DefaultRetries).
WithLogger(NewDefaultLogger()).
WithLogLevel(LogOff)
// A Config provides service configuration for service clients. By default,
// all clients will use the {DefaultConfig} structure.
type Config struct {
// The credentials object to use when signing requests. Defaults to
// {DefaultChainCredentials}.
Credentials *credentials.Credentials
// An optional endpoint URL (hostname only or fully qualified URI)
// that overrides the default generated endpoint for a client. Set this
// to `""` to use the default generated endpoint.
//
// @note You must still provide a `Region` value when specifying an
// endpoint for a client.
Endpoint *string
// The region to send requests to. This parameter is required and must
// be configured globally or on a per-client basis unless otherwise
// noted. A full list of regions is found in the "Regions and Endpoints"
// document.
//
// @see http://docs.aws.amazon.com/general/latest/gr/rande.html
// AWS Regions and Endpoints
Region *string
// Set this to `true` to disable SSL when sending requests. Defaults
// to `false`.
DisableSSL *bool
// The HTTP client to use when sending requests. Defaults to
// `http.DefaultClient`.
HTTPClient *http.Client
// An integer value representing the logging level. The default log level
// is zero (LogOff), which represents no logging. To enable logging set
// to a LogLevel Value.
LogLevel *LogLevelType
// The logger writer interface to write logging messages to. Defaults to
// standard out.
Logger Logger
// The maximum number of times that a request will be retried for failures.
// Defaults to -1, which defers the max retry setting to the service specific
// configuration.
MaxRetries *int
// Disables semantic parameter validation, which validates input for missing
// required fields and/or other semantic request input errors.
DisableParamValidation *bool
// Disables the computation of request and response checksums, e.g.,
// CRC32 checksums in Amazon DynamoDB.
DisableComputeChecksums *bool
// Set this to `true` to force the request to use path-style addressing,
// i.e., `http://s3.amazonaws.com/BUCKET/KEY`. By default, the S3 client will
// use virtual hosted bucket addressing when possible
// (`http://BUCKET.s3.amazonaws.com/KEY`).
//
// @note This configuration option is specific to the Amazon S3 service.
// @see http://docs.aws.amazon.com/AmazonS3/latest/dev/VirtualHosting.html
// Amazon S3: Virtual Hosting of Buckets
S3ForcePathStyle *bool
}
// NewConfig returns a new Config pointer that can be chained with builder methods to
// set multiple configuration values inline without using pointers.
//
// svc := s3.New(aws.NewConfig().WithRegion("us-west-2").WithMaxRetries(10))
//
func NewConfig() *Config {
return &Config{}
}
// WithCredentials sets a config Credentials value returning a Config pointer
// for chaining.
func (c *Config) WithCredentials(creds *credentials.Credentials) *Config {
c.Credentials = creds
return c
}
// WithEndpoint sets a config Endpoint value returning a Config pointer for
// chaining.
func (c *Config) WithEndpoint(endpoint string) *Config {
c.Endpoint = &endpoint
return c
}
// WithRegion sets a config Region value returning a Config pointer for
// chaining.
func (c *Config) WithRegion(region string) *Config {
c.Region = &region
return c
}
// WithDisableSSL sets a config DisableSSL value returning a Config pointer
// for chaining.
func (c *Config) WithDisableSSL(disable bool) *Config {
c.DisableSSL = &disable
return c
}
// WithHTTPClient sets a config HTTPClient value returning a Config pointer
// for chaining.
func (c *Config) WithHTTPClient(client *http.Client) *Config {
c.HTTPClient = client
return c
}
// WithMaxRetries sets a config MaxRetries value returning a Config pointer
// for chaining.
func (c *Config) WithMaxRetries(max int) *Config {
c.MaxRetries = &max
return c
}
// WithDisableParamValidation sets a config DisableParamValidation value
// returning a Config pointer for chaining.
func (c *Config) WithDisableParamValidation(disable bool) *Config {
c.DisableParamValidation = &disable
return c
}
// WithDisableComputeChecksums sets a config DisableComputeChecksums value
// returning a Config pointer for chaining.
func (c *Config) WithDisableComputeChecksums(disable bool) *Config {
c.DisableComputeChecksums = &disable
return c
}
// WithLogLevel sets a config LogLevel value returning a Config pointer for
// chaining.
func (c *Config) WithLogLevel(level LogLevelType) *Config {
c.LogLevel = &level
return c
}
// WithLogger sets a config Logger value returning a Config pointer for
// chaining.
func (c *Config) WithLogger(logger Logger) *Config {
c.Logger = logger
return c
}
// WithS3ForcePathStyle sets a config S3ForcePathStyle value returning a Config
// pointer for chaining.
func (c *Config) WithS3ForcePathStyle(force bool) *Config {
c.S3ForcePathStyle = &force
return c
}
// Merge returns a new Config with the other Config's attribute values merged into
// this Config. If the other Config's attribute is nil it will not be merged into
// the new Config to be returned.
func (c Config) Merge(other *Config) *Config {
if other == nil {
return &c
}
dst := c
if other.Credentials != nil {
dst.Credentials = other.Credentials
}
if other.Endpoint != nil {
dst.Endpoint = other.Endpoint
}
if other.Region != nil {
dst.Region = other.Region
}
if other.DisableSSL != nil {
dst.DisableSSL = other.DisableSSL
}
if other.HTTPClient != nil {
dst.HTTPClient = other.HTTPClient
}
if other.LogLevel != nil {
dst.LogLevel = other.LogLevel
}
if other.Logger != nil {
dst.Logger = other.Logger
}
if other.MaxRetries != nil {
dst.MaxRetries = other.MaxRetries
}
if other.DisableParamValidation != nil {
dst.DisableParamValidation = other.DisableParamValidation
}
if other.DisableComputeChecksums != nil {
dst.DisableComputeChecksums = other.DisableComputeChecksums
}
if other.S3ForcePathStyle != nil {
dst.S3ForcePathStyle = other.S3ForcePathStyle
}
return &dst
}
// Copy will return a shallow copy of the Config object.
func (c Config) Copy() *Config {
dst := c
return &dst
}

View file

@ -1,87 +0,0 @@
package aws
import (
"net/http"
"reflect"
"testing"
"time"
"github.com/aws/aws-sdk-go/aws/credentials"
)
var testCredentials = credentials.NewChainCredentials([]credentials.Provider{
&credentials.EnvProvider{},
&credentials.SharedCredentialsProvider{
Filename: "TestFilename",
Profile: "TestProfile"},
&credentials.EC2RoleProvider{ExpiryWindow: 5 * time.Minute},
})
var copyTestConfig = Config{
Credentials: testCredentials,
Endpoint: String("CopyTestEndpoint"),
Region: String("COPY_TEST_AWS_REGION"),
DisableSSL: Bool(true),
HTTPClient: http.DefaultClient,
LogLevel: LogLevel(LogDebug),
Logger: NewDefaultLogger(),
MaxRetries: Int(DefaultRetries),
DisableParamValidation: Bool(true),
DisableComputeChecksums: Bool(true),
S3ForcePathStyle: Bool(true),
}
func TestCopy(t *testing.T) {
want := copyTestConfig
got := copyTestConfig.Copy()
if !reflect.DeepEqual(*got, want) {
t.Errorf("Copy() = %+v", got)
t.Errorf(" want %+v", want)
}
}
func TestCopyReturnsNewInstance(t *testing.T) {
want := copyTestConfig
got := copyTestConfig.Copy()
if got == &want {
t.Errorf("Copy() = %p; want different instance as source %p", got, &want)
}
}
var mergeTestZeroValueConfig = Config{}
var mergeTestConfig = Config{
Credentials: testCredentials,
Endpoint: String("MergeTestEndpoint"),
Region: String("MERGE_TEST_AWS_REGION"),
DisableSSL: Bool(true),
HTTPClient: http.DefaultClient,
LogLevel: LogLevel(LogDebug),
Logger: NewDefaultLogger(),
MaxRetries: Int(10),
DisableParamValidation: Bool(true),
DisableComputeChecksums: Bool(true),
S3ForcePathStyle: Bool(true),
}
var mergeTests = []struct {
cfg *Config
in *Config
want *Config
}{
{&Config{}, nil, &Config{}},
{&Config{}, &mergeTestZeroValueConfig, &Config{}},
{&Config{}, &mergeTestConfig, &mergeTestConfig},
}
func TestMerge(t *testing.T) {
for i, tt := range mergeTests {
got := tt.cfg.Merge(tt.in)
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("Config %d %+v", i, tt.cfg)
t.Errorf(" Merge(%+v)", tt.in)
t.Errorf(" got %+v", got)
t.Errorf(" want %+v", tt.want)
}
}
}

View file

@ -1,438 +0,0 @@
package aws_test
import (
"testing"
"time"
"github.com/aws/aws-sdk-go/aws"
"github.com/stretchr/testify/assert"
)
var testCasesStringSlice = [][]string{
{"a", "b", "c", "d", "e"},
{"a", "b", "", "", "e"},
}
func TestStringSlice(t *testing.T) {
for idx, in := range testCasesStringSlice {
if in == nil {
continue
}
out := aws.StringSlice(in)
assert.Len(t, out, len(in), "Unexpected len at idx %d", idx)
for i := range out {
assert.Equal(t, in[i], *(out[i]), "Unexpected value at idx %d", idx)
}
out2 := aws.StringValueSlice(out)
assert.Len(t, out2, len(in), "Unexpected len at idx %d", idx)
assert.Equal(t, in, out2, "Unexpected value at idx %d", idx)
}
}
var testCasesStringValueSlice = [][]*string{
{aws.String("a"), aws.String("b"), nil, aws.String("c")},
}
func TestStringValueSlice(t *testing.T) {
for idx, in := range testCasesStringValueSlice {
if in == nil {
continue
}
out := aws.StringValueSlice(in)
assert.Len(t, out, len(in), "Unexpected len at idx %d", idx)
for i := range out {
if in[i] == nil {
assert.Empty(t, out[i], "Unexpected value at idx %d", idx)
} else {
assert.Equal(t, *(in[i]), out[i], "Unexpected value at idx %d", idx)
}
}
out2 := aws.StringSlice(out)
assert.Len(t, out2, len(in), "Unexpected len at idx %d", idx)
for i := range out2 {
if in[i] == nil {
assert.Empty(t, *(out2[i]), "Unexpected value at idx %d", idx)
} else {
assert.Equal(t, in[i], out2[i], "Unexpected value at idx %d", idx)
}
}
}
}
var testCasesStringMap = []map[string]string{
{"a": "1", "b": "2", "c": "3"},
}
func TestStringMap(t *testing.T) {
for idx, in := range testCasesStringMap {
if in == nil {
continue
}
out := aws.StringMap(in)
assert.Len(t, out, len(in), "Unexpected len at idx %d", idx)
for i := range out {
assert.Equal(t, in[i], *(out[i]), "Unexpected value at idx %d", idx)
}
out2 := aws.StringValueMap(out)
assert.Len(t, out2, len(in), "Unexpected len at idx %d", idx)
assert.Equal(t, in, out2, "Unexpected value at idx %d", idx)
}
}
var testCasesBoolSlice = [][]bool{
{true, true, false, false},
}
func TestBoolSlice(t *testing.T) {
for idx, in := range testCasesBoolSlice {
if in == nil {
continue
}
out := aws.BoolSlice(in)
assert.Len(t, out, len(in), "Unexpected len at idx %d", idx)
for i := range out {
assert.Equal(t, in[i], *(out[i]), "Unexpected value at idx %d", idx)
}
out2 := aws.BoolValueSlice(out)
assert.Len(t, out2, len(in), "Unexpected len at idx %d", idx)
assert.Equal(t, in, out2, "Unexpected value at idx %d", idx)
}
}
var testCasesBoolValueSlice = [][]*bool{}
func TestBoolValueSlice(t *testing.T) {
for idx, in := range testCasesBoolValueSlice {
if in == nil {
continue
}
out := aws.BoolValueSlice(in)
assert.Len(t, out, len(in), "Unexpected len at idx %d", idx)
for i := range out {
if in[i] == nil {
assert.Empty(t, out[i], "Unexpected value at idx %d", idx)
} else {
assert.Equal(t, *(in[i]), out[i], "Unexpected value at idx %d", idx)
}
}
out2 := aws.BoolSlice(out)
assert.Len(t, out2, len(in), "Unexpected len at idx %d", idx)
for i := range out2 {
if in[i] == nil {
assert.Empty(t, *(out2[i]), "Unexpected value at idx %d", idx)
} else {
assert.Equal(t, in[i], out2[i], "Unexpected value at idx %d", idx)
}
}
}
}
var testCasesBoolMap = []map[string]bool{
{"a": true, "b": false, "c": true},
}
func TestBoolMap(t *testing.T) {
for idx, in := range testCasesBoolMap {
if in == nil {
continue
}
out := aws.BoolMap(in)
assert.Len(t, out, len(in), "Unexpected len at idx %d", idx)
for i := range out {
assert.Equal(t, in[i], *(out[i]), "Unexpected value at idx %d", idx)
}
out2 := aws.BoolValueMap(out)
assert.Len(t, out2, len(in), "Unexpected len at idx %d", idx)
assert.Equal(t, in, out2, "Unexpected value at idx %d", idx)
}
}
var testCasesIntSlice = [][]int{
{1, 2, 3, 4},
}
func TestIntSlice(t *testing.T) {
for idx, in := range testCasesIntSlice {
if in == nil {
continue
}
out := aws.IntSlice(in)
assert.Len(t, out, len(in), "Unexpected len at idx %d", idx)
for i := range out {
assert.Equal(t, in[i], *(out[i]), "Unexpected value at idx %d", idx)
}
out2 := aws.IntValueSlice(out)
assert.Len(t, out2, len(in), "Unexpected len at idx %d", idx)
assert.Equal(t, in, out2, "Unexpected value at idx %d", idx)
}
}
var testCasesIntValueSlice = [][]*int{}
func TestIntValueSlice(t *testing.T) {
for idx, in := range testCasesIntValueSlice {
if in == nil {
continue
}
out := aws.IntValueSlice(in)
assert.Len(t, out, len(in), "Unexpected len at idx %d", idx)
for i := range out {
if in[i] == nil {
assert.Empty(t, out[i], "Unexpected value at idx %d", idx)
} else {
assert.Equal(t, *(in[i]), out[i], "Unexpected value at idx %d", idx)
}
}
out2 := aws.IntSlice(out)
assert.Len(t, out2, len(in), "Unexpected len at idx %d", idx)
for i := range out2 {
if in[i] == nil {
assert.Empty(t, *(out2[i]), "Unexpected value at idx %d", idx)
} else {
assert.Equal(t, in[i], out2[i], "Unexpected value at idx %d", idx)
}
}
}
}
var testCasesIntMap = []map[string]int{
{"a": 3, "b": 2, "c": 1},
}
func TestIntMap(t *testing.T) {
for idx, in := range testCasesIntMap {
if in == nil {
continue
}
out := aws.IntMap(in)
assert.Len(t, out, len(in), "Unexpected len at idx %d", idx)
for i := range out {
assert.Equal(t, in[i], *(out[i]), "Unexpected value at idx %d", idx)
}
out2 := aws.IntValueMap(out)
assert.Len(t, out2, len(in), "Unexpected len at idx %d", idx)
assert.Equal(t, in, out2, "Unexpected value at idx %d", idx)
}
}
var testCasesInt64Slice = [][]int64{
{1, 2, 3, 4},
}
func TestInt64Slice(t *testing.T) {
for idx, in := range testCasesInt64Slice {
if in == nil {
continue
}
out := aws.Int64Slice(in)
assert.Len(t, out, len(in), "Unexpected len at idx %d", idx)
for i := range out {
assert.Equal(t, in[i], *(out[i]), "Unexpected value at idx %d", idx)
}
out2 := aws.Int64ValueSlice(out)
assert.Len(t, out2, len(in), "Unexpected len at idx %d", idx)
assert.Equal(t, in, out2, "Unexpected value at idx %d", idx)
}
}
var testCasesInt64ValueSlice = [][]*int64{}
func TestInt64ValueSlice(t *testing.T) {
for idx, in := range testCasesInt64ValueSlice {
if in == nil {
continue
}
out := aws.Int64ValueSlice(in)
assert.Len(t, out, len(in), "Unexpected len at idx %d", idx)
for i := range out {
if in[i] == nil {
assert.Empty(t, out[i], "Unexpected value at idx %d", idx)
} else {
assert.Equal(t, *(in[i]), out[i], "Unexpected value at idx %d", idx)
}
}
out2 := aws.Int64Slice(out)
assert.Len(t, out2, len(in), "Unexpected len at idx %d", idx)
for i := range out2 {
if in[i] == nil {
assert.Empty(t, *(out2[i]), "Unexpected value at idx %d", idx)
} else {
assert.Equal(t, in[i], out2[i], "Unexpected value at idx %d", idx)
}
}
}
}
var testCasesInt64Map = []map[string]int64{
{"a": 3, "b": 2, "c": 1},
}
func TestInt64Map(t *testing.T) {
for idx, in := range testCasesInt64Map {
if in == nil {
continue
}
out := aws.Int64Map(in)
assert.Len(t, out, len(in), "Unexpected len at idx %d", idx)
for i := range out {
assert.Equal(t, in[i], *(out[i]), "Unexpected value at idx %d", idx)
}
out2 := aws.Int64ValueMap(out)
assert.Len(t, out2, len(in), "Unexpected len at idx %d", idx)
assert.Equal(t, in, out2, "Unexpected value at idx %d", idx)
}
}
var testCasesFloat64Slice = [][]float64{
{1, 2, 3, 4},
}
func TestFloat64Slice(t *testing.T) {
for idx, in := range testCasesFloat64Slice {
if in == nil {
continue
}
out := aws.Float64Slice(in)
assert.Len(t, out, len(in), "Unexpected len at idx %d", idx)
for i := range out {
assert.Equal(t, in[i], *(out[i]), "Unexpected value at idx %d", idx)
}
out2 := aws.Float64ValueSlice(out)
assert.Len(t, out2, len(in), "Unexpected len at idx %d", idx)
assert.Equal(t, in, out2, "Unexpected value at idx %d", idx)
}
}
var testCasesFloat64ValueSlice = [][]*float64{}
func TestFloat64ValueSlice(t *testing.T) {
for idx, in := range testCasesFloat64ValueSlice {
if in == nil {
continue
}
out := aws.Float64ValueSlice(in)
assert.Len(t, out, len(in), "Unexpected len at idx %d", idx)
for i := range out {
if in[i] == nil {
assert.Empty(t, out[i], "Unexpected value at idx %d", idx)
} else {
assert.Equal(t, *(in[i]), out[i], "Unexpected value at idx %d", idx)
}
}
out2 := aws.Float64Slice(out)
assert.Len(t, out2, len(in), "Unexpected len at idx %d", idx)
for i := range out2 {
if in[i] == nil {
assert.Empty(t, *(out2[i]), "Unexpected value at idx %d", idx)
} else {
assert.Equal(t, in[i], out2[i], "Unexpected value at idx %d", idx)
}
}
}
}
var testCasesFloat64Map = []map[string]float64{
{"a": 3, "b": 2, "c": 1},
}
func TestFloat64Map(t *testing.T) {
for idx, in := range testCasesFloat64Map {
if in == nil {
continue
}
out := aws.Float64Map(in)
assert.Len(t, out, len(in), "Unexpected len at idx %d", idx)
for i := range out {
assert.Equal(t, in[i], *(out[i]), "Unexpected value at idx %d", idx)
}
out2 := aws.Float64ValueMap(out)
assert.Len(t, out2, len(in), "Unexpected len at idx %d", idx)
assert.Equal(t, in, out2, "Unexpected value at idx %d", idx)
}
}
var testCasesTimeSlice = [][]time.Time{
{time.Now(), time.Now().AddDate(100, 0, 0)},
}
func TestTimeSlice(t *testing.T) {
for idx, in := range testCasesTimeSlice {
if in == nil {
continue
}
out := aws.TimeSlice(in)
assert.Len(t, out, len(in), "Unexpected len at idx %d", idx)
for i := range out {
assert.Equal(t, in[i], *(out[i]), "Unexpected value at idx %d", idx)
}
out2 := aws.TimeValueSlice(out)
assert.Len(t, out2, len(in), "Unexpected len at idx %d", idx)
assert.Equal(t, in, out2, "Unexpected value at idx %d", idx)
}
}
var testCasesTimeValueSlice = [][]*time.Time{}
func TestTimeValueSlice(t *testing.T) {
for idx, in := range testCasesTimeValueSlice {
if in == nil {
continue
}
out := aws.TimeValueSlice(in)
assert.Len(t, out, len(in), "Unexpected len at idx %d", idx)
for i := range out {
if in[i] == nil {
assert.Empty(t, out[i], "Unexpected value at idx %d", idx)
} else {
assert.Equal(t, *(in[i]), out[i], "Unexpected value at idx %d", idx)
}
}
out2 := aws.TimeSlice(out)
assert.Len(t, out2, len(in), "Unexpected len at idx %d", idx)
for i := range out2 {
if in[i] == nil {
assert.Empty(t, *(out2[i]), "Unexpected value at idx %d", idx)
} else {
assert.Equal(t, in[i], out2[i], "Unexpected value at idx %d", idx)
}
}
}
}
var testCasesTimeMap = []map[string]time.Time{
{"a": time.Now().AddDate(-100, 0, 0), "b": time.Now()},
}
func TestTimeMap(t *testing.T) {
for idx, in := range testCasesTimeMap {
if in == nil {
continue
}
out := aws.TimeMap(in)
assert.Len(t, out, len(in), "Unexpected len at idx %d", idx)
for i := range out {
assert.Equal(t, in[i], *(out[i]), "Unexpected value at idx %d", idx)
}
out2 := aws.TimeValueMap(out)
assert.Len(t, out2, len(in), "Unexpected len at idx %d", idx)
assert.Equal(t, in, out2, "Unexpected value at idx %d", idx)
}
}

View file

@ -1,73 +0,0 @@
package credentials
import (
"testing"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/stretchr/testify/assert"
)
func TestChainProviderGet(t *testing.T) {
p := &ChainProvider{
Providers: []Provider{
&stubProvider{err: awserr.New("FirstError", "first provider error", nil)},
&stubProvider{err: awserr.New("SecondError", "second provider error", nil)},
&stubProvider{
creds: Value{
AccessKeyID: "AKID",
SecretAccessKey: "SECRET",
SessionToken: "",
},
},
},
}
creds, err := p.Retrieve()
assert.Nil(t, err, "Expect no error")
assert.Equal(t, "AKID", creds.AccessKeyID, "Expect access key ID to match")
assert.Equal(t, "SECRET", creds.SecretAccessKey, "Expect secret access key to match")
assert.Empty(t, creds.SessionToken, "Expect session token to be empty")
}
func TestChainProviderIsExpired(t *testing.T) {
stubProvider := &stubProvider{expired: true}
p := &ChainProvider{
Providers: []Provider{
stubProvider,
},
}
assert.True(t, p.IsExpired(), "Expect expired to be true before any Retrieve")
_, err := p.Retrieve()
assert.Nil(t, err, "Expect no error")
assert.False(t, p.IsExpired(), "Expect not expired after retrieve")
stubProvider.expired = true
assert.True(t, p.IsExpired(), "Expect return of expired provider")
_, err = p.Retrieve()
assert.False(t, p.IsExpired(), "Expect not expired after retrieve")
}
func TestChainProviderWithNoProvider(t *testing.T) {
p := &ChainProvider{
Providers: []Provider{},
}
assert.True(t, p.IsExpired(), "Expect expired with no providers")
_, err := p.Retrieve()
assert.Equal(t, ErrNoValidProvidersFoundInChain, err, "Expect no providers error returned")
}
func TestChainProviderWithNoValidProvider(t *testing.T) {
p := &ChainProvider{
Providers: []Provider{
&stubProvider{err: awserr.New("FirstError", "first provider error", nil)},
&stubProvider{err: awserr.New("SecondError", "second provider error", nil)},
},
}
assert.True(t, p.IsExpired(), "Expect expired with no providers")
_, err := p.Retrieve()
assert.Equal(t, ErrNoValidProvidersFoundInChain, err, "Expect no providers error returned")
}

View file

@ -1,62 +0,0 @@
package credentials
import (
"testing"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/stretchr/testify/assert"
)
type stubProvider struct {
creds Value
expired bool
err error
}
func (s *stubProvider) Retrieve() (Value, error) {
s.expired = false
return s.creds, s.err
}
func (s *stubProvider) IsExpired() bool {
return s.expired
}
func TestCredentialsGet(t *testing.T) {
c := NewCredentials(&stubProvider{
creds: Value{
AccessKeyID: "AKID",
SecretAccessKey: "SECRET",
SessionToken: "",
},
expired: true,
})
creds, err := c.Get()
assert.Nil(t, err, "Expected no error")
assert.Equal(t, "AKID", creds.AccessKeyID, "Expect access key ID to match")
assert.Equal(t, "SECRET", creds.SecretAccessKey, "Expect secret access key to match")
assert.Empty(t, creds.SessionToken, "Expect session token to be empty")
}
func TestCredentialsGetWithError(t *testing.T) {
c := NewCredentials(&stubProvider{err: awserr.New("provider error", "", nil), expired: true})
_, err := c.Get()
assert.Equal(t, "provider error", err.(awserr.Error).Code(), "Expected provider error")
}
func TestCredentialsExpire(t *testing.T) {
stub := &stubProvider{}
c := NewCredentials(stub)
stub.expired = false
assert.True(t, c.IsExpired(), "Expected to start out expired")
c.Expire()
assert.True(t, c.IsExpired(), "Expected to be expired")
c.forceRefresh = false
assert.False(t, c.IsExpired(), "Expected not to be expired")
stub.expired = true
assert.True(t, c.IsExpired(), "Expected to be expired")
}

View file

@ -1,162 +0,0 @@
package credentials
import (
"bufio"
"encoding/json"
"fmt"
"net/http"
"time"
"github.com/aws/aws-sdk-go/aws/awserr"
)
const metadataCredentialsEndpoint = "http://169.254.169.254/latest/meta-data/iam/security-credentials/"
// A EC2RoleProvider retrieves credentials from the EC2 service, and keeps track if
// those credentials are expired.
//
// Example how to configure the EC2RoleProvider with custom http Client, Endpoint
// or ExpiryWindow
//
// p := &credentials.EC2RoleProvider{
// // Pass in a custom timeout to be used when requesting
// // IAM EC2 Role credentials.
// Client: &http.Client{
// Timeout: 10 * time.Second,
// },
// // Use default EC2 Role metadata endpoint, Alternate endpoints can be
// // specified setting Endpoint to something else.
// Endpoint: "",
// // Do not use early expiry of credentials. If a non zero value is
// // specified the credentials will be expired early
// ExpiryWindow: 0,
// }
type EC2RoleProvider struct {
Expiry
// Endpoint must be fully quantified URL
Endpoint string
// HTTP client to use when connecting to EC2 service
Client *http.Client
// ExpiryWindow will allow the credentials to trigger refreshing prior to
// the credentials actually expiring. This is beneficial so race conditions
// with expiring credentials do not cause request to fail unexpectedly
// due to ExpiredTokenException exceptions.
//
// So a ExpiryWindow of 10s would cause calls to IsExpired() to return true
// 10 seconds before the credentials are actually expired.
//
// If ExpiryWindow is 0 or less it will be ignored.
ExpiryWindow time.Duration
}
// NewEC2RoleCredentials returns a pointer to a new Credentials object
// wrapping the EC2RoleProvider.
//
// Takes a custom http.Client which can be configured for custom handling of
// things such as timeout.
//
// Endpoint is the URL that the EC2RoleProvider will connect to when retrieving
// role and credentials.
//
// Window is the expiry window that will be subtracted from the expiry returned
// by the role credential request. This is done so that the credentials will
// expire sooner than their actual lifespan.
func NewEC2RoleCredentials(client *http.Client, endpoint string, window time.Duration) *Credentials {
return NewCredentials(&EC2RoleProvider{
Endpoint: endpoint,
Client: client,
ExpiryWindow: window,
})
}
// Retrieve retrieves credentials from the EC2 service.
// Error will be returned if the request fails, or unable to extract
// the desired credentials.
func (m *EC2RoleProvider) Retrieve() (Value, error) {
if m.Client == nil {
m.Client = http.DefaultClient
}
if m.Endpoint == "" {
m.Endpoint = metadataCredentialsEndpoint
}
credsList, err := requestCredList(m.Client, m.Endpoint)
if err != nil {
return Value{}, err
}
if len(credsList) == 0 {
return Value{}, awserr.New("EmptyEC2RoleList", "empty EC2 Role list", nil)
}
credsName := credsList[0]
roleCreds, err := requestCred(m.Client, m.Endpoint, credsName)
if err != nil {
return Value{}, err
}
m.SetExpiration(roleCreds.Expiration, m.ExpiryWindow)
return Value{
AccessKeyID: roleCreds.AccessKeyID,
SecretAccessKey: roleCreds.SecretAccessKey,
SessionToken: roleCreds.Token,
}, nil
}
// A ec2RoleCredRespBody provides the shape for deserializing credential
// request responses.
type ec2RoleCredRespBody struct {
Expiration time.Time
AccessKeyID string
SecretAccessKey string
Token string
}
// requestCredList requests a list of credentials from the EC2 service.
// If there are no credentials, or there is an error making or receiving the request
func requestCredList(client *http.Client, endpoint string) ([]string, error) {
resp, err := client.Get(endpoint)
if err != nil {
return nil, awserr.New("ListEC2Role", "failed to list EC2 Roles", err)
}
defer resp.Body.Close()
credsList := []string{}
s := bufio.NewScanner(resp.Body)
for s.Scan() {
credsList = append(credsList, s.Text())
}
if err := s.Err(); err != nil {
return nil, awserr.New("ReadEC2Role", "failed to read list of EC2 Roles", err)
}
return credsList, nil
}
// requestCred requests the credentials for a specific credentials from the EC2 service.
//
// If the credentials cannot be found, or there is an error reading the response
// and error will be returned.
func requestCred(client *http.Client, endpoint, credsName string) (*ec2RoleCredRespBody, error) {
resp, err := client.Get(endpoint + credsName)
if err != nil {
return nil, awserr.New("GetEC2RoleCredentials",
fmt.Sprintf("failed to get %s EC2 Role credentials", credsName),
err)
}
defer resp.Body.Close()
respCreds := &ec2RoleCredRespBody{}
if err := json.NewDecoder(resp.Body).Decode(respCreds); err != nil {
return nil, awserr.New("DecodeEC2RoleCredentials",
fmt.Sprintf("failed to decode %s EC2 Role credentials", credsName),
err)
}
return respCreds, nil
}

View file

@ -1,108 +0,0 @@
package credentials
import (
"fmt"
"github.com/stretchr/testify/assert"
"net/http"
"net/http/httptest"
"testing"
"time"
)
func initTestServer(expireOn string) *httptest.Server {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.RequestURI == "/" {
fmt.Fprintln(w, "/creds")
} else {
fmt.Fprintf(w, `{
"AccessKeyId" : "accessKey",
"SecretAccessKey" : "secret",
"Token" : "token",
"Expiration" : "%s"
}`, expireOn)
}
}))
return server
}
func TestEC2RoleProvider(t *testing.T) {
server := initTestServer("2014-12-16T01:51:37Z")
defer server.Close()
p := &EC2RoleProvider{Client: http.DefaultClient, Endpoint: server.URL}
creds, err := p.Retrieve()
assert.Nil(t, err, "Expect no error")
assert.Equal(t, "accessKey", creds.AccessKeyID, "Expect access key ID to match")
assert.Equal(t, "secret", creds.SecretAccessKey, "Expect secret access key to match")
assert.Equal(t, "token", creds.SessionToken, "Expect session token to match")
}
func TestEC2RoleProviderIsExpired(t *testing.T) {
server := initTestServer("2014-12-16T01:51:37Z")
defer server.Close()
p := &EC2RoleProvider{Client: http.DefaultClient, Endpoint: server.URL}
p.CurrentTime = func() time.Time {
return time.Date(2014, 12, 15, 21, 26, 0, 0, time.UTC)
}
assert.True(t, p.IsExpired(), "Expect creds to be expired before retrieve.")
_, err := p.Retrieve()
assert.Nil(t, err, "Expect no error")
assert.False(t, p.IsExpired(), "Expect creds to not be expired after retrieve.")
p.CurrentTime = func() time.Time {
return time.Date(3014, 12, 15, 21, 26, 0, 0, time.UTC)
}
assert.True(t, p.IsExpired(), "Expect creds to be expired.")
}
func TestEC2RoleProviderExpiryWindowIsExpired(t *testing.T) {
server := initTestServer("2014-12-16T01:51:37Z")
defer server.Close()
p := &EC2RoleProvider{Client: http.DefaultClient, Endpoint: server.URL, ExpiryWindow: time.Hour * 1}
p.CurrentTime = func() time.Time {
return time.Date(2014, 12, 15, 0, 51, 37, 0, time.UTC)
}
assert.True(t, p.IsExpired(), "Expect creds to be expired before retrieve.")
_, err := p.Retrieve()
assert.Nil(t, err, "Expect no error")
assert.False(t, p.IsExpired(), "Expect creds to not be expired after retrieve.")
p.CurrentTime = func() time.Time {
return time.Date(2014, 12, 16, 0, 55, 37, 0, time.UTC)
}
assert.True(t, p.IsExpired(), "Expect creds to be expired.")
}
func BenchmarkEC2RoleProvider(b *testing.B) {
server := initTestServer("2014-12-16T01:51:37Z")
defer server.Close()
p := &EC2RoleProvider{Client: http.DefaultClient, Endpoint: server.URL}
_, err := p.Retrieve()
if err != nil {
b.Fatal(err)
}
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
_, err := p.Retrieve()
if err != nil {
b.Fatal(err)
}
}
})
}

View file

@ -1,70 +0,0 @@
package credentials
import (
"github.com/stretchr/testify/assert"
"os"
"testing"
)
func TestEnvProviderRetrieve(t *testing.T) {
os.Clearenv()
os.Setenv("AWS_ACCESS_KEY_ID", "access")
os.Setenv("AWS_SECRET_ACCESS_KEY", "secret")
os.Setenv("AWS_SESSION_TOKEN", "token")
e := EnvProvider{}
creds, err := e.Retrieve()
assert.Nil(t, err, "Expect no error")
assert.Equal(t, "access", creds.AccessKeyID, "Expect access key ID to match")
assert.Equal(t, "secret", creds.SecretAccessKey, "Expect secret access key to match")
assert.Equal(t, "token", creds.SessionToken, "Expect session token to match")
}
func TestEnvProviderIsExpired(t *testing.T) {
os.Clearenv()
os.Setenv("AWS_ACCESS_KEY_ID", "access")
os.Setenv("AWS_SECRET_ACCESS_KEY", "secret")
os.Setenv("AWS_SESSION_TOKEN", "token")
e := EnvProvider{}
assert.True(t, e.IsExpired(), "Expect creds to be expired before retrieve.")
_, err := e.Retrieve()
assert.Nil(t, err, "Expect no error")
assert.False(t, e.IsExpired(), "Expect creds to not be expired after retrieve.")
}
func TestEnvProviderNoAccessKeyID(t *testing.T) {
os.Clearenv()
os.Setenv("AWS_SECRET_ACCESS_KEY", "secret")
e := EnvProvider{}
creds, err := e.Retrieve()
assert.Equal(t, ErrAccessKeyIDNotFound, err, "ErrAccessKeyIDNotFound expected, but was %#v error: %#v", creds, err)
}
func TestEnvProviderNoSecretAccessKey(t *testing.T) {
os.Clearenv()
os.Setenv("AWS_ACCESS_KEY_ID", "access")
e := EnvProvider{}
creds, err := e.Retrieve()
assert.Equal(t, ErrSecretAccessKeyNotFound, err, "ErrSecretAccessKeyNotFound expected, but was %#v error: %#v", creds, err)
}
func TestEnvProviderAlternateNames(t *testing.T) {
os.Clearenv()
os.Setenv("AWS_ACCESS_KEY", "access")
os.Setenv("AWS_SECRET_KEY", "secret")
e := EnvProvider{}
creds, err := e.Retrieve()
assert.Nil(t, err, "Expect no error")
assert.Equal(t, "access", creds.AccessKeyID, "Expected access key ID")
assert.Equal(t, "secret", creds.SecretAccessKey, "Expected secret access key")
assert.Empty(t, creds.SessionToken, "Expected no token")
}

View file

@ -1,77 +0,0 @@
package credentials
import (
"github.com/stretchr/testify/assert"
"os"
"testing"
)
func TestSharedCredentialsProvider(t *testing.T) {
os.Clearenv()
p := SharedCredentialsProvider{Filename: "example.ini", Profile: ""}
creds, err := p.Retrieve()
assert.Nil(t, err, "Expect no error")
assert.Equal(t, "accessKey", creds.AccessKeyID, "Expect access key ID to match")
assert.Equal(t, "secret", creds.SecretAccessKey, "Expect secret access key to match")
assert.Equal(t, "token", creds.SessionToken, "Expect session token to match")
}
func TestSharedCredentialsProviderIsExpired(t *testing.T) {
os.Clearenv()
p := SharedCredentialsProvider{Filename: "example.ini", Profile: ""}
assert.True(t, p.IsExpired(), "Expect creds to be expired before retrieve")
_, err := p.Retrieve()
assert.Nil(t, err, "Expect no error")
assert.False(t, p.IsExpired(), "Expect creds to not be expired after retrieve")
}
func TestSharedCredentialsProviderWithAWS_PROFILE(t *testing.T) {
os.Clearenv()
os.Setenv("AWS_PROFILE", "no_token")
p := SharedCredentialsProvider{Filename: "example.ini", Profile: ""}
creds, err := p.Retrieve()
assert.Nil(t, err, "Expect no error")
assert.Equal(t, "accessKey", creds.AccessKeyID, "Expect access key ID to match")
assert.Equal(t, "secret", creds.SecretAccessKey, "Expect secret access key to match")
assert.Empty(t, creds.SessionToken, "Expect no token")
}
func TestSharedCredentialsProviderWithoutTokenFromProfile(t *testing.T) {
os.Clearenv()
p := SharedCredentialsProvider{Filename: "example.ini", Profile: "no_token"}
creds, err := p.Retrieve()
assert.Nil(t, err, "Expect no error")
assert.Equal(t, "accessKey", creds.AccessKeyID, "Expect access key ID to match")
assert.Equal(t, "secret", creds.SecretAccessKey, "Expect secret access key to match")
assert.Empty(t, creds.SessionToken, "Expect no token")
}
func BenchmarkSharedCredentialsProvider(b *testing.B) {
os.Clearenv()
p := SharedCredentialsProvider{Filename: "example.ini", Profile: ""}
_, err := p.Retrieve()
if err != nil {
b.Fatal(err)
}
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
_, err := p.Retrieve()
if err != nil {
b.Fatal(err)
}
}
})
}

View file

@ -1,34 +0,0 @@
package credentials
import (
"github.com/stretchr/testify/assert"
"testing"
)
func TestStaticProviderGet(t *testing.T) {
s := StaticProvider{
Value: Value{
AccessKeyID: "AKID",
SecretAccessKey: "SECRET",
SessionToken: "",
},
}
creds, err := s.Retrieve()
assert.Nil(t, err, "Expect no error")
assert.Equal(t, "AKID", creds.AccessKeyID, "Expect access key ID to match")
assert.Equal(t, "SECRET", creds.SecretAccessKey, "Expect secret access key to match")
assert.Empty(t, creds.SessionToken, "Expect no session token")
}
func TestStaticProviderIsExpired(t *testing.T) {
s := StaticProvider{
Value: Value{
AccessKeyID: "AKID",
SecretAccessKey: "SECRET",
SessionToken: "",
},
}
assert.False(t, s.IsExpired(), "Expect static credentials to never expire")
}

View file

@ -1,120 +0,0 @@
// Package stscreds are credential Providers to retrieve STS AWS credentials.
//
// STS provides multiple ways to retrieve credentials which can be used when making
// future AWS service API operation calls.
package stscreds
import (
"fmt"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/service/sts"
"time"
)
// AssumeRoler represents the minimal subset of the STS client API used by this provider.
type AssumeRoler interface {
AssumeRole(input *sts.AssumeRoleInput) (*sts.AssumeRoleOutput, error)
}
// AssumeRoleProvider retrieves temporary credentials from the STS service, and
// keeps track of their expiration time. This provider must be used explicitly,
// as it is not included in the credentials chain.
//
// Example how to configure a service to use this provider:
//
// config := &aws.Config{
// Credentials: stscreds.NewCredentials(nil, "arn-of-the-role-to-assume", 10*time.Second),
// })
// // Use config for creating your AWS service.
//
// Example how to obtain customised credentials:
//
// provider := &stscreds.Provider{
// // Extend the duration to 1 hour.
// Duration: time.Hour,
// // Custom role name.
// RoleSessionName: "custom-session-name",
// }
// creds := credentials.NewCredentials(provider)
//
type AssumeRoleProvider struct {
credentials.Expiry
// Custom STS client. If not set the default STS client will be used.
Client AssumeRoler
// Role to be assumed.
RoleARN string
// Session name, if you wish to reuse the credentials elsewhere.
RoleSessionName string
// Expiry duration of the STS credentials. Defaults to 15 minutes if not set.
Duration time.Duration
// ExpiryWindow will allow the credentials to trigger refreshing prior to
// the credentials actually expiring. This is beneficial so race conditions
// with expiring credentials do not cause request to fail unexpectedly
// due to ExpiredTokenException exceptions.
//
// So a ExpiryWindow of 10s would cause calls to IsExpired() to return true
// 10 seconds before the credentials are actually expired.
//
// If ExpiryWindow is 0 or less it will be ignored.
ExpiryWindow time.Duration
}
// NewCredentials returns a pointer to a new Credentials object wrapping the
// AssumeRoleProvider. The credentials will expire every 15 minutes and the
// role will be named after a nanosecond timestamp of this operation.
//
// The sts and roleARN parameters are used for building the "AssumeRole" call.
// Pass nil as sts to use the default client.
//
// Window is the expiry window that will be subtracted from the expiry returned
// by the role credential request. This is done so that the credentials will
// expire sooner than their actual lifespan.
func NewCredentials(client AssumeRoler, roleARN string, window time.Duration) *credentials.Credentials {
return credentials.NewCredentials(&AssumeRoleProvider{
Client: client,
RoleARN: roleARN,
ExpiryWindow: window,
})
}
// Retrieve generates a new set of temporary credentials using STS.
func (p *AssumeRoleProvider) Retrieve() (credentials.Value, error) {
// Apply defaults where parameters are not set.
if p.Client == nil {
p.Client = sts.New(nil)
}
if p.RoleSessionName == "" {
// Try to work out a role name that will hopefully end up unique.
p.RoleSessionName = fmt.Sprintf("%d", time.Now().UTC().UnixNano())
}
if p.Duration == 0 {
// Expire as often as AWS permits.
p.Duration = 15 * time.Minute
}
roleOutput, err := p.Client.AssumeRole(&sts.AssumeRoleInput{
DurationSeconds: aws.Int64(int64(p.Duration / time.Second)),
RoleARN: aws.String(p.RoleARN),
RoleSessionName: aws.String(p.RoleSessionName),
})
if err != nil {
return credentials.Value{}, err
}
// We will proactively generate new credentials before they expire.
p.SetExpiration(*roleOutput.Credentials.Expiration, p.ExpiryWindow)
return credentials.Value{
AccessKeyID: *roleOutput.Credentials.AccessKeyID,
SecretAccessKey: *roleOutput.Credentials.SecretAccessKey,
SessionToken: *roleOutput.Credentials.SessionToken,
}, nil
}

View file

@ -1,58 +0,0 @@
package stscreds
import (
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/service/sts"
"github.com/stretchr/testify/assert"
"testing"
"time"
)
type stubSTS struct {
}
func (s *stubSTS) AssumeRole(input *sts.AssumeRoleInput) (*sts.AssumeRoleOutput, error) {
expiry := time.Now().Add(60 * time.Minute)
return &sts.AssumeRoleOutput{
Credentials: &sts.Credentials{
// Just reflect the role arn to the provider.
AccessKeyID: input.RoleARN,
SecretAccessKey: aws.String("assumedSecretAccessKey"),
SessionToken: aws.String("assumedSessionToken"),
Expiration: &expiry,
},
}, nil
}
func TestAssumeRoleProvider(t *testing.T) {
stub := &stubSTS{}
p := &AssumeRoleProvider{
Client: stub,
RoleARN: "roleARN",
}
creds, err := p.Retrieve()
assert.Nil(t, err, "Expect no error")
assert.Equal(t, "roleARN", creds.AccessKeyID, "Expect access key ID to be reflected role ARN")
assert.Equal(t, "assumedSecretAccessKey", creds.SecretAccessKey, "Expect secret access key to match")
assert.Equal(t, "assumedSessionToken", creds.SessionToken, "Expect session token to match")
}
func BenchmarkAssumeRoleProvider(b *testing.B) {
stub := &stubSTS{}
p := &AssumeRoleProvider{
Client: stub,
RoleARN: "roleARN",
}
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
_, err := p.Retrieve()
if err != nil {
b.Fatal(err)
}
}
})
}

View file

@ -1,157 +0,0 @@
package aws
import (
"bytes"
"fmt"
"io"
"io/ioutil"
"net/http"
"net/url"
"regexp"
"strconv"
"time"
"github.com/aws/aws-sdk-go/aws/awserr"
)
var sleepDelay = func(delay time.Duration) {
time.Sleep(delay)
}
// Interface for matching types which also have a Len method.
type lener interface {
Len() int
}
// BuildContentLength builds the content length of a request based on the body,
// or will use the HTTPRequest.Header's "Content-Length" if defined. If unable
// to determine request body length and no "Content-Length" was specified it will panic.
func BuildContentLength(r *Request) {
if slength := r.HTTPRequest.Header.Get("Content-Length"); slength != "" {
length, _ := strconv.ParseInt(slength, 10, 64)
r.HTTPRequest.ContentLength = length
return
}
var length int64
switch body := r.Body.(type) {
case nil:
length = 0
case lener:
length = int64(body.Len())
case io.Seeker:
r.bodyStart, _ = body.Seek(0, 1)
end, _ := body.Seek(0, 2)
body.Seek(r.bodyStart, 0) // make sure to seek back to original location
length = end - r.bodyStart
default:
panic("Cannot get length of body, must provide `ContentLength`")
}
r.HTTPRequest.ContentLength = length
r.HTTPRequest.Header.Set("Content-Length", fmt.Sprintf("%d", length))
}
// UserAgentHandler is a request handler for injecting User agent into requests.
func UserAgentHandler(r *Request) {
r.HTTPRequest.Header.Set("User-Agent", SDKName+"/"+SDKVersion)
}
var reStatusCode = regexp.MustCompile(`^(\d+)`)
// SendHandler is a request handler to send service request using HTTP client.
func SendHandler(r *Request) {
var err error
r.HTTPResponse, err = r.Service.Config.HTTPClient.Do(r.HTTPRequest)
if err != nil {
// Capture the case where url.Error is returned for error processing
// response. e.g. 301 without location header comes back as string
// error and r.HTTPResponse is nil. Other url redirect errors will
// comeback in a similar method.
if e, ok := err.(*url.Error); ok {
if s := reStatusCode.FindStringSubmatch(e.Error()); s != nil {
code, _ := strconv.ParseInt(s[1], 10, 64)
r.HTTPResponse = &http.Response{
StatusCode: int(code),
Status: http.StatusText(int(code)),
Body: ioutil.NopCloser(bytes.NewReader([]byte{})),
}
return
}
}
if r.HTTPRequest == nil {
// Add a dummy request response object to ensure the HTTPResponse
// value is consistent.
r.HTTPResponse = &http.Response{
StatusCode: int(0),
Status: http.StatusText(int(0)),
Body: ioutil.NopCloser(bytes.NewReader([]byte{})),
}
}
// Catch all other request errors.
r.Error = awserr.New("RequestError", "send request failed", err)
r.Retryable = Bool(true) // network errors are retryable
}
}
// ValidateResponseHandler is a request handler to validate service response.
func ValidateResponseHandler(r *Request) {
if r.HTTPResponse.StatusCode == 0 || r.HTTPResponse.StatusCode >= 300 {
// this may be replaced by an UnmarshalError handler
r.Error = awserr.New("UnknownError", "unknown error", nil)
}
}
// AfterRetryHandler performs final checks to determine if the request should
// be retried and how long to delay.
func AfterRetryHandler(r *Request) {
// If one of the other handlers already set the retry state
// we don't want to override it based on the service's state
if r.Retryable == nil {
r.Retryable = Bool(r.Service.ShouldRetry(r))
}
if r.WillRetry() {
r.RetryDelay = r.Service.RetryRules(r)
sleepDelay(r.RetryDelay)
// when the expired token exception occurs the credentials
// need to be expired locally so that the next request to
// get credentials will trigger a credentials refresh.
if r.Error != nil {
if err, ok := r.Error.(awserr.Error); ok {
if isCodeExpiredCreds(err.Code()) {
r.Config.Credentials.Expire()
}
}
}
r.RetryCount++
r.Error = nil
}
}
var (
// ErrMissingRegion is an error that is returned if region configuration is
// not found.
//
// @readonly
ErrMissingRegion error = awserr.New("MissingRegion", "could not find region configuration", nil)
// ErrMissingEndpoint is an error that is returned if an endpoint cannot be
// resolved for a service.
//
// @readonly
ErrMissingEndpoint error = awserr.New("MissingEndpoint", "'Endpoint' configuration is required for this service", nil)
)
// ValidateEndpointHandler is a request handler to validate a request had the
// appropriate Region and Endpoint set. Will set r.Error if the endpoint or
// region is not valid.
func ValidateEndpointHandler(r *Request) {
if r.Service.SigningRegion == "" && StringValue(r.Service.Config.Region) == "" {
r.Error = ErrMissingRegion
} else if r.Service.Endpoint == "" {
r.Error = ErrMissingEndpoint
}
}

View file

@ -1,81 +0,0 @@
package aws
import (
"net/http"
"os"
"testing"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/stretchr/testify/assert"
)
func TestValidateEndpointHandler(t *testing.T) {
os.Clearenv()
svc := NewService(NewConfig().WithRegion("us-west-2"))
svc.Handlers.Clear()
svc.Handlers.Validate.PushBack(ValidateEndpointHandler)
req := NewRequest(svc, &Operation{Name: "Operation"}, nil, nil)
err := req.Build()
assert.NoError(t, err)
}
func TestValidateEndpointHandlerErrorRegion(t *testing.T) {
os.Clearenv()
svc := NewService(nil)
svc.Handlers.Clear()
svc.Handlers.Validate.PushBack(ValidateEndpointHandler)
req := NewRequest(svc, &Operation{Name: "Operation"}, nil, nil)
err := req.Build()
assert.Error(t, err)
assert.Equal(t, ErrMissingRegion, err)
}
type mockCredsProvider struct {
expired bool
retrieveCalled bool
}
func (m *mockCredsProvider) Retrieve() (credentials.Value, error) {
m.retrieveCalled = true
return credentials.Value{}, nil
}
func (m *mockCredsProvider) IsExpired() bool {
return m.expired
}
func TestAfterRetryRefreshCreds(t *testing.T) {
os.Clearenv()
credProvider := &mockCredsProvider{}
svc := NewService(&Config{Credentials: credentials.NewCredentials(credProvider), MaxRetries: Int(1)})
svc.Handlers.Clear()
svc.Handlers.ValidateResponse.PushBack(func(r *Request) {
r.Error = awserr.New("UnknownError", "", nil)
r.HTTPResponse = &http.Response{StatusCode: 400}
})
svc.Handlers.UnmarshalError.PushBack(func(r *Request) {
r.Error = awserr.New("ExpiredTokenException", "", nil)
})
svc.Handlers.AfterRetry.PushBack(func(r *Request) {
AfterRetryHandler(r)
})
assert.True(t, svc.Config.Credentials.IsExpired(), "Expect to start out expired")
assert.False(t, credProvider.retrieveCalled)
req := NewRequest(svc, &Operation{Name: "Operation"}, nil, nil)
req.Send()
assert.True(t, svc.Config.Credentials.IsExpired())
assert.False(t, credProvider.retrieveCalled)
_, err := svc.Config.Credentials.Get()
assert.NoError(t, err)
assert.True(t, credProvider.retrieveCalled)
}

View file

@ -1,85 +0,0 @@
package aws
// A Handlers provides a collection of request handlers for various
// stages of handling requests.
type Handlers struct {
Validate HandlerList
Build HandlerList
Sign HandlerList
Send HandlerList
ValidateResponse HandlerList
Unmarshal HandlerList
UnmarshalMeta HandlerList
UnmarshalError HandlerList
Retry HandlerList
AfterRetry HandlerList
}
// copy returns of this handler's lists.
func (h *Handlers) copy() Handlers {
return Handlers{
Validate: h.Validate.copy(),
Build: h.Build.copy(),
Sign: h.Sign.copy(),
Send: h.Send.copy(),
ValidateResponse: h.ValidateResponse.copy(),
Unmarshal: h.Unmarshal.copy(),
UnmarshalError: h.UnmarshalError.copy(),
UnmarshalMeta: h.UnmarshalMeta.copy(),
Retry: h.Retry.copy(),
AfterRetry: h.AfterRetry.copy(),
}
}
// Clear removes callback functions for all handlers
func (h *Handlers) Clear() {
h.Validate.Clear()
h.Build.Clear()
h.Send.Clear()
h.Sign.Clear()
h.Unmarshal.Clear()
h.UnmarshalMeta.Clear()
h.UnmarshalError.Clear()
h.ValidateResponse.Clear()
h.Retry.Clear()
h.AfterRetry.Clear()
}
// A HandlerList manages zero or more handlers in a list.
type HandlerList struct {
list []func(*Request)
}
// copy creates a copy of the handler list.
func (l *HandlerList) copy() HandlerList {
var n HandlerList
n.list = append([]func(*Request){}, l.list...)
return n
}
// Clear clears the handler list.
func (l *HandlerList) Clear() {
l.list = []func(*Request){}
}
// Len returns the number of handlers in the list.
func (l *HandlerList) Len() int {
return len(l.list)
}
// PushBack pushes handlers f to the back of the handler list.
func (l *HandlerList) PushBack(f ...func(*Request)) {
l.list = append(l.list, f...)
}
// PushFront pushes handlers f to the front of the handler list.
func (l *HandlerList) PushFront(f ...func(*Request)) {
l.list = append(f, l.list...)
}
// Run executes all handlers in the list with a given request object.
func (l *HandlerList) Run(r *Request) {
for _, f := range l.list {
f(r)
}
}

View file

@ -1,31 +0,0 @@
package aws
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestHandlerList(t *testing.T) {
s := ""
r := &Request{}
l := HandlerList{}
l.PushBack(func(r *Request) {
s += "a"
r.Data = s
})
l.Run(r)
assert.Equal(t, "a", s)
assert.Equal(t, "a", r.Data)
}
func TestMultipleHandlers(t *testing.T) {
r := &Request{}
l := HandlerList{}
l.PushBack(func(r *Request) { r.Data = nil })
l.PushFront(func(r *Request) { r.Data = Bool(true) })
l.Run(r)
if r.Data != nil {
t.Error("Expected handler to execute")
}
}

View file

@ -1,89 +0,0 @@
package aws
import (
"fmt"
"reflect"
"strings"
"github.com/aws/aws-sdk-go/aws/awserr"
)
// ValidateParameters is a request handler to validate the input parameters.
// Validating parameters only has meaning if done prior to the request being sent.
func ValidateParameters(r *Request) {
if r.ParamsFilled() {
v := validator{errors: []string{}}
v.validateAny(reflect.ValueOf(r.Params), "")
if count := len(v.errors); count > 0 {
format := "%d validation errors:\n- %s"
msg := fmt.Sprintf(format, count, strings.Join(v.errors, "\n- "))
r.Error = awserr.New("InvalidParameter", msg, nil)
}
}
}
// A validator validates values. Collects validations errors which occurs.
type validator struct {
errors []string
}
// validateAny will validate any struct, slice or map type. All validations
// are also performed recursively for nested types.
func (v *validator) validateAny(value reflect.Value, path string) {
value = reflect.Indirect(value)
if !value.IsValid() {
return
}
switch value.Kind() {
case reflect.Struct:
v.validateStruct(value, path)
case reflect.Slice:
for i := 0; i < value.Len(); i++ {
v.validateAny(value.Index(i), path+fmt.Sprintf("[%d]", i))
}
case reflect.Map:
for _, n := range value.MapKeys() {
v.validateAny(value.MapIndex(n), path+fmt.Sprintf("[%q]", n.String()))
}
}
}
// validateStruct will validate the struct value's fields. If the structure has
// nested types those types will be validated also.
func (v *validator) validateStruct(value reflect.Value, path string) {
prefix := "."
if path == "" {
prefix = ""
}
for i := 0; i < value.Type().NumField(); i++ {
f := value.Type().Field(i)
if strings.ToLower(f.Name[0:1]) == f.Name[0:1] {
continue
}
fvalue := value.FieldByName(f.Name)
notset := false
if f.Tag.Get("required") != "" {
switch fvalue.Kind() {
case reflect.Ptr, reflect.Slice, reflect.Map:
if fvalue.IsNil() {
notset = true
}
default:
if !fvalue.IsValid() {
notset = true
}
}
}
if notset {
msg := "missing required parameter: " + path + prefix + f.Name
v.errors = append(v.errors, msg)
} else {
v.validateAny(fvalue, path+prefix+f.Name)
}
}
}

View file

@ -1,84 +0,0 @@
package aws_test
import (
"testing"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/stretchr/testify/assert"
)
var service = func() *aws.Service {
s := &aws.Service{
Config: &aws.Config{},
ServiceName: "mock-service",
APIVersion: "2015-01-01",
}
return s
}()
type StructShape struct {
RequiredList []*ConditionalStructShape `required:"true"`
RequiredMap map[string]*ConditionalStructShape `required:"true"`
RequiredBool *bool `required:"true"`
OptionalStruct *ConditionalStructShape
hiddenParameter *string
metadataStructureShape
}
type metadataStructureShape struct {
SDKShapeTraits bool
}
type ConditionalStructShape struct {
Name *string `required:"true"`
SDKShapeTraits bool
}
func TestNoErrors(t *testing.T) {
input := &StructShape{
RequiredList: []*ConditionalStructShape{},
RequiredMap: map[string]*ConditionalStructShape{
"key1": {Name: aws.String("Name")},
"key2": {Name: aws.String("Name")},
},
RequiredBool: aws.Bool(true),
OptionalStruct: &ConditionalStructShape{Name: aws.String("Name")},
}
req := aws.NewRequest(service, &aws.Operation{}, input, nil)
aws.ValidateParameters(req)
assert.NoError(t, req.Error)
}
func TestMissingRequiredParameters(t *testing.T) {
input := &StructShape{}
req := aws.NewRequest(service, &aws.Operation{}, input, nil)
aws.ValidateParameters(req)
assert.Error(t, req.Error)
assert.Equal(t, "InvalidParameter", req.Error.(awserr.Error).Code())
assert.Equal(t, "3 validation errors:\n- missing required parameter: RequiredList\n- missing required parameter: RequiredMap\n- missing required parameter: RequiredBool", req.Error.(awserr.Error).Message())
}
func TestNestedMissingRequiredParameters(t *testing.T) {
input := &StructShape{
RequiredList: []*ConditionalStructShape{{}},
RequiredMap: map[string]*ConditionalStructShape{
"key1": {Name: aws.String("Name")},
"key2": {},
},
RequiredBool: aws.Bool(true),
OptionalStruct: &ConditionalStructShape{},
}
req := aws.NewRequest(service, &aws.Operation{}, input, nil)
aws.ValidateParameters(req)
assert.Error(t, req.Error)
assert.Equal(t, "InvalidParameter", req.Error.(awserr.Error).Code())
assert.Equal(t, "3 validation errors:\n- missing required parameter: RequiredList[0].Name\n- missing required parameter: RequiredMap[\"key2\"].Name\n- missing required parameter: OptionalStruct.Name", req.Error.(awserr.Error).Message())
}

View file

@ -1,312 +0,0 @@
package aws
import (
"bytes"
"io"
"io/ioutil"
"net/http"
"net/url"
"reflect"
"strings"
"time"
"github.com/aws/aws-sdk-go/aws/awsutil"
)
// A Request is the service request to be made.
type Request struct {
*Service
Handlers Handlers
Time time.Time
ExpireTime time.Duration
Operation *Operation
HTTPRequest *http.Request
HTTPResponse *http.Response
Body io.ReadSeeker
bodyStart int64 // offset from beginning of Body that the request body starts
Params interface{}
Error error
Data interface{}
RequestID string
RetryCount uint
Retryable *bool
RetryDelay time.Duration
built bool
}
// An Operation is the service API operation to be made.
type Operation struct {
Name string
HTTPMethod string
HTTPPath string
*Paginator
}
// Paginator keeps track of pagination configuration for an API operation.
type Paginator struct {
InputTokens []string
OutputTokens []string
LimitToken string
TruncationToken string
}
// NewRequest returns a new Request pointer for the service API
// operation and parameters.
//
// Params is any value of input parameters to be the request payload.
// Data is pointer value to an object which the request's response
// payload will be deserialized to.
func NewRequest(service *Service, operation *Operation, params interface{}, data interface{}) *Request {
method := operation.HTTPMethod
if method == "" {
method = "POST"
}
p := operation.HTTPPath
if p == "" {
p = "/"
}
httpReq, _ := http.NewRequest(method, "", nil)
httpReq.URL, _ = url.Parse(service.Endpoint + p)
r := &Request{
Service: service,
Handlers: service.Handlers.copy(),
Time: time.Now(),
ExpireTime: 0,
Operation: operation,
HTTPRequest: httpReq,
Body: nil,
Params: params,
Error: nil,
Data: data,
}
r.SetBufferBody([]byte{})
return r
}
// WillRetry returns if the request's can be retried.
func (r *Request) WillRetry() bool {
return r.Error != nil && BoolValue(r.Retryable) && r.RetryCount < r.Service.MaxRetries()
}
// ParamsFilled returns if the request's parameters have been populated
// and the parameters are valid. False is returned if no parameters are
// provided or invalid.
func (r *Request) ParamsFilled() bool {
return r.Params != nil && reflect.ValueOf(r.Params).Elem().IsValid()
}
// DataFilled returns true if the request's data for response deserialization
// target has been set and is a valid. False is returned if data is not
// set, or is invalid.
func (r *Request) DataFilled() bool {
return r.Data != nil && reflect.ValueOf(r.Data).Elem().IsValid()
}
// SetBufferBody will set the request's body bytes that will be sent to
// the service API.
func (r *Request) SetBufferBody(buf []byte) {
r.SetReaderBody(bytes.NewReader(buf))
}
// SetStringBody sets the body of the request to be backed by a string.
func (r *Request) SetStringBody(s string) {
r.SetReaderBody(strings.NewReader(s))
}
// SetReaderBody will set the request's body reader.
func (r *Request) SetReaderBody(reader io.ReadSeeker) {
r.HTTPRequest.Body = ioutil.NopCloser(reader)
r.Body = reader
}
// Presign returns the request's signed URL. Error will be returned
// if the signing fails.
func (r *Request) Presign(expireTime time.Duration) (string, error) {
r.ExpireTime = expireTime
r.Sign()
if r.Error != nil {
return "", r.Error
}
return r.HTTPRequest.URL.String(), nil
}
// Build will build the request's object so it can be signed and sent
// to the service. Build will also validate all the request's parameters.
// Anny additional build Handlers set on this request will be run
// in the order they were set.
//
// The request will only be built once. Multiple calls to build will have
// no effect.
//
// If any Validate or Build errors occur the build will stop and the error
// which occurred will be returned.
func (r *Request) Build() error {
if !r.built {
r.Error = nil
r.Handlers.Validate.Run(r)
if r.Error != nil {
return r.Error
}
r.Handlers.Build.Run(r)
r.built = true
}
return r.Error
}
// Sign will sign the request retuning error if errors are encountered.
//
// Send will build the request prior to signing. All Sign Handlers will
// be executed in the order they were set.
func (r *Request) Sign() error {
r.Build()
if r.Error != nil {
return r.Error
}
r.Handlers.Sign.Run(r)
return r.Error
}
// Send will send the request returning error if errors are encountered.
//
// Send will sign the request prior to sending. All Send Handlers will
// be executed in the order they were set.
func (r *Request) Send() error {
for {
r.Sign()
if r.Error != nil {
return r.Error
}
if BoolValue(r.Retryable) {
// Re-seek the body back to the original point in for a retry so that
// send will send the body's contents again in the upcoming request.
r.Body.Seek(r.bodyStart, 0)
}
r.Retryable = nil
r.Handlers.Send.Run(r)
if r.Error != nil {
r.Handlers.Retry.Run(r)
r.Handlers.AfterRetry.Run(r)
if r.Error != nil {
return r.Error
}
continue
}
r.Handlers.UnmarshalMeta.Run(r)
r.Handlers.ValidateResponse.Run(r)
if r.Error != nil {
r.Handlers.UnmarshalError.Run(r)
r.Handlers.Retry.Run(r)
r.Handlers.AfterRetry.Run(r)
if r.Error != nil {
return r.Error
}
continue
}
r.Handlers.Unmarshal.Run(r)
if r.Error != nil {
r.Handlers.Retry.Run(r)
r.Handlers.AfterRetry.Run(r)
if r.Error != nil {
return r.Error
}
continue
}
break
}
return nil
}
// HasNextPage returns true if this request has more pages of data available.
func (r *Request) HasNextPage() bool {
return r.nextPageTokens() != nil
}
// nextPageTokens returns the tokens to use when asking for the next page of
// data.
func (r *Request) nextPageTokens() []interface{} {
if r.Operation.Paginator == nil {
return nil
}
if r.Operation.TruncationToken != "" {
tr := awsutil.ValuesAtAnyPath(r.Data, r.Operation.TruncationToken)
if tr == nil || len(tr) == 0 {
return nil
}
switch v := tr[0].(type) {
case bool:
if v == false {
return nil
}
}
}
found := false
tokens := make([]interface{}, len(r.Operation.OutputTokens))
for i, outtok := range r.Operation.OutputTokens {
v := awsutil.ValuesAtAnyPath(r.Data, outtok)
if v != nil && len(v) > 0 {
found = true
tokens[i] = v[0]
}
}
if found {
return tokens
}
return nil
}
// NextPage returns a new Request that can be executed to return the next
// page of result data. Call .Send() on this request to execute it.
func (r *Request) NextPage() *Request {
tokens := r.nextPageTokens()
if tokens == nil {
return nil
}
data := reflect.New(reflect.TypeOf(r.Data).Elem()).Interface()
nr := NewRequest(r.Service, r.Operation, awsutil.CopyOf(r.Params), data)
for i, intok := range nr.Operation.InputTokens {
awsutil.SetValueAtAnyPath(nr.Params, intok, tokens[i])
}
return nr
}
// EachPage iterates over each page of a paginated request object. The fn
// parameter should be a function with the following sample signature:
//
// func(page *T, lastPage bool) bool {
// return true // return false to stop iterating
// }
//
// Where "T" is the structure type matching the output structure of the given
// operation. For example, a request object generated by
// DynamoDB.ListTablesRequest() would expect to see dynamodb.ListTablesOutput
// as the structure "T". The lastPage value represents whether the page is
// the last page of data or not. The return value of this function should
// return true to keep iterating or false to stop.
func (r *Request) EachPage(fn func(data interface{}, isLastPage bool) (shouldContinue bool)) error {
for page := r; page != nil; page = page.NextPage() {
page.Send()
shouldContinue := fn(page.Data, !page.HasNextPage())
if page.Error != nil || !shouldContinue {
return page.Error
}
}
return nil
}

View file

@ -1,305 +0,0 @@
package aws_test
import (
"testing"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/internal/test/unit"
"github.com/aws/aws-sdk-go/service/dynamodb"
"github.com/aws/aws-sdk-go/service/s3"
"github.com/stretchr/testify/assert"
)
var _ = unit.Imported
// Use DynamoDB methods for simplicity
func TestPagination(t *testing.T) {
db := dynamodb.New(nil)
tokens, pages, numPages, gotToEnd := []string{}, []string{}, 0, false
reqNum := 0
resps := []*dynamodb.ListTablesOutput{
{TableNames: []*string{aws.String("Table1"), aws.String("Table2")}, LastEvaluatedTableName: aws.String("Table2")},
{TableNames: []*string{aws.String("Table3"), aws.String("Table4")}, LastEvaluatedTableName: aws.String("Table4")},
{TableNames: []*string{aws.String("Table5")}},
}
db.Handlers.Send.Clear() // mock sending
db.Handlers.Unmarshal.Clear()
db.Handlers.UnmarshalMeta.Clear()
db.Handlers.ValidateResponse.Clear()
db.Handlers.Build.PushBack(func(r *aws.Request) {
in := r.Params.(*dynamodb.ListTablesInput)
if in == nil {
tokens = append(tokens, "")
} else if in.ExclusiveStartTableName != nil {
tokens = append(tokens, *in.ExclusiveStartTableName)
}
})
db.Handlers.Unmarshal.PushBack(func(r *aws.Request) {
r.Data = resps[reqNum]
reqNum++
})
params := &dynamodb.ListTablesInput{Limit: aws.Int64(2)}
err := db.ListTablesPages(params, func(p *dynamodb.ListTablesOutput, last bool) bool {
numPages++
for _, t := range p.TableNames {
pages = append(pages, *t)
}
if last {
if gotToEnd {
assert.Fail(t, "last=true happened twice")
}
gotToEnd = true
}
return true
})
assert.Equal(t, []string{"Table2", "Table4"}, tokens)
assert.Equal(t, []string{"Table1", "Table2", "Table3", "Table4", "Table5"}, pages)
assert.Equal(t, 3, numPages)
assert.True(t, gotToEnd)
assert.Nil(t, err)
assert.Nil(t, params.ExclusiveStartTableName)
}
// Use DynamoDB methods for simplicity
func TestPaginationEachPage(t *testing.T) {
db := dynamodb.New(nil)
tokens, pages, numPages, gotToEnd := []string{}, []string{}, 0, false
reqNum := 0
resps := []*dynamodb.ListTablesOutput{
{TableNames: []*string{aws.String("Table1"), aws.String("Table2")}, LastEvaluatedTableName: aws.String("Table2")},
{TableNames: []*string{aws.String("Table3"), aws.String("Table4")}, LastEvaluatedTableName: aws.String("Table4")},
{TableNames: []*string{aws.String("Table5")}},
}
db.Handlers.Send.Clear() // mock sending
db.Handlers.Unmarshal.Clear()
db.Handlers.UnmarshalMeta.Clear()
db.Handlers.ValidateResponse.Clear()
db.Handlers.Build.PushBack(func(r *aws.Request) {
in := r.Params.(*dynamodb.ListTablesInput)
if in == nil {
tokens = append(tokens, "")
} else if in.ExclusiveStartTableName != nil {
tokens = append(tokens, *in.ExclusiveStartTableName)
}
})
db.Handlers.Unmarshal.PushBack(func(r *aws.Request) {
r.Data = resps[reqNum]
reqNum++
})
params := &dynamodb.ListTablesInput{Limit: aws.Int64(2)}
req, _ := db.ListTablesRequest(params)
err := req.EachPage(func(p interface{}, last bool) bool {
numPages++
for _, t := range p.(*dynamodb.ListTablesOutput).TableNames {
pages = append(pages, *t)
}
if last {
if gotToEnd {
assert.Fail(t, "last=true happened twice")
}
gotToEnd = true
}
return true
})
assert.Equal(t, []string{"Table2", "Table4"}, tokens)
assert.Equal(t, []string{"Table1", "Table2", "Table3", "Table4", "Table5"}, pages)
assert.Equal(t, 3, numPages)
assert.True(t, gotToEnd)
assert.Nil(t, err)
}
// Use DynamoDB methods for simplicity
func TestPaginationEarlyExit(t *testing.T) {
db := dynamodb.New(nil)
numPages, gotToEnd := 0, false
reqNum := 0
resps := []*dynamodb.ListTablesOutput{
{TableNames: []*string{aws.String("Table1"), aws.String("Table2")}, LastEvaluatedTableName: aws.String("Table2")},
{TableNames: []*string{aws.String("Table3"), aws.String("Table4")}, LastEvaluatedTableName: aws.String("Table4")},
{TableNames: []*string{aws.String("Table5")}},
}
db.Handlers.Send.Clear() // mock sending
db.Handlers.Unmarshal.Clear()
db.Handlers.UnmarshalMeta.Clear()
db.Handlers.ValidateResponse.Clear()
db.Handlers.Unmarshal.PushBack(func(r *aws.Request) {
r.Data = resps[reqNum]
reqNum++
})
params := &dynamodb.ListTablesInput{Limit: aws.Int64(2)}
err := db.ListTablesPages(params, func(p *dynamodb.ListTablesOutput, last bool) bool {
numPages++
if numPages == 2 {
return false
}
if last {
if gotToEnd {
assert.Fail(t, "last=true happened twice")
}
gotToEnd = true
}
return true
})
assert.Equal(t, 2, numPages)
assert.False(t, gotToEnd)
assert.Nil(t, err)
}
func TestSkipPagination(t *testing.T) {
client := s3.New(nil)
client.Handlers.Send.Clear() // mock sending
client.Handlers.Unmarshal.Clear()
client.Handlers.UnmarshalMeta.Clear()
client.Handlers.ValidateResponse.Clear()
client.Handlers.Unmarshal.PushBack(func(r *aws.Request) {
r.Data = &s3.HeadBucketOutput{}
})
req, _ := client.HeadBucketRequest(&s3.HeadBucketInput{Bucket: aws.String("bucket")})
numPages, gotToEnd := 0, false
req.EachPage(func(p interface{}, last bool) bool {
numPages++
if last {
gotToEnd = true
}
return true
})
assert.Equal(t, 1, numPages)
assert.True(t, gotToEnd)
}
// Use S3 for simplicity
func TestPaginationTruncation(t *testing.T) {
count := 0
client := s3.New(nil)
reqNum := &count
resps := []*s3.ListObjectsOutput{
{IsTruncated: aws.Bool(true), Contents: []*s3.Object{{Key: aws.String("Key1")}}},
{IsTruncated: aws.Bool(true), Contents: []*s3.Object{{Key: aws.String("Key2")}}},
{IsTruncated: aws.Bool(false), Contents: []*s3.Object{{Key: aws.String("Key3")}}},
{IsTruncated: aws.Bool(true), Contents: []*s3.Object{{Key: aws.String("Key4")}}},
}
client.Handlers.Send.Clear() // mock sending
client.Handlers.Unmarshal.Clear()
client.Handlers.UnmarshalMeta.Clear()
client.Handlers.ValidateResponse.Clear()
client.Handlers.Unmarshal.PushBack(func(r *aws.Request) {
r.Data = resps[*reqNum]
*reqNum++
})
params := &s3.ListObjectsInput{Bucket: aws.String("bucket")}
results := []string{}
err := client.ListObjectsPages(params, func(p *s3.ListObjectsOutput, last bool) bool {
results = append(results, *p.Contents[0].Key)
return true
})
assert.Equal(t, []string{"Key1", "Key2", "Key3"}, results)
assert.Nil(t, err)
// Try again without truncation token at all
count = 0
resps[1].IsTruncated = nil
resps[2].IsTruncated = aws.Bool(true)
results = []string{}
err = client.ListObjectsPages(params, func(p *s3.ListObjectsOutput, last bool) bool {
results = append(results, *p.Contents[0].Key)
return true
})
assert.Equal(t, []string{"Key1", "Key2"}, results)
assert.Nil(t, err)
}
// Benchmarks
var benchResps = []*dynamodb.ListTablesOutput{
{TableNames: []*string{aws.String("TABLE"), aws.String("NXT")}, LastEvaluatedTableName: aws.String("NXT")},
{TableNames: []*string{aws.String("TABLE"), aws.String("NXT")}, LastEvaluatedTableName: aws.String("NXT")},
{TableNames: []*string{aws.String("TABLE"), aws.String("NXT")}, LastEvaluatedTableName: aws.String("NXT")},
{TableNames: []*string{aws.String("TABLE"), aws.String("NXT")}, LastEvaluatedTableName: aws.String("NXT")},
{TableNames: []*string{aws.String("TABLE"), aws.String("NXT")}, LastEvaluatedTableName: aws.String("NXT")},
{TableNames: []*string{aws.String("TABLE"), aws.String("NXT")}, LastEvaluatedTableName: aws.String("NXT")},
{TableNames: []*string{aws.String("TABLE"), aws.String("NXT")}, LastEvaluatedTableName: aws.String("NXT")},
{TableNames: []*string{aws.String("TABLE"), aws.String("NXT")}, LastEvaluatedTableName: aws.String("NXT")},
{TableNames: []*string{aws.String("TABLE"), aws.String("NXT")}, LastEvaluatedTableName: aws.String("NXT")},
{TableNames: []*string{aws.String("TABLE"), aws.String("NXT")}, LastEvaluatedTableName: aws.String("NXT")},
{TableNames: []*string{aws.String("TABLE"), aws.String("NXT")}, LastEvaluatedTableName: aws.String("NXT")},
{TableNames: []*string{aws.String("TABLE"), aws.String("NXT")}, LastEvaluatedTableName: aws.String("NXT")},
{TableNames: []*string{aws.String("TABLE"), aws.String("NXT")}, LastEvaluatedTableName: aws.String("NXT")},
{TableNames: []*string{aws.String("TABLE")}},
}
var benchDb = func() *dynamodb.DynamoDB {
db := dynamodb.New(nil)
db.Handlers.Send.Clear() // mock sending
db.Handlers.Unmarshal.Clear()
db.Handlers.UnmarshalMeta.Clear()
db.Handlers.ValidateResponse.Clear()
return db
}
func BenchmarkCodegenIterator(b *testing.B) {
reqNum := 0
db := benchDb()
db.Handlers.Unmarshal.PushBack(func(r *aws.Request) {
r.Data = benchResps[reqNum]
reqNum++
})
input := &dynamodb.ListTablesInput{Limit: aws.Int64(2)}
iter := func(fn func(*dynamodb.ListTablesOutput, bool) bool) error {
page, _ := db.ListTablesRequest(input)
for ; page != nil; page = page.NextPage() {
page.Send()
out := page.Data.(*dynamodb.ListTablesOutput)
if result := fn(out, !page.HasNextPage()); page.Error != nil || !result {
return page.Error
}
}
return nil
}
for i := 0; i < b.N; i++ {
reqNum = 0
iter(func(p *dynamodb.ListTablesOutput, last bool) bool {
return true
})
}
}
func BenchmarkEachPageIterator(b *testing.B) {
reqNum := 0
db := benchDb()
db.Handlers.Unmarshal.PushBack(func(r *aws.Request) {
r.Data = benchResps[reqNum]
reqNum++
})
input := &dynamodb.ListTablesInput{Limit: aws.Int64(2)}
for i := 0; i < b.N; i++ {
reqNum = 0
req, _ := db.ListTablesRequest(input)
req.EachPage(func(p interface{}, last bool) bool {
return true
})
}
}

View file

@ -1,225 +0,0 @@
package aws
import (
"bytes"
"encoding/json"
"fmt"
"io"
"io/ioutil"
"net/http"
"testing"
"time"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/stretchr/testify/assert"
)
type testData struct {
Data string
}
func body(str string) io.ReadCloser {
return ioutil.NopCloser(bytes.NewReader([]byte(str)))
}
func unmarshal(req *Request) {
defer req.HTTPResponse.Body.Close()
if req.Data != nil {
json.NewDecoder(req.HTTPResponse.Body).Decode(req.Data)
}
return
}
func unmarshalError(req *Request) {
bodyBytes, err := ioutil.ReadAll(req.HTTPResponse.Body)
if err != nil {
req.Error = awserr.New("UnmarshaleError", req.HTTPResponse.Status, err)
return
}
if len(bodyBytes) == 0 {
req.Error = awserr.NewRequestFailure(
awserr.New("UnmarshaleError", req.HTTPResponse.Status, fmt.Errorf("empty body")),
req.HTTPResponse.StatusCode,
"",
)
return
}
var jsonErr jsonErrorResponse
if err := json.Unmarshal(bodyBytes, &jsonErr); err != nil {
req.Error = awserr.New("UnmarshaleError", "JSON unmarshal", err)
return
}
req.Error = awserr.NewRequestFailure(
awserr.New(jsonErr.Code, jsonErr.Message, nil),
req.HTTPResponse.StatusCode,
"",
)
}
type jsonErrorResponse struct {
Code string `json:"__type"`
Message string `json:"message"`
}
// test that retries occur for 5xx status codes
func TestRequestRecoverRetry5xx(t *testing.T) {
reqNum := 0
reqs := []http.Response{
{StatusCode: 500, Body: body(`{"__type":"UnknownError","message":"An error occurred."}`)},
{StatusCode: 501, Body: body(`{"__type":"UnknownError","message":"An error occurred."}`)},
{StatusCode: 200, Body: body(`{"data":"valid"}`)},
}
s := NewService(NewConfig().WithMaxRetries(10))
s.Handlers.Validate.Clear()
s.Handlers.Unmarshal.PushBack(unmarshal)
s.Handlers.UnmarshalError.PushBack(unmarshalError)
s.Handlers.Send.Clear() // mock sending
s.Handlers.Send.PushBack(func(r *Request) {
r.HTTPResponse = &reqs[reqNum]
reqNum++
})
out := &testData{}
r := NewRequest(s, &Operation{Name: "Operation"}, nil, out)
err := r.Send()
assert.Nil(t, err)
assert.Equal(t, 2, int(r.RetryCount))
assert.Equal(t, "valid", out.Data)
}
// test that retries occur for 4xx status codes with a response type that can be retried - see `shouldRetry`
func TestRequestRecoverRetry4xxRetryable(t *testing.T) {
reqNum := 0
reqs := []http.Response{
{StatusCode: 400, Body: body(`{"__type":"Throttling","message":"Rate exceeded."}`)},
{StatusCode: 429, Body: body(`{"__type":"ProvisionedThroughputExceededException","message":"Rate exceeded."}`)},
{StatusCode: 200, Body: body(`{"data":"valid"}`)},
}
s := NewService(NewConfig().WithMaxRetries(10))
s.Handlers.Validate.Clear()
s.Handlers.Unmarshal.PushBack(unmarshal)
s.Handlers.UnmarshalError.PushBack(unmarshalError)
s.Handlers.Send.Clear() // mock sending
s.Handlers.Send.PushBack(func(r *Request) {
r.HTTPResponse = &reqs[reqNum]
reqNum++
})
out := &testData{}
r := NewRequest(s, &Operation{Name: "Operation"}, nil, out)
err := r.Send()
assert.Nil(t, err)
assert.Equal(t, 2, int(r.RetryCount))
assert.Equal(t, "valid", out.Data)
}
// test that retries don't occur for 4xx status codes with a response type that can't be retried
func TestRequest4xxUnretryable(t *testing.T) {
s := NewService(NewConfig().WithMaxRetries(10))
s.Handlers.Validate.Clear()
s.Handlers.Unmarshal.PushBack(unmarshal)
s.Handlers.UnmarshalError.PushBack(unmarshalError)
s.Handlers.Send.Clear() // mock sending
s.Handlers.Send.PushBack(func(r *Request) {
r.HTTPResponse = &http.Response{StatusCode: 401, Body: body(`{"__type":"SignatureDoesNotMatch","message":"Signature does not match."}`)}
})
out := &testData{}
r := NewRequest(s, &Operation{Name: "Operation"}, nil, out)
err := r.Send()
assert.NotNil(t, err)
if e, ok := err.(awserr.RequestFailure); ok {
assert.Equal(t, 401, e.StatusCode())
} else {
assert.Fail(t, "Expected error to be a service failure")
}
assert.Equal(t, "SignatureDoesNotMatch", err.(awserr.Error).Code())
assert.Equal(t, "Signature does not match.", err.(awserr.Error).Message())
assert.Equal(t, 0, int(r.RetryCount))
}
func TestRequestExhaustRetries(t *testing.T) {
delays := []time.Duration{}
sleepDelay = func(delay time.Duration) {
delays = append(delays, delay)
}
reqNum := 0
reqs := []http.Response{
{StatusCode: 500, Body: body(`{"__type":"UnknownError","message":"An error occurred."}`)},
{StatusCode: 500, Body: body(`{"__type":"UnknownError","message":"An error occurred."}`)},
{StatusCode: 500, Body: body(`{"__type":"UnknownError","message":"An error occurred."}`)},
{StatusCode: 500, Body: body(`{"__type":"UnknownError","message":"An error occurred."}`)},
}
s := NewService(NewConfig().WithMaxRetries(DefaultRetries))
s.Handlers.Validate.Clear()
s.Handlers.Unmarshal.PushBack(unmarshal)
s.Handlers.UnmarshalError.PushBack(unmarshalError)
s.Handlers.Send.Clear() // mock sending
s.Handlers.Send.PushBack(func(r *Request) {
r.HTTPResponse = &reqs[reqNum]
reqNum++
})
r := NewRequest(s, &Operation{Name: "Operation"}, nil, nil)
err := r.Send()
assert.NotNil(t, err)
if e, ok := err.(awserr.RequestFailure); ok {
assert.Equal(t, 500, e.StatusCode())
} else {
assert.Fail(t, "Expected error to be a service failure")
}
assert.Equal(t, "UnknownError", err.(awserr.Error).Code())
assert.Equal(t, "An error occurred.", err.(awserr.Error).Message())
assert.Equal(t, 3, int(r.RetryCount))
expectDelays := []struct{ min, max time.Duration }{{30, 59}, {60, 118}, {120, 236}}
for i, v := range delays {
min := expectDelays[i].min * time.Millisecond
max := expectDelays[i].max * time.Millisecond
assert.True(t, min <= v && v <= max,
"Expect delay to be within range, i:%d, v:%s, min:%s, max:%s", i, v, min, max)
}
}
// test that the request is retried after the credentials are expired.
func TestRequestRecoverExpiredCreds(t *testing.T) {
reqNum := 0
reqs := []http.Response{
{StatusCode: 400, Body: body(`{"__type":"ExpiredTokenException","message":"expired token"}`)},
{StatusCode: 200, Body: body(`{"data":"valid"}`)},
}
s := NewService(&Config{MaxRetries: Int(10), Credentials: credentials.NewStaticCredentials("AKID", "SECRET", "")})
s.Handlers.Validate.Clear()
s.Handlers.Unmarshal.PushBack(unmarshal)
s.Handlers.UnmarshalError.PushBack(unmarshalError)
credExpiredBeforeRetry := false
credExpiredAfterRetry := false
s.Handlers.AfterRetry.PushBack(func(r *Request) {
credExpiredAfterRetry = r.Config.Credentials.IsExpired()
})
s.Handlers.Sign.Clear()
s.Handlers.Sign.PushBack(func(r *Request) {
r.Config.Credentials.Get()
})
s.Handlers.Send.Clear() // mock sending
s.Handlers.Send.PushBack(func(r *Request) {
r.HTTPResponse = &reqs[reqNum]
reqNum++
})
out := &testData{}
r := NewRequest(s, &Operation{Name: "Operation"}, nil, out)
err := r.Send()
assert.Nil(t, err)
assert.False(t, credExpiredBeforeRetry, "Expect valid creds before retry check")
assert.True(t, credExpiredAfterRetry, "Expect expired creds after retry check")
assert.False(t, s.Config.Credentials.IsExpired(), "Expect valid creds after cred expired recovery")
assert.Equal(t, 1, int(r.RetryCount))
assert.Equal(t, "valid", out.Data)
}

View file

@ -1,194 +0,0 @@
package aws
import (
"fmt"
"math"
"math/rand"
"net/http"
"net/http/httputil"
"regexp"
"time"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/internal/endpoints"
)
// A Service implements the base service request and response handling
// used by all services.
type Service struct {
Config *Config
Handlers Handlers
ServiceName string
APIVersion string
Endpoint string
SigningName string
SigningRegion string
JSONVersion string
TargetPrefix string
RetryRules func(*Request) time.Duration
ShouldRetry func(*Request) bool
DefaultMaxRetries uint
}
var schemeRE = regexp.MustCompile("^([^:]+)://")
// NewService will return a pointer to a new Server object initialized.
func NewService(config *Config) *Service {
svc := &Service{Config: config}
svc.Initialize()
return svc
}
// Initialize initializes the service.
func (s *Service) Initialize() {
if s.Config == nil {
s.Config = &Config{}
}
if s.Config.HTTPClient == nil {
s.Config.HTTPClient = http.DefaultClient
}
if s.RetryRules == nil {
s.RetryRules = retryRules
}
if s.ShouldRetry == nil {
s.ShouldRetry = shouldRetry
}
s.DefaultMaxRetries = 3
s.Handlers.Validate.PushBack(ValidateEndpointHandler)
s.Handlers.Build.PushBack(UserAgentHandler)
s.Handlers.Sign.PushBack(BuildContentLength)
s.Handlers.Send.PushBack(SendHandler)
s.Handlers.AfterRetry.PushBack(AfterRetryHandler)
s.Handlers.ValidateResponse.PushBack(ValidateResponseHandler)
s.AddDebugHandlers()
s.buildEndpoint()
if !BoolValue(s.Config.DisableParamValidation) {
s.Handlers.Validate.PushBack(ValidateParameters)
}
}
// buildEndpoint builds the endpoint values the service will use to make requests with.
func (s *Service) buildEndpoint() {
if StringValue(s.Config.Endpoint) != "" {
s.Endpoint = *s.Config.Endpoint
} else {
s.Endpoint, s.SigningRegion =
endpoints.EndpointForRegion(s.ServiceName, StringValue(s.Config.Region))
}
if s.Endpoint != "" && !schemeRE.MatchString(s.Endpoint) {
scheme := "https"
if BoolValue(s.Config.DisableSSL) {
scheme = "http"
}
s.Endpoint = scheme + "://" + s.Endpoint
}
}
// AddDebugHandlers injects debug logging handlers into the service to log request
// debug information.
func (s *Service) AddDebugHandlers() {
if !s.Config.LogLevel.AtLeast(LogDebug) {
return
}
s.Handlers.Send.PushFront(logRequest)
s.Handlers.Send.PushBack(logResponse)
}
const logReqMsg = `DEBUG: Request %s/%s Details:
---[ REQUEST POST-SIGN ]-----------------------------
%s
-----------------------------------------------------`
func logRequest(r *Request) {
logBody := r.Config.LogLevel.Matches(LogDebugWithHTTPBody)
dumpedBody, _ := httputil.DumpRequestOut(r.HTTPRequest, logBody)
r.Config.Logger.Log(fmt.Sprintf(logReqMsg, r.ServiceName, r.Operation.Name, string(dumpedBody)))
}
const logRespMsg = `DEBUG: Response %s/%s Details:
---[ RESPONSE ]--------------------------------------
%s
-----------------------------------------------------`
func logResponse(r *Request) {
var msg = "no reponse data"
if r.HTTPResponse != nil {
logBody := r.Config.LogLevel.Matches(LogDebugWithHTTPBody)
dumpedBody, _ := httputil.DumpResponse(r.HTTPResponse, logBody)
msg = string(dumpedBody)
} else if r.Error != nil {
msg = r.Error.Error()
}
r.Config.Logger.Log(fmt.Sprintf(logRespMsg, r.ServiceName, r.Operation.Name, msg))
}
// MaxRetries returns the number of maximum returns the service will use to make
// an individual API request.
func (s *Service) MaxRetries() uint {
if IntValue(s.Config.MaxRetries) < 0 {
return s.DefaultMaxRetries
}
return uint(IntValue(s.Config.MaxRetries))
}
var seededRand = rand.New(rand.NewSource(time.Now().UnixNano()))
// retryRules returns the delay duration before retrying this request again
func retryRules(r *Request) time.Duration {
delay := int(math.Pow(2, float64(r.RetryCount))) * (seededRand.Intn(30) + 30)
return time.Duration(delay) * time.Millisecond
}
// retryableCodes is a collection of service response codes which are retry-able
// without any further action.
var retryableCodes = map[string]struct{}{
"RequestError": {},
"ProvisionedThroughputExceededException": {},
"Throttling": {},
"ThrottlingException": {},
"RequestLimitExceeded": {},
"RequestThrottled": {},
}
// credsExpiredCodes is a collection of error codes which signify the credentials
// need to be refreshed. Expired tokens require refreshing of credentials, and
// resigning before the request can be retried.
var credsExpiredCodes = map[string]struct{}{
"ExpiredToken": {},
"ExpiredTokenException": {},
"RequestExpired": {}, // EC2 Only
}
func isCodeRetryable(code string) bool {
if _, ok := retryableCodes[code]; ok {
return true
}
return isCodeExpiredCreds(code)
}
func isCodeExpiredCreds(code string) bool {
_, ok := credsExpiredCodes[code]
return ok
}
// shouldRetry returns if the request should be retried.
func shouldRetry(r *Request) bool {
if r.HTTPResponse.StatusCode >= 500 {
return true
}
if r.Error != nil {
if err, ok := r.Error.(awserr.Error); ok {
return isCodeRetryable(err.Code())
}
}
return false
}

View file

@ -1,55 +0,0 @@
package aws
import (
"io"
)
// ReadSeekCloser wraps a io.Reader returning a ReaderSeakerCloser
func ReadSeekCloser(r io.Reader) ReaderSeekerCloser {
return ReaderSeekerCloser{r}
}
// ReaderSeekerCloser represents a reader that can also delegate io.Seeker and
// io.Closer interfaces to the underlying object if they are available.
type ReaderSeekerCloser struct {
r io.Reader
}
// Read reads from the reader up to size of p. The number of bytes read, and
// error if it occurred will be returned.
//
// If the reader is not an io.Reader zero bytes read, and nil error will be returned.
//
// Performs the same functionality as io.Reader Read
func (r ReaderSeekerCloser) Read(p []byte) (int, error) {
switch t := r.r.(type) {
case io.Reader:
return t.Read(p)
}
return 0, nil
}
// Seek sets the offset for the next Read to offset, interpreted according to
// whence: 0 means relative to the origin of the file, 1 means relative to the
// current offset, and 2 means relative to the end. Seek returns the new offset
// and an error, if any.
//
// If the ReaderSeekerCloser is not an io.Seeker nothing will be done.
func (r ReaderSeekerCloser) Seek(offset int64, whence int) (int64, error) {
switch t := r.r.(type) {
case io.Seeker:
return t.Seek(offset, whence)
}
return int64(0), nil
}
// Close closes the ReaderSeekerCloser.
//
// If the ReaderSeekerCloser is not an io.Closer nothing will be done.
func (r ReaderSeekerCloser) Close() error {
switch t := r.r.(type) {
case io.Closer:
return t.Close()
}
return nil
}

View file

@ -1,31 +0,0 @@
// Package endpoints validates regional endpoints for services.
package endpoints
//go:generate go run ../model/cli/gen-endpoints/main.go endpoints.json endpoints_map.go
//go:generate gofmt -s -w endpoints_map.go
import "strings"
// EndpointForRegion returns an endpoint and its signing region for a service and region.
// if the service and region pair are not found endpoint and signingRegion will be empty.
func EndpointForRegion(svcName, region string) (endpoint, signingRegion string) {
derivedKeys := []string{
region + "/" + svcName,
region + "/*",
"*/" + svcName,
"*/*",
}
for _, key := range derivedKeys {
if val, ok := endpointsMap.Endpoints[key]; ok {
ep := val.Endpoint
ep = strings.Replace(ep, "{region}", region, -1)
ep = strings.Replace(ep, "{service}", svcName, -1)
endpoint = ep
signingRegion = val.SigningRegion
return
}
}
return
}

View file

@ -1,77 +0,0 @@
{
"version": 2,
"endpoints": {
"*/*": {
"endpoint": "{service}.{region}.amazonaws.com"
},
"cn-north-1/*": {
"endpoint": "{service}.{region}.amazonaws.com.cn",
"signatureVersion": "v4"
},
"us-gov-west-1/iam": {
"endpoint": "iam.us-gov.amazonaws.com"
},
"us-gov-west-1/sts": {
"endpoint": "sts.us-gov-west-1.amazonaws.com"
},
"us-gov-west-1/s3": {
"endpoint": "s3-{region}.amazonaws.com"
},
"*/cloudfront": {
"endpoint": "cloudfront.amazonaws.com",
"signingRegion": "us-east-1"
},
"*/cloudsearchdomain": {
"endpoint": "",
"signingRegion": "us-east-1"
},
"*/iam": {
"endpoint": "iam.amazonaws.com",
"signingRegion": "us-east-1"
},
"*/importexport": {
"endpoint": "importexport.amazonaws.com",
"signingRegion": "us-east-1"
},
"*/route53": {
"endpoint": "route53.amazonaws.com",
"signingRegion": "us-east-1"
},
"*/sts": {
"endpoint": "sts.amazonaws.com",
"signingRegion": "us-east-1"
},
"us-east-1/sdb": {
"endpoint": "sdb.amazonaws.com",
"signingRegion": "us-east-1"
},
"us-east-1/s3": {
"endpoint": "s3.amazonaws.com"
},
"us-west-1/s3": {
"endpoint": "s3-{region}.amazonaws.com"
},
"us-west-2/s3": {
"endpoint": "s3-{region}.amazonaws.com"
},
"eu-west-1/s3": {
"endpoint": "s3-{region}.amazonaws.com"
},
"ap-southeast-1/s3": {
"endpoint": "s3-{region}.amazonaws.com"
},
"ap-southeast-2/s3": {
"endpoint": "s3-{region}.amazonaws.com"
},
"ap-northeast-1/s3": {
"endpoint": "s3-{region}.amazonaws.com"
},
"sa-east-1/s3": {
"endpoint": "s3-{region}.amazonaws.com"
},
"eu-central-1/s3": {
"endpoint": "{service}.{region}.amazonaws.com",
"signatureVersion": "v4"
}
}
}

View file

@ -1,89 +0,0 @@
package endpoints
// THIS FILE IS AUTOMATICALLY GENERATED. DO NOT EDIT.
type endpointStruct struct {
Version int
Endpoints map[string]endpointEntry
}
type endpointEntry struct {
Endpoint string
SigningRegion string
}
var endpointsMap = endpointStruct{
Version: 2,
Endpoints: map[string]endpointEntry{
"*/*": {
Endpoint: "{service}.{region}.amazonaws.com",
},
"*/cloudfront": {
Endpoint: "cloudfront.amazonaws.com",
SigningRegion: "us-east-1",
},
"*/cloudsearchdomain": {
Endpoint: "",
SigningRegion: "us-east-1",
},
"*/iam": {
Endpoint: "iam.amazonaws.com",
SigningRegion: "us-east-1",
},
"*/importexport": {
Endpoint: "importexport.amazonaws.com",
SigningRegion: "us-east-1",
},
"*/route53": {
Endpoint: "route53.amazonaws.com",
SigningRegion: "us-east-1",
},
"*/sts": {
Endpoint: "sts.amazonaws.com",
SigningRegion: "us-east-1",
},
"ap-northeast-1/s3": {
Endpoint: "s3-{region}.amazonaws.com",
},
"ap-southeast-1/s3": {
Endpoint: "s3-{region}.amazonaws.com",
},
"ap-southeast-2/s3": {
Endpoint: "s3-{region}.amazonaws.com",
},
"cn-north-1/*": {
Endpoint: "{service}.{region}.amazonaws.com.cn",
},
"eu-central-1/s3": {
Endpoint: "{service}.{region}.amazonaws.com",
},
"eu-west-1/s3": {
Endpoint: "s3-{region}.amazonaws.com",
},
"sa-east-1/s3": {
Endpoint: "s3-{region}.amazonaws.com",
},
"us-east-1/s3": {
Endpoint: "s3.amazonaws.com",
},
"us-east-1/sdb": {
Endpoint: "sdb.amazonaws.com",
SigningRegion: "us-east-1",
},
"us-gov-west-1/iam": {
Endpoint: "iam.us-gov.amazonaws.com",
},
"us-gov-west-1/s3": {
Endpoint: "s3-{region}.amazonaws.com",
},
"us-gov-west-1/sts": {
Endpoint: "sts.us-gov-west-1.amazonaws.com",
},
"us-west-1/s3": {
Endpoint: "s3-{region}.amazonaws.com",
},
"us-west-2/s3": {
Endpoint: "s3-{region}.amazonaws.com",
},
},
}

View file

@ -1,28 +0,0 @@
package endpoints
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestGlobalEndpoints(t *testing.T) {
region := "mock-region-1"
svcs := []string{"cloudfront", "iam", "importexport", "route53", "sts"}
for _, name := range svcs {
ep, sr := EndpointForRegion(name, region)
assert.Equal(t, name+".amazonaws.com", ep)
assert.Equal(t, "us-east-1", sr)
}
}
func TestServicesInCN(t *testing.T) {
region := "cn-north-1"
svcs := []string{"cloudfront", "iam", "importexport", "route53", "sts", "s3"}
for _, name := range svcs {
ep, _ := EndpointForRegion(name, region)
assert.Equal(t, name+"."+region+".amazonaws.com.cn", ep)
}
}

File diff suppressed because it is too large Load diff

View file

@ -1,29 +0,0 @@
package query
//go:generate go run ../../fixtures/protocol/generate.go ../../fixtures/protocol/output/query.json unmarshal_test.go
import (
"encoding/xml"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/internal/protocol/xml/xmlutil"
)
// Unmarshal unmarshals a response for an AWS Query service.
func Unmarshal(r *aws.Request) {
defer r.HTTPResponse.Body.Close()
if r.DataFilled() {
decoder := xml.NewDecoder(r.HTTPResponse.Body)
err := xmlutil.UnmarshalXML(r.Data, decoder, r.Operation.Name+"Result")
if err != nil {
r.Error = awserr.New("SerializationError", "failed decoding Query response", err)
return
}
}
}
// UnmarshalMeta unmarshals header response values for an AWS Query service.
func UnmarshalMeta(r *aws.Request) {
// TODO implement unmarshaling of request IDs
}

View file

@ -1,33 +0,0 @@
package query
import (
"encoding/xml"
"io"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
)
type xmlErrorResponse struct {
XMLName xml.Name `xml:"ErrorResponse"`
Code string `xml:"Error>Code"`
Message string `xml:"Error>Message"`
RequestID string `xml:"RequestId"`
}
// UnmarshalError unmarshals an error response for an AWS Query service.
func UnmarshalError(r *aws.Request) {
defer r.HTTPResponse.Body.Close()
resp := &xmlErrorResponse{}
err := xml.NewDecoder(r.HTTPResponse.Body).Decode(resp)
if err != nil && err != io.EOF {
r.Error = awserr.New("SerializationError", "failed to decode query XML error response", err)
} else {
r.Error = awserr.NewRequestFailure(
awserr.New(resp.Code, resp.Message, nil),
r.HTTPResponse.StatusCode,
resp.RequestID,
)
}
}

File diff suppressed because it is too large Load diff

View file

@ -1,208 +0,0 @@
// Package rest provides RESTful serialisation of AWS requests and responses.
package rest
import (
"bytes"
"encoding/base64"
"fmt"
"io"
"net/url"
"reflect"
"strconv"
"strings"
"time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
)
// RFC822 returns an RFC822 formatted timestamp for AWS protocols
const RFC822 = "Mon, 2 Jan 2006 15:04:05 GMT"
// Whether the byte value can be sent without escaping in AWS URLs
var noEscape [256]bool
func init() {
for i := 0; i < len(noEscape); i++ {
// AWS expects every character except these to be escaped
noEscape[i] = (i >= 'A' && i <= 'Z') ||
(i >= 'a' && i <= 'z') ||
(i >= '0' && i <= '9') ||
i == '-' ||
i == '.' ||
i == '_' ||
i == '~'
}
}
// Build builds the REST component of a service request.
func Build(r *aws.Request) {
if r.ParamsFilled() {
v := reflect.ValueOf(r.Params).Elem()
buildLocationElements(r, v)
buildBody(r, v)
}
}
func buildLocationElements(r *aws.Request, v reflect.Value) {
query := r.HTTPRequest.URL.Query()
for i := 0; i < v.NumField(); i++ {
m := v.Field(i)
if n := v.Type().Field(i).Name; n[0:1] == strings.ToLower(n[0:1]) {
continue
}
if m.IsValid() {
field := v.Type().Field(i)
name := field.Tag.Get("locationName")
if name == "" {
name = field.Name
}
if m.Kind() == reflect.Ptr {
m = m.Elem()
}
if !m.IsValid() {
continue
}
switch field.Tag.Get("location") {
case "headers": // header maps
buildHeaderMap(r, m, field.Tag.Get("locationName"))
case "header":
buildHeader(r, m, name)
case "uri":
buildURI(r, m, name)
case "querystring":
buildQueryString(r, m, name, query)
}
}
if r.Error != nil {
return
}
}
r.HTTPRequest.URL.RawQuery = query.Encode()
updatePath(r.HTTPRequest.URL, r.HTTPRequest.URL.Path)
}
func buildBody(r *aws.Request, v reflect.Value) {
if field, ok := v.Type().FieldByName("SDKShapeTraits"); ok {
if payloadName := field.Tag.Get("payload"); payloadName != "" {
pfield, _ := v.Type().FieldByName(payloadName)
if ptag := pfield.Tag.Get("type"); ptag != "" && ptag != "structure" {
payload := reflect.Indirect(v.FieldByName(payloadName))
if payload.IsValid() && payload.Interface() != nil {
switch reader := payload.Interface().(type) {
case io.ReadSeeker:
r.SetReaderBody(reader)
case []byte:
r.SetBufferBody(reader)
case string:
r.SetStringBody(reader)
default:
r.Error = awserr.New("SerializationError",
"failed to encode REST request",
fmt.Errorf("unknown payload type %s", payload.Type()))
}
}
}
}
}
}
func buildHeader(r *aws.Request, v reflect.Value, name string) {
str, err := convertType(v)
if err != nil {
r.Error = awserr.New("SerializationError", "failed to encode REST request", err)
} else if str != nil {
r.HTTPRequest.Header.Add(name, *str)
}
}
func buildHeaderMap(r *aws.Request, v reflect.Value, prefix string) {
for _, key := range v.MapKeys() {
str, err := convertType(v.MapIndex(key))
if err != nil {
r.Error = awserr.New("SerializationError", "failed to encode REST request", err)
} else if str != nil {
r.HTTPRequest.Header.Add(prefix+key.String(), *str)
}
}
}
func buildURI(r *aws.Request, v reflect.Value, name string) {
value, err := convertType(v)
if err != nil {
r.Error = awserr.New("SerializationError", "failed to encode REST request", err)
} else if value != nil {
uri := r.HTTPRequest.URL.Path
uri = strings.Replace(uri, "{"+name+"}", EscapePath(*value, true), -1)
uri = strings.Replace(uri, "{"+name+"+}", EscapePath(*value, false), -1)
r.HTTPRequest.URL.Path = uri
}
}
func buildQueryString(r *aws.Request, v reflect.Value, name string, query url.Values) {
str, err := convertType(v)
if err != nil {
r.Error = awserr.New("SerializationError", "failed to encode REST request", err)
} else if str != nil {
query.Set(name, *str)
}
}
func updatePath(url *url.URL, urlPath string) {
scheme, query := url.Scheme, url.RawQuery
// get formatted URL minus scheme so we can build this into Opaque
url.Scheme, url.Path, url.RawQuery = "", "", ""
s := url.String()
url.Scheme = scheme
url.RawQuery = query
// build opaque URI
url.Opaque = s + urlPath
}
// EscapePath escapes part of a URL path in Amazon style
func EscapePath(path string, encodeSep bool) string {
var buf bytes.Buffer
for i := 0; i < len(path); i++ {
c := path[i]
if noEscape[c] || (c == '/' && !encodeSep) {
buf.WriteByte(c)
} else {
buf.WriteByte('%')
buf.WriteString(strings.ToUpper(strconv.FormatUint(uint64(c), 16)))
}
}
return buf.String()
}
func convertType(v reflect.Value) (*string, error) {
v = reflect.Indirect(v)
if !v.IsValid() {
return nil, nil
}
var str string
switch value := v.Interface().(type) {
case string:
str = value
case []byte:
str = base64.StdEncoding.EncodeToString(value)
case bool:
str = strconv.FormatBool(value)
case int64:
str = strconv.FormatInt(value, 10)
case float64:
str = strconv.FormatFloat(value, 'f', -1, 64)
case time.Time:
str = value.UTC().Format(RFC822)
default:
err := fmt.Errorf("Unsupported value for param %v (%s)", v.Interface(), v.Type())
return nil, err
}
return &str, nil
}

File diff suppressed because it is too large Load diff

View file

@ -1,55 +0,0 @@
// Package restxml provides RESTful XML serialisation of AWS
// requests and responses.
package restxml
//go:generate go run ../../fixtures/protocol/generate.go ../../fixtures/protocol/input/rest-xml.json build_test.go
//go:generate go run ../../fixtures/protocol/generate.go ../../fixtures/protocol/output/rest-xml.json unmarshal_test.go
import (
"bytes"
"encoding/xml"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/internal/protocol/query"
"github.com/aws/aws-sdk-go/internal/protocol/rest"
"github.com/aws/aws-sdk-go/internal/protocol/xml/xmlutil"
)
// Build builds a request payload for the REST XML protocol.
func Build(r *aws.Request) {
rest.Build(r)
if t := rest.PayloadType(r.Params); t == "structure" || t == "" {
var buf bytes.Buffer
err := xmlutil.BuildXML(r.Params, xml.NewEncoder(&buf))
if err != nil {
r.Error = awserr.New("SerializationError", "failed to enode rest XML request", err)
return
}
r.SetBufferBody(buf.Bytes())
}
}
// Unmarshal unmarshals a payload response for the REST XML protocol.
func Unmarshal(r *aws.Request) {
if t := rest.PayloadType(r.Data); t == "structure" || t == "" {
defer r.HTTPResponse.Body.Close()
decoder := xml.NewDecoder(r.HTTPResponse.Body)
err := xmlutil.UnmarshalXML(r.Data, decoder, "")
if err != nil {
r.Error = awserr.New("SerializationError", "failed to decode REST XML response", err)
return
}
}
}
// UnmarshalMeta unmarshals response headers for the REST XML protocol.
func UnmarshalMeta(r *aws.Request) {
rest.Unmarshal(r)
}
// UnmarshalError unmarshals a response error for the REST XML protocol.
func UnmarshalError(r *aws.Request) {
query.UnmarshalError(r)
}

File diff suppressed because it is too large Load diff

View file

@ -1,43 +0,0 @@
package v4_test
import (
"net/url"
"testing"
"time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/internal/test/unit"
"github.com/aws/aws-sdk-go/service/s3"
"github.com/stretchr/testify/assert"
)
var _ = unit.Imported
func TestPresignHandler(t *testing.T) {
svc := s3.New(nil)
req, _ := svc.PutObjectRequest(&s3.PutObjectInput{
Bucket: aws.String("bucket"),
Key: aws.String("key"),
ContentDisposition: aws.String("a+b c$d"),
ACL: aws.String("public-read"),
})
req.Time = time.Unix(0, 0)
urlstr, err := req.Presign(5 * time.Minute)
assert.NoError(t, err)
expectedDate := "19700101T000000Z"
expectedHeaders := "host;x-amz-acl"
expectedSig := "7edcb4e3a1bf12f4989018d75acbe3a7f03df24bd6f3112602d59fc551f0e4e2"
expectedCred := "AKID/19700101/mock-region/s3/aws4_request"
u, _ := url.Parse(urlstr)
urlQ := u.Query()
assert.Equal(t, expectedSig, urlQ.Get("X-Amz-Signature"))
assert.Equal(t, expectedCred, urlQ.Get("X-Amz-Credential"))
assert.Equal(t, expectedHeaders, urlQ.Get("X-Amz-SignedHeaders"))
assert.Equal(t, expectedDate, urlQ.Get("X-Amz-Date"))
assert.Equal(t, "300", urlQ.Get("X-Amz-Expires"))
assert.NotContains(t, urlstr, "+") // + encoded as %20
}

View file

@ -1,364 +0,0 @@
// Package v4 implements signing for AWS V4 signer
package v4
import (
"crypto/hmac"
"crypto/sha256"
"encoding/hex"
"fmt"
"io"
"net/http"
"net/url"
"sort"
"strconv"
"strings"
"time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/internal/protocol/rest"
)
const (
authHeaderPrefix = "AWS4-HMAC-SHA256"
timeFormat = "20060102T150405Z"
shortTimeFormat = "20060102"
)
var ignoredHeaders = map[string]bool{
"Authorization": true,
"Content-Type": true,
"Content-Length": true,
"User-Agent": true,
}
type signer struct {
Request *http.Request
Time time.Time
ExpireTime time.Duration
ServiceName string
Region string
CredValues credentials.Value
Credentials *credentials.Credentials
Query url.Values
Body io.ReadSeeker
Debug aws.LogLevelType
Logger aws.Logger
isPresign bool
formattedTime string
formattedShortTime string
signedHeaders string
canonicalHeaders string
canonicalString string
credentialString string
stringToSign string
signature string
authorization string
}
// Sign requests with signature version 4.
//
// Will sign the requests with the service config's Credentials object
// Signing is skipped if the credentials is the credentials.AnonymousCredentials
// object.
func Sign(req *aws.Request) {
// If the request does not need to be signed ignore the signing of the
// request if the AnonymousCredentials object is used.
if req.Service.Config.Credentials == credentials.AnonymousCredentials {
return
}
region := req.Service.SigningRegion
if region == "" {
region = aws.StringValue(req.Service.Config.Region)
}
name := req.Service.SigningName
if name == "" {
name = req.Service.ServiceName
}
s := signer{
Request: req.HTTPRequest,
Time: req.Time,
ExpireTime: req.ExpireTime,
Query: req.HTTPRequest.URL.Query(),
Body: req.Body,
ServiceName: name,
Region: region,
Credentials: req.Service.Config.Credentials,
Debug: req.Service.Config.LogLevel.Value(),
Logger: req.Service.Config.Logger,
}
req.Error = s.sign()
}
func (v4 *signer) sign() error {
if v4.ExpireTime != 0 {
v4.isPresign = true
}
if v4.isRequestSigned() {
if !v4.Credentials.IsExpired() {
// If the request is already signed, and the credentials have not
// expired yet ignore the signing request.
return nil
}
// The credentials have expired for this request. The current signing
// is invalid, and needs to be request because the request will fail.
if v4.isPresign {
v4.removePresign()
// Update the request's query string to ensure the values stays in
// sync in the case retrieving the new credentials fails.
v4.Request.URL.RawQuery = v4.Query.Encode()
}
}
var err error
v4.CredValues, err = v4.Credentials.Get()
if err != nil {
return err
}
if v4.isPresign {
v4.Query.Set("X-Amz-Algorithm", authHeaderPrefix)
if v4.CredValues.SessionToken != "" {
v4.Query.Set("X-Amz-Security-Token", v4.CredValues.SessionToken)
} else {
v4.Query.Del("X-Amz-Security-Token")
}
} else if v4.CredValues.SessionToken != "" {
v4.Request.Header.Set("X-Amz-Security-Token", v4.CredValues.SessionToken)
}
v4.build()
if v4.Debug.Matches(aws.LogDebugWithSigning) {
v4.logSigningInfo()
}
return nil
}
const logSignInfoMsg = `DEBUG: Request Signiture:
---[ CANONICAL STRING ]-----------------------------
%s
---[ STRING TO SIGN ]--------------------------------
%s%s
-----------------------------------------------------`
const logSignedURLMsg = `
---[ SIGNED URL ]------------------------------------
%s`
func (v4 *signer) logSigningInfo() {
signedURLMsg := ""
if v4.isPresign {
signedURLMsg = fmt.Sprintf(logSignedURLMsg, v4.Request.URL.String())
}
msg := fmt.Sprintf(logSignInfoMsg, v4.canonicalString, v4.stringToSign, signedURLMsg)
v4.Logger.Log(msg)
}
func (v4 *signer) build() {
v4.buildTime() // no depends
v4.buildCredentialString() // no depends
if v4.isPresign {
v4.buildQuery() // no depends
}
v4.buildCanonicalHeaders() // depends on cred string
v4.buildCanonicalString() // depends on canon headers / signed headers
v4.buildStringToSign() // depends on canon string
v4.buildSignature() // depends on string to sign
if v4.isPresign {
v4.Request.URL.RawQuery += "&X-Amz-Signature=" + v4.signature
} else {
parts := []string{
authHeaderPrefix + " Credential=" + v4.CredValues.AccessKeyID + "/" + v4.credentialString,
"SignedHeaders=" + v4.signedHeaders,
"Signature=" + v4.signature,
}
v4.Request.Header.Set("Authorization", strings.Join(parts, ", "))
}
}
func (v4 *signer) buildTime() {
v4.formattedTime = v4.Time.UTC().Format(timeFormat)
v4.formattedShortTime = v4.Time.UTC().Format(shortTimeFormat)
if v4.isPresign {
duration := int64(v4.ExpireTime / time.Second)
v4.Query.Set("X-Amz-Date", v4.formattedTime)
v4.Query.Set("X-Amz-Expires", strconv.FormatInt(duration, 10))
} else {
v4.Request.Header.Set("X-Amz-Date", v4.formattedTime)
}
}
func (v4 *signer) buildCredentialString() {
v4.credentialString = strings.Join([]string{
v4.formattedShortTime,
v4.Region,
v4.ServiceName,
"aws4_request",
}, "/")
if v4.isPresign {
v4.Query.Set("X-Amz-Credential", v4.CredValues.AccessKeyID+"/"+v4.credentialString)
}
}
func (v4 *signer) buildQuery() {
for k, h := range v4.Request.Header {
if strings.HasPrefix(http.CanonicalHeaderKey(k), "X-Amz-") {
continue // never hoist x-amz-* headers, they must be signed
}
if _, ok := ignoredHeaders[http.CanonicalHeaderKey(k)]; ok {
continue // never hoist ignored headers
}
v4.Request.Header.Del(k)
v4.Query.Del(k)
for _, v := range h {
v4.Query.Add(k, v)
}
}
}
func (v4 *signer) buildCanonicalHeaders() {
var headers []string
headers = append(headers, "host")
for k := range v4.Request.Header {
if _, ok := ignoredHeaders[http.CanonicalHeaderKey(k)]; ok {
continue // ignored header
}
headers = append(headers, strings.ToLower(k))
}
sort.Strings(headers)
v4.signedHeaders = strings.Join(headers, ";")
if v4.isPresign {
v4.Query.Set("X-Amz-SignedHeaders", v4.signedHeaders)
}
headerValues := make([]string, len(headers))
for i, k := range headers {
if k == "host" {
headerValues[i] = "host:" + v4.Request.URL.Host
} else {
headerValues[i] = k + ":" +
strings.Join(v4.Request.Header[http.CanonicalHeaderKey(k)], ",")
}
}
v4.canonicalHeaders = strings.Join(headerValues, "\n")
}
func (v4 *signer) buildCanonicalString() {
v4.Request.URL.RawQuery = strings.Replace(v4.Query.Encode(), "+", "%20", -1)
uri := v4.Request.URL.Opaque
if uri != "" {
uri = "/" + strings.Join(strings.Split(uri, "/")[3:], "/")
} else {
uri = v4.Request.URL.Path
}
if uri == "" {
uri = "/"
}
if v4.ServiceName != "s3" {
uri = rest.EscapePath(uri, false)
}
v4.canonicalString = strings.Join([]string{
v4.Request.Method,
uri,
v4.Request.URL.RawQuery,
v4.canonicalHeaders + "\n",
v4.signedHeaders,
v4.bodyDigest(),
}, "\n")
}
func (v4 *signer) buildStringToSign() {
v4.stringToSign = strings.Join([]string{
authHeaderPrefix,
v4.formattedTime,
v4.credentialString,
hex.EncodeToString(makeSha256([]byte(v4.canonicalString))),
}, "\n")
}
func (v4 *signer) buildSignature() {
secret := v4.CredValues.SecretAccessKey
date := makeHmac([]byte("AWS4"+secret), []byte(v4.formattedShortTime))
region := makeHmac(date, []byte(v4.Region))
service := makeHmac(region, []byte(v4.ServiceName))
credentials := makeHmac(service, []byte("aws4_request"))
signature := makeHmac(credentials, []byte(v4.stringToSign))
v4.signature = hex.EncodeToString(signature)
}
func (v4 *signer) bodyDigest() string {
hash := v4.Request.Header.Get("X-Amz-Content-Sha256")
if hash == "" {
if v4.isPresign && v4.ServiceName == "s3" {
hash = "UNSIGNED-PAYLOAD"
} else if v4.Body == nil {
hash = hex.EncodeToString(makeSha256([]byte{}))
} else {
hash = hex.EncodeToString(makeSha256Reader(v4.Body))
}
v4.Request.Header.Add("X-Amz-Content-Sha256", hash)
}
return hash
}
// isRequestSigned returns if the request is currently signed or presigned
func (v4 *signer) isRequestSigned() bool {
if v4.isPresign && v4.Query.Get("X-Amz-Signature") != "" {
return true
}
if v4.Request.Header.Get("Authorization") != "" {
return true
}
return false
}
// unsign removes signing flags for both signed and presigned requests.
func (v4 *signer) removePresign() {
v4.Query.Del("X-Amz-Algorithm")
v4.Query.Del("X-Amz-Signature")
v4.Query.Del("X-Amz-Security-Token")
v4.Query.Del("X-Amz-Date")
v4.Query.Del("X-Amz-Expires")
v4.Query.Del("X-Amz-Credential")
v4.Query.Del("X-Amz-SignedHeaders")
}
func makeHmac(key []byte, data []byte) []byte {
hash := hmac.New(sha256.New, key)
hash.Write(data)
return hash.Sum(nil)
}
func makeSha256(data []byte) []byte {
hash := sha256.New()
hash.Write(data)
return hash.Sum(nil)
}
func makeSha256Reader(reader io.ReadSeeker) []byte {
hash := sha256.New()
start, _ := reader.Seek(0, 1)
defer reader.Seek(start, 0)
io.Copy(hash, reader)
return hash.Sum(nil)
}

View file

@ -1,245 +0,0 @@
package v4
import (
"net/http"
"strings"
"testing"
"time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/stretchr/testify/assert"
)
func buildSigner(serviceName string, region string, signTime time.Time, expireTime time.Duration, body string) signer {
endpoint := "https://" + serviceName + "." + region + ".amazonaws.com"
reader := strings.NewReader(body)
req, _ := http.NewRequest("POST", endpoint, reader)
req.URL.Opaque = "//example.org/bucket/key-._~,!@#$%^&*()"
req.Header.Add("X-Amz-Target", "prefix.Operation")
req.Header.Add("Content-Type", "application/x-amz-json-1.0")
req.Header.Add("Content-Length", string(len(body)))
req.Header.Add("X-Amz-Meta-Other-Header", "some-value=!@#$%^&* (+)")
return signer{
Request: req,
Time: signTime,
ExpireTime: expireTime,
Query: req.URL.Query(),
Body: reader,
ServiceName: serviceName,
Region: region,
Credentials: credentials.NewStaticCredentials("AKID", "SECRET", "SESSION"),
}
}
func removeWS(text string) string {
text = strings.Replace(text, " ", "", -1)
text = strings.Replace(text, "\n", "", -1)
text = strings.Replace(text, "\t", "", -1)
return text
}
func assertEqual(t *testing.T, expected, given string) {
if removeWS(expected) != removeWS(given) {
t.Errorf("\nExpected: %s\nGiven: %s", expected, given)
}
}
func TestPresignRequest(t *testing.T) {
signer := buildSigner("dynamodb", "us-east-1", time.Unix(0, 0), 300*time.Second, "{}")
signer.sign()
expectedDate := "19700101T000000Z"
expectedHeaders := "host;x-amz-meta-other-header;x-amz-target"
expectedSig := "5eeedebf6f995145ce56daa02902d10485246d3defb34f97b973c1f40ab82d36"
expectedCred := "AKID/19700101/us-east-1/dynamodb/aws4_request"
q := signer.Request.URL.Query()
assert.Equal(t, expectedSig, q.Get("X-Amz-Signature"))
assert.Equal(t, expectedCred, q.Get("X-Amz-Credential"))
assert.Equal(t, expectedHeaders, q.Get("X-Amz-SignedHeaders"))
assert.Equal(t, expectedDate, q.Get("X-Amz-Date"))
}
func TestSignRequest(t *testing.T) {
signer := buildSigner("dynamodb", "us-east-1", time.Unix(0, 0), 0, "{}")
signer.sign()
expectedDate := "19700101T000000Z"
expectedSig := "AWS4-HMAC-SHA256 Credential=AKID/19700101/us-east-1/dynamodb/aws4_request, SignedHeaders=host;x-amz-date;x-amz-meta-other-header;x-amz-security-token;x-amz-target, Signature=69ada33fec48180dab153576e4dd80c4e04124f80dda3eccfed8a67c2b91ed5e"
q := signer.Request.Header
assert.Equal(t, expectedSig, q.Get("Authorization"))
assert.Equal(t, expectedDate, q.Get("X-Amz-Date"))
}
func TestSignEmptyBody(t *testing.T) {
signer := buildSigner("dynamodb", "us-east-1", time.Now(), 0, "")
signer.Body = nil
signer.sign()
hash := signer.Request.Header.Get("X-Amz-Content-Sha256")
assert.Equal(t, "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855", hash)
}
func TestSignBody(t *testing.T) {
signer := buildSigner("dynamodb", "us-east-1", time.Now(), 0, "hello")
signer.sign()
hash := signer.Request.Header.Get("X-Amz-Content-Sha256")
assert.Equal(t, "2cf24dba5fb0a30e26e83b2ac5b9e29e1b161e5c1fa7425e73043362938b9824", hash)
}
func TestSignSeekedBody(t *testing.T) {
signer := buildSigner("dynamodb", "us-east-1", time.Now(), 0, " hello")
signer.Body.Read(make([]byte, 3)) // consume first 3 bytes so body is now "hello"
signer.sign()
hash := signer.Request.Header.Get("X-Amz-Content-Sha256")
assert.Equal(t, "2cf24dba5fb0a30e26e83b2ac5b9e29e1b161e5c1fa7425e73043362938b9824", hash)
start, _ := signer.Body.Seek(0, 1)
assert.Equal(t, int64(3), start)
}
func TestPresignEmptyBodyS3(t *testing.T) {
signer := buildSigner("s3", "us-east-1", time.Now(), 5*time.Minute, "hello")
signer.sign()
hash := signer.Request.Header.Get("X-Amz-Content-Sha256")
assert.Equal(t, "UNSIGNED-PAYLOAD", hash)
}
func TestSignPrecomputedBodyChecksum(t *testing.T) {
signer := buildSigner("dynamodb", "us-east-1", time.Now(), 0, "hello")
signer.Request.Header.Set("X-Amz-Content-Sha256", "PRECOMPUTED")
signer.sign()
hash := signer.Request.Header.Get("X-Amz-Content-Sha256")
assert.Equal(t, "PRECOMPUTED", hash)
}
func TestAnonymousCredentials(t *testing.T) {
r := aws.NewRequest(
aws.NewService(&aws.Config{Credentials: credentials.AnonymousCredentials}),
&aws.Operation{
Name: "BatchGetItem",
HTTPMethod: "POST",
HTTPPath: "/",
},
nil,
nil,
)
Sign(r)
urlQ := r.HTTPRequest.URL.Query()
assert.Empty(t, urlQ.Get("X-Amz-Signature"))
assert.Empty(t, urlQ.Get("X-Amz-Credential"))
assert.Empty(t, urlQ.Get("X-Amz-SignedHeaders"))
assert.Empty(t, urlQ.Get("X-Amz-Date"))
hQ := r.HTTPRequest.Header
assert.Empty(t, hQ.Get("Authorization"))
assert.Empty(t, hQ.Get("X-Amz-Date"))
}
func TestIgnoreResignRequestWithValidCreds(t *testing.T) {
r := aws.NewRequest(
aws.NewService(&aws.Config{
Credentials: credentials.NewStaticCredentials("AKID", "SECRET", "SESSION"),
Region: aws.String("us-west-2"),
}),
&aws.Operation{
Name: "BatchGetItem",
HTTPMethod: "POST",
HTTPPath: "/",
},
nil,
nil,
)
Sign(r)
sig := r.HTTPRequest.Header.Get("Authorization")
Sign(r)
assert.Equal(t, sig, r.HTTPRequest.Header.Get("Authorization"))
}
func TestIgnorePreResignRequestWithValidCreds(t *testing.T) {
r := aws.NewRequest(
aws.NewService(&aws.Config{
Credentials: credentials.NewStaticCredentials("AKID", "SECRET", "SESSION"),
Region: aws.String("us-west-2"),
}),
&aws.Operation{
Name: "BatchGetItem",
HTTPMethod: "POST",
HTTPPath: "/",
},
nil,
nil,
)
r.ExpireTime = time.Minute * 10
Sign(r)
sig := r.HTTPRequest.Header.Get("X-Amz-Signature")
Sign(r)
assert.Equal(t, sig, r.HTTPRequest.Header.Get("X-Amz-Signature"))
}
func TestResignRequestExpiredCreds(t *testing.T) {
creds := credentials.NewStaticCredentials("AKID", "SECRET", "SESSION")
r := aws.NewRequest(
aws.NewService(&aws.Config{Credentials: creds}),
&aws.Operation{
Name: "BatchGetItem",
HTTPMethod: "POST",
HTTPPath: "/",
},
nil,
nil,
)
Sign(r)
querySig := r.HTTPRequest.Header.Get("Authorization")
creds.Expire()
Sign(r)
assert.NotEqual(t, querySig, r.HTTPRequest.Header.Get("Authorization"))
}
func TestPreResignRequestExpiredCreds(t *testing.T) {
provider := &credentials.StaticProvider{credentials.Value{"AKID", "SECRET", "SESSION"}}
creds := credentials.NewCredentials(provider)
r := aws.NewRequest(
aws.NewService(&aws.Config{Credentials: creds}),
&aws.Operation{
Name: "BatchGetItem",
HTTPMethod: "POST",
HTTPPath: "/",
},
nil,
nil,
)
r.ExpireTime = time.Minute * 10
Sign(r)
querySig := r.HTTPRequest.URL.Query().Get("X-Amz-Signature")
creds.Expire()
r.Time = time.Now().Add(time.Hour * 48)
Sign(r)
assert.NotEqual(t, querySig, r.HTTPRequest.URL.Query().Get("X-Amz-Signature"))
}
func BenchmarkPresignRequest(b *testing.B) {
signer := buildSigner("dynamodb", "us-east-1", time.Now(), 300*time.Second, "{}")
for i := 0; i < b.N; i++ {
signer.sign()
}
}
func BenchmarkSignRequest(b *testing.B) {
signer := buildSigner("dynamodb", "us-east-1", time.Now(), 0, "{}")
for i := 0; i < b.N; i++ {
signer.sign()
}
}

File diff suppressed because it is too large Load diff

View file

@ -1,42 +0,0 @@
package s3
import (
"io/ioutil"
"regexp"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/awsutil"
)
var reBucketLocation = regexp.MustCompile(`>([^<>]+)<\/Location`)
func buildGetBucketLocation(r *aws.Request) {
if r.DataFilled() {
out := r.Data.(*GetBucketLocationOutput)
b, err := ioutil.ReadAll(r.HTTPResponse.Body)
if err != nil {
r.Error = awserr.New("SerializationError", "failed reading response body", err)
return
}
match := reBucketLocation.FindSubmatch(b)
if len(match) > 1 {
loc := string(match[1])
out.LocationConstraint = &loc
}
}
}
func populateLocationConstraint(r *aws.Request) {
if r.ParamsFilled() && aws.StringValue(r.Config.Region) != "us-east-1" {
in := r.Params.(*CreateBucketInput)
if in.CreateBucketConfiguration == nil {
r.Params = awsutil.CopyOf(r.Params)
in = r.Params.(*CreateBucketInput)
in.CreateBucketConfiguration = &CreateBucketConfiguration{
LocationConstraint: r.Config.Region,
}
}
}
}

View file

@ -1,75 +0,0 @@
package s3_test
import (
"bytes"
"io/ioutil"
"net/http"
"testing"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awsutil"
"github.com/aws/aws-sdk-go/internal/test/unit"
"github.com/aws/aws-sdk-go/service/s3"
"github.com/stretchr/testify/assert"
)
var _ = unit.Imported
var s3LocationTests = []struct {
body string
loc string
}{
{`<?xml version="1.0" encoding="UTF-8"?><LocationConstraint xmlns="http://s3.amazonaws.com/doc/2006-03-01/"/>`, ``},
{`<?xml version="1.0" encoding="UTF-8"?><LocationConstraint xmlns="http://s3.amazonaws.com/doc/2006-03-01/">EU</LocationConstraint>`, `EU`},
}
func TestGetBucketLocation(t *testing.T) {
for _, test := range s3LocationTests {
s := s3.New(nil)
s.Handlers.Send.Clear()
s.Handlers.Send.PushBack(func(r *aws.Request) {
reader := ioutil.NopCloser(bytes.NewReader([]byte(test.body)))
r.HTTPResponse = &http.Response{StatusCode: 200, Body: reader}
})
resp, err := s.GetBucketLocation(&s3.GetBucketLocationInput{Bucket: aws.String("bucket")})
assert.NoError(t, err)
if test.loc == "" {
assert.Nil(t, resp.LocationConstraint)
} else {
assert.Equal(t, test.loc, *resp.LocationConstraint)
}
}
}
func TestPopulateLocationConstraint(t *testing.T) {
s := s3.New(nil)
in := &s3.CreateBucketInput{
Bucket: aws.String("bucket"),
}
req, _ := s.CreateBucketRequest(in)
err := req.Build()
assert.NoError(t, err)
assert.Equal(t, "mock-region", awsutil.ValuesAtPath(req.Params, "CreateBucketConfiguration.LocationConstraint")[0])
assert.Nil(t, in.CreateBucketConfiguration) // don't modify original params
}
func TestNoPopulateLocationConstraintIfProvided(t *testing.T) {
s := s3.New(nil)
req, _ := s.CreateBucketRequest(&s3.CreateBucketInput{
Bucket: aws.String("bucket"),
CreateBucketConfiguration: &s3.CreateBucketConfiguration{},
})
err := req.Build()
assert.NoError(t, err)
assert.Equal(t, 0, len(awsutil.ValuesAtPath(req.Params, "CreateBucketConfiguration.LocationConstraint")))
}
func TestNoPopulateLocationConstraintIfClassic(t *testing.T) {
s := s3.New(&aws.Config{Region: aws.String("us-east-1")})
req, _ := s.CreateBucketRequest(&s3.CreateBucketInput{
Bucket: aws.String("bucket"),
})
err := req.Build()
assert.NoError(t, err)
assert.Equal(t, 0, len(awsutil.ValuesAtPath(req.Params, "CreateBucketConfiguration.LocationConstraint")))
}

View file

@ -1,32 +0,0 @@
package s3
import "github.com/aws/aws-sdk-go/aws"
func init() {
initService = func(s *aws.Service) {
// Support building custom host-style bucket endpoints
s.Handlers.Build.PushFront(updateHostWithBucket)
// Require SSL when using SSE keys
s.Handlers.Validate.PushBack(validateSSERequiresSSL)
s.Handlers.Build.PushBack(computeSSEKeys)
// S3 uses custom error unmarshaling logic
s.Handlers.UnmarshalError.Clear()
s.Handlers.UnmarshalError.PushBack(unmarshalError)
}
initRequest = func(r *aws.Request) {
switch r.Operation.Name {
case opPutBucketCORS, opPutBucketLifecycle, opPutBucketPolicy, opPutBucketTagging, opDeleteObjects:
// These S3 operations require Content-MD5 to be set
r.Handlers.Build.PushBack(contentMD5)
case opGetBucketLocation:
// GetBucketLocation has custom parsing logic
r.Handlers.Unmarshal.PushFront(buildGetBucketLocation)
case opCreateBucket:
// Auto-populate LocationConstraint with current region
r.Handlers.Validate.PushFront(populateLocationConstraint)
}
}
}

View file

@ -1,90 +0,0 @@
package s3_test
import (
"crypto/md5"
"encoding/base64"
"io/ioutil"
"testing"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/internal/test/unit"
"github.com/aws/aws-sdk-go/service/s3"
"github.com/stretchr/testify/assert"
)
var _ = unit.Imported
func assertMD5(t *testing.T, req *aws.Request) {
err := req.Build()
assert.NoError(t, err)
b, _ := ioutil.ReadAll(req.HTTPRequest.Body)
out := md5.Sum(b)
assert.NotEmpty(t, b)
assert.Equal(t, base64.StdEncoding.EncodeToString(out[:]), req.HTTPRequest.Header.Get("Content-MD5"))
}
func TestMD5InPutBucketCORS(t *testing.T) {
svc := s3.New(nil)
req, _ := svc.PutBucketCORSRequest(&s3.PutBucketCORSInput{
Bucket: aws.String("bucketname"),
CORSConfiguration: &s3.CORSConfiguration{
CORSRules: []*s3.CORSRule{
{AllowedMethods: []*string{aws.String("GET")}},
},
},
})
assertMD5(t, req)
}
func TestMD5InPutBucketLifecycle(t *testing.T) {
svc := s3.New(nil)
req, _ := svc.PutBucketLifecycleRequest(&s3.PutBucketLifecycleInput{
Bucket: aws.String("bucketname"),
LifecycleConfiguration: &s3.LifecycleConfiguration{
Rules: []*s3.LifecycleRule{
{
ID: aws.String("ID"),
Prefix: aws.String("Prefix"),
Status: aws.String("Enabled"),
},
},
},
})
assertMD5(t, req)
}
func TestMD5InPutBucketPolicy(t *testing.T) {
svc := s3.New(nil)
req, _ := svc.PutBucketPolicyRequest(&s3.PutBucketPolicyInput{
Bucket: aws.String("bucketname"),
Policy: aws.String("{}"),
})
assertMD5(t, req)
}
func TestMD5InPutBucketTagging(t *testing.T) {
svc := s3.New(nil)
req, _ := svc.PutBucketTaggingRequest(&s3.PutBucketTaggingInput{
Bucket: aws.String("bucketname"),
Tagging: &s3.Tagging{
TagSet: []*s3.Tag{
{Key: aws.String("KEY"), Value: aws.String("VALUE")},
},
},
})
assertMD5(t, req)
}
func TestMD5InDeleteObjects(t *testing.T) {
svc := s3.New(nil)
req, _ := svc.DeleteObjectsRequest(&s3.DeleteObjectsInput{
Bucket: aws.String("bucketname"),
Delete: &s3.Delete{
Objects: []*s3.ObjectIdentifier{
{Key: aws.String("key")},
},
},
})
assertMD5(t, req)
}

File diff suppressed because it is too large Load diff

View file

@ -1,53 +0,0 @@
package s3
import (
"regexp"
"strings"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awsutil"
)
var reDomain = regexp.MustCompile(`^[a-z0-9][a-z0-9\.\-]{1,61}[a-z0-9]$`)
var reIPAddress = regexp.MustCompile(`^(\d+\.){3}\d+$`)
// dnsCompatibleBucketName returns true if the bucket name is DNS compatible.
// Buckets created outside of the classic region MUST be DNS compatible.
func dnsCompatibleBucketName(bucket string) bool {
return reDomain.MatchString(bucket) &&
!reIPAddress.MatchString(bucket) &&
!strings.Contains(bucket, "..")
}
// hostStyleBucketName returns true if the request should put the bucket in
// the host. This is false if S3ForcePathStyle is explicitly set or if the
// bucket is not DNS compatible.
func hostStyleBucketName(r *aws.Request, bucket string) bool {
if aws.BoolValue(r.Config.S3ForcePathStyle) {
return false
}
// Bucket might be DNS compatible but dots in the hostname will fail
// certificate validation, so do not use host-style.
if r.HTTPRequest.URL.Scheme == "https" && strings.Contains(bucket, ".") {
return false
}
// Use host-style if the bucket is DNS compatible
return dnsCompatibleBucketName(bucket)
}
func updateHostWithBucket(r *aws.Request) {
b := awsutil.ValuesAtPath(r.Params, "Bucket")
if len(b) == 0 {
return
}
if bucket := b[0].(string); bucket != "" && hostStyleBucketName(r, bucket) {
r.HTTPRequest.URL.Host = bucket + "." + r.HTTPRequest.URL.Host
r.HTTPRequest.URL.Path = strings.Replace(r.HTTPRequest.URL.Path, "/{Bucket}", "", -1)
if r.HTTPRequest.URL.Path == "" {
r.HTTPRequest.URL.Path = "/"
}
}
}

View file

@ -1,61 +0,0 @@
package s3_test
import (
"testing"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/internal/test/unit"
"github.com/aws/aws-sdk-go/service/s3"
"github.com/stretchr/testify/assert"
)
type s3BucketTest struct {
bucket string
url string
}
var (
_ = unit.Imported
sslTests = []s3BucketTest{
{"abc", "https://abc.s3.mock-region.amazonaws.com/"},
{"a$b$c", "https://s3.mock-region.amazonaws.com/a%24b%24c"},
{"a.b.c", "https://s3.mock-region.amazonaws.com/a.b.c"},
{"a..bc", "https://s3.mock-region.amazonaws.com/a..bc"},
}
nosslTests = []s3BucketTest{
{"a.b.c", "http://a.b.c.s3.mock-region.amazonaws.com/"},
{"a..bc", "http://s3.mock-region.amazonaws.com/a..bc"},
}
forcepathTests = []s3BucketTest{
{"abc", "https://s3.mock-region.amazonaws.com/abc"},
{"a$b$c", "https://s3.mock-region.amazonaws.com/a%24b%24c"},
{"a.b.c", "https://s3.mock-region.amazonaws.com/a.b.c"},
{"a..bc", "https://s3.mock-region.amazonaws.com/a..bc"},
}
)
func runTests(t *testing.T, svc *s3.S3, tests []s3BucketTest) {
for _, test := range tests {
req, _ := svc.ListObjectsRequest(&s3.ListObjectsInput{Bucket: &test.bucket})
req.Build()
assert.Equal(t, test.url, req.HTTPRequest.URL.String())
}
}
func TestHostStyleBucketBuild(t *testing.T) {
s := s3.New(nil)
runTests(t, s, sslTests)
}
func TestHostStyleBucketBuildNoSSL(t *testing.T) {
s := s3.New(&aws.Config{DisableSSL: aws.Bool(true)})
runTests(t, s, nosslTests)
}
func TestPathStyleBucketBuild(t *testing.T) {
s := s3.New(&aws.Config{S3ForcePathStyle: aws.Bool(true)})
runTests(t, s, forcepathTests)
}

View file

@ -1,236 +0,0 @@
// THIS FILE IS AUTOMATICALLY GENERATED. DO NOT EDIT.
// Package s3iface provides an interface for the Amazon Simple Storage Service.
package s3iface
import (
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/service/s3"
)
// S3API is the interface type for s3.S3.
type S3API interface {
AbortMultipartUploadRequest(*s3.AbortMultipartUploadInput) (*aws.Request, *s3.AbortMultipartUploadOutput)
AbortMultipartUpload(*s3.AbortMultipartUploadInput) (*s3.AbortMultipartUploadOutput, error)
CompleteMultipartUploadRequest(*s3.CompleteMultipartUploadInput) (*aws.Request, *s3.CompleteMultipartUploadOutput)
CompleteMultipartUpload(*s3.CompleteMultipartUploadInput) (*s3.CompleteMultipartUploadOutput, error)
CopyObjectRequest(*s3.CopyObjectInput) (*aws.Request, *s3.CopyObjectOutput)
CopyObject(*s3.CopyObjectInput) (*s3.CopyObjectOutput, error)
CreateBucketRequest(*s3.CreateBucketInput) (*aws.Request, *s3.CreateBucketOutput)
CreateBucket(*s3.CreateBucketInput) (*s3.CreateBucketOutput, error)
CreateMultipartUploadRequest(*s3.CreateMultipartUploadInput) (*aws.Request, *s3.CreateMultipartUploadOutput)
CreateMultipartUpload(*s3.CreateMultipartUploadInput) (*s3.CreateMultipartUploadOutput, error)
DeleteBucketRequest(*s3.DeleteBucketInput) (*aws.Request, *s3.DeleteBucketOutput)
DeleteBucket(*s3.DeleteBucketInput) (*s3.DeleteBucketOutput, error)
DeleteBucketCORSRequest(*s3.DeleteBucketCORSInput) (*aws.Request, *s3.DeleteBucketCORSOutput)
DeleteBucketCORS(*s3.DeleteBucketCORSInput) (*s3.DeleteBucketCORSOutput, error)
DeleteBucketLifecycleRequest(*s3.DeleteBucketLifecycleInput) (*aws.Request, *s3.DeleteBucketLifecycleOutput)
DeleteBucketLifecycle(*s3.DeleteBucketLifecycleInput) (*s3.DeleteBucketLifecycleOutput, error)
DeleteBucketPolicyRequest(*s3.DeleteBucketPolicyInput) (*aws.Request, *s3.DeleteBucketPolicyOutput)
DeleteBucketPolicy(*s3.DeleteBucketPolicyInput) (*s3.DeleteBucketPolicyOutput, error)
DeleteBucketReplicationRequest(*s3.DeleteBucketReplicationInput) (*aws.Request, *s3.DeleteBucketReplicationOutput)
DeleteBucketReplication(*s3.DeleteBucketReplicationInput) (*s3.DeleteBucketReplicationOutput, error)
DeleteBucketTaggingRequest(*s3.DeleteBucketTaggingInput) (*aws.Request, *s3.DeleteBucketTaggingOutput)
DeleteBucketTagging(*s3.DeleteBucketTaggingInput) (*s3.DeleteBucketTaggingOutput, error)
DeleteBucketWebsiteRequest(*s3.DeleteBucketWebsiteInput) (*aws.Request, *s3.DeleteBucketWebsiteOutput)
DeleteBucketWebsite(*s3.DeleteBucketWebsiteInput) (*s3.DeleteBucketWebsiteOutput, error)
DeleteObjectRequest(*s3.DeleteObjectInput) (*aws.Request, *s3.DeleteObjectOutput)
DeleteObject(*s3.DeleteObjectInput) (*s3.DeleteObjectOutput, error)
DeleteObjectsRequest(*s3.DeleteObjectsInput) (*aws.Request, *s3.DeleteObjectsOutput)
DeleteObjects(*s3.DeleteObjectsInput) (*s3.DeleteObjectsOutput, error)
GetBucketACLRequest(*s3.GetBucketACLInput) (*aws.Request, *s3.GetBucketACLOutput)
GetBucketACL(*s3.GetBucketACLInput) (*s3.GetBucketACLOutput, error)
GetBucketCORSRequest(*s3.GetBucketCORSInput) (*aws.Request, *s3.GetBucketCORSOutput)
GetBucketCORS(*s3.GetBucketCORSInput) (*s3.GetBucketCORSOutput, error)
GetBucketLifecycleRequest(*s3.GetBucketLifecycleInput) (*aws.Request, *s3.GetBucketLifecycleOutput)
GetBucketLifecycle(*s3.GetBucketLifecycleInput) (*s3.GetBucketLifecycleOutput, error)
GetBucketLocationRequest(*s3.GetBucketLocationInput) (*aws.Request, *s3.GetBucketLocationOutput)
GetBucketLocation(*s3.GetBucketLocationInput) (*s3.GetBucketLocationOutput, error)
GetBucketLoggingRequest(*s3.GetBucketLoggingInput) (*aws.Request, *s3.GetBucketLoggingOutput)
GetBucketLogging(*s3.GetBucketLoggingInput) (*s3.GetBucketLoggingOutput, error)
GetBucketNotificationRequest(*s3.GetBucketNotificationConfigurationRequest) (*aws.Request, *s3.NotificationConfigurationDeprecated)
GetBucketNotification(*s3.GetBucketNotificationConfigurationRequest) (*s3.NotificationConfigurationDeprecated, error)
GetBucketNotificationConfigurationRequest(*s3.GetBucketNotificationConfigurationRequest) (*aws.Request, *s3.NotificationConfiguration)
GetBucketNotificationConfiguration(*s3.GetBucketNotificationConfigurationRequest) (*s3.NotificationConfiguration, error)
GetBucketPolicyRequest(*s3.GetBucketPolicyInput) (*aws.Request, *s3.GetBucketPolicyOutput)
GetBucketPolicy(*s3.GetBucketPolicyInput) (*s3.GetBucketPolicyOutput, error)
GetBucketReplicationRequest(*s3.GetBucketReplicationInput) (*aws.Request, *s3.GetBucketReplicationOutput)
GetBucketReplication(*s3.GetBucketReplicationInput) (*s3.GetBucketReplicationOutput, error)
GetBucketRequestPaymentRequest(*s3.GetBucketRequestPaymentInput) (*aws.Request, *s3.GetBucketRequestPaymentOutput)
GetBucketRequestPayment(*s3.GetBucketRequestPaymentInput) (*s3.GetBucketRequestPaymentOutput, error)
GetBucketTaggingRequest(*s3.GetBucketTaggingInput) (*aws.Request, *s3.GetBucketTaggingOutput)
GetBucketTagging(*s3.GetBucketTaggingInput) (*s3.GetBucketTaggingOutput, error)
GetBucketVersioningRequest(*s3.GetBucketVersioningInput) (*aws.Request, *s3.GetBucketVersioningOutput)
GetBucketVersioning(*s3.GetBucketVersioningInput) (*s3.GetBucketVersioningOutput, error)
GetBucketWebsiteRequest(*s3.GetBucketWebsiteInput) (*aws.Request, *s3.GetBucketWebsiteOutput)
GetBucketWebsite(*s3.GetBucketWebsiteInput) (*s3.GetBucketWebsiteOutput, error)
GetObjectRequest(*s3.GetObjectInput) (*aws.Request, *s3.GetObjectOutput)
GetObject(*s3.GetObjectInput) (*s3.GetObjectOutput, error)
GetObjectACLRequest(*s3.GetObjectACLInput) (*aws.Request, *s3.GetObjectACLOutput)
GetObjectACL(*s3.GetObjectACLInput) (*s3.GetObjectACLOutput, error)
GetObjectTorrentRequest(*s3.GetObjectTorrentInput) (*aws.Request, *s3.GetObjectTorrentOutput)
GetObjectTorrent(*s3.GetObjectTorrentInput) (*s3.GetObjectTorrentOutput, error)
HeadBucketRequest(*s3.HeadBucketInput) (*aws.Request, *s3.HeadBucketOutput)
HeadBucket(*s3.HeadBucketInput) (*s3.HeadBucketOutput, error)
HeadObjectRequest(*s3.HeadObjectInput) (*aws.Request, *s3.HeadObjectOutput)
HeadObject(*s3.HeadObjectInput) (*s3.HeadObjectOutput, error)
ListBucketsRequest(*s3.ListBucketsInput) (*aws.Request, *s3.ListBucketsOutput)
ListBuckets(*s3.ListBucketsInput) (*s3.ListBucketsOutput, error)
ListMultipartUploadsRequest(*s3.ListMultipartUploadsInput) (*aws.Request, *s3.ListMultipartUploadsOutput)
ListMultipartUploads(*s3.ListMultipartUploadsInput) (*s3.ListMultipartUploadsOutput, error)
ListMultipartUploadsPages(*s3.ListMultipartUploadsInput, func(*s3.ListMultipartUploadsOutput, bool) bool) error
ListObjectVersionsRequest(*s3.ListObjectVersionsInput) (*aws.Request, *s3.ListObjectVersionsOutput)
ListObjectVersions(*s3.ListObjectVersionsInput) (*s3.ListObjectVersionsOutput, error)
ListObjectVersionsPages(*s3.ListObjectVersionsInput, func(*s3.ListObjectVersionsOutput, bool) bool) error
ListObjectsRequest(*s3.ListObjectsInput) (*aws.Request, *s3.ListObjectsOutput)
ListObjects(*s3.ListObjectsInput) (*s3.ListObjectsOutput, error)
ListObjectsPages(*s3.ListObjectsInput, func(*s3.ListObjectsOutput, bool) bool) error
ListPartsRequest(*s3.ListPartsInput) (*aws.Request, *s3.ListPartsOutput)
ListParts(*s3.ListPartsInput) (*s3.ListPartsOutput, error)
ListPartsPages(*s3.ListPartsInput, func(*s3.ListPartsOutput, bool) bool) error
PutBucketACLRequest(*s3.PutBucketACLInput) (*aws.Request, *s3.PutBucketACLOutput)
PutBucketACL(*s3.PutBucketACLInput) (*s3.PutBucketACLOutput, error)
PutBucketCORSRequest(*s3.PutBucketCORSInput) (*aws.Request, *s3.PutBucketCORSOutput)
PutBucketCORS(*s3.PutBucketCORSInput) (*s3.PutBucketCORSOutput, error)
PutBucketLifecycleRequest(*s3.PutBucketLifecycleInput) (*aws.Request, *s3.PutBucketLifecycleOutput)
PutBucketLifecycle(*s3.PutBucketLifecycleInput) (*s3.PutBucketLifecycleOutput, error)
PutBucketLoggingRequest(*s3.PutBucketLoggingInput) (*aws.Request, *s3.PutBucketLoggingOutput)
PutBucketLogging(*s3.PutBucketLoggingInput) (*s3.PutBucketLoggingOutput, error)
PutBucketNotificationRequest(*s3.PutBucketNotificationInput) (*aws.Request, *s3.PutBucketNotificationOutput)
PutBucketNotification(*s3.PutBucketNotificationInput) (*s3.PutBucketNotificationOutput, error)
PutBucketNotificationConfigurationRequest(*s3.PutBucketNotificationConfigurationInput) (*aws.Request, *s3.PutBucketNotificationConfigurationOutput)
PutBucketNotificationConfiguration(*s3.PutBucketNotificationConfigurationInput) (*s3.PutBucketNotificationConfigurationOutput, error)
PutBucketPolicyRequest(*s3.PutBucketPolicyInput) (*aws.Request, *s3.PutBucketPolicyOutput)
PutBucketPolicy(*s3.PutBucketPolicyInput) (*s3.PutBucketPolicyOutput, error)
PutBucketReplicationRequest(*s3.PutBucketReplicationInput) (*aws.Request, *s3.PutBucketReplicationOutput)
PutBucketReplication(*s3.PutBucketReplicationInput) (*s3.PutBucketReplicationOutput, error)
PutBucketRequestPaymentRequest(*s3.PutBucketRequestPaymentInput) (*aws.Request, *s3.PutBucketRequestPaymentOutput)
PutBucketRequestPayment(*s3.PutBucketRequestPaymentInput) (*s3.PutBucketRequestPaymentOutput, error)
PutBucketTaggingRequest(*s3.PutBucketTaggingInput) (*aws.Request, *s3.PutBucketTaggingOutput)
PutBucketTagging(*s3.PutBucketTaggingInput) (*s3.PutBucketTaggingOutput, error)
PutBucketVersioningRequest(*s3.PutBucketVersioningInput) (*aws.Request, *s3.PutBucketVersioningOutput)
PutBucketVersioning(*s3.PutBucketVersioningInput) (*s3.PutBucketVersioningOutput, error)
PutBucketWebsiteRequest(*s3.PutBucketWebsiteInput) (*aws.Request, *s3.PutBucketWebsiteOutput)
PutBucketWebsite(*s3.PutBucketWebsiteInput) (*s3.PutBucketWebsiteOutput, error)
PutObjectRequest(*s3.PutObjectInput) (*aws.Request, *s3.PutObjectOutput)
PutObject(*s3.PutObjectInput) (*s3.PutObjectOutput, error)
PutObjectACLRequest(*s3.PutObjectACLInput) (*aws.Request, *s3.PutObjectACLOutput)
PutObjectACL(*s3.PutObjectACLInput) (*s3.PutObjectACLOutput, error)
RestoreObjectRequest(*s3.RestoreObjectInput) (*aws.Request, *s3.RestoreObjectOutput)
RestoreObject(*s3.RestoreObjectInput) (*s3.RestoreObjectOutput, error)
UploadPartRequest(*s3.UploadPartInput) (*aws.Request, *s3.UploadPartOutput)
UploadPart(*s3.UploadPartInput) (*s3.UploadPartOutput, error)
UploadPartCopyRequest(*s3.UploadPartCopyInput) (*aws.Request, *s3.UploadPartCopyOutput)
UploadPartCopy(*s3.UploadPartCopyInput) (*s3.UploadPartCopyOutput, error)
}

View file

@ -1,15 +0,0 @@
// THIS FILE IS AUTOMATICALLY GENERATED. DO NOT EDIT.
package s3iface_test
import (
"testing"
"github.com/aws/aws-sdk-go/service/s3"
"github.com/aws/aws-sdk-go/service/s3/s3iface"
"github.com/stretchr/testify/assert"
)
func TestInterface(t *testing.T) {
assert.Implements(t, (*s3iface.S3API)(nil), s3.New(nil))
}

View file

@ -1,257 +0,0 @@
package s3manager
import (
"fmt"
"io"
"strconv"
"strings"
"sync"
"time"
"github.com/aws/aws-sdk-go/aws/awsutil"
"github.com/aws/aws-sdk-go/service/s3"
)
// The default range of bytes to get at a time when using Download().
var DefaultDownloadPartSize int64 = 1024 * 1024 * 5
// The default number of goroutines to spin up when using Download().
var DefaultDownloadConcurrency = 5
// The default set of options used when opts is nil in Download().
var DefaultDownloadOptions = &DownloadOptions{
PartSize: DefaultDownloadPartSize,
Concurrency: DefaultDownloadConcurrency,
}
// DownloadOptions keeps tracks of extra options to pass to an Download() call.
type DownloadOptions struct {
// The buffer size (in bytes) to use when buffering data into chunks and
// sending them as parts to S3. The minimum allowed part size is 5MB, and
// if this value is set to zero, the DefaultPartSize value will be used.
PartSize int64
// The number of goroutines to spin up in parallel when sending parts.
// If this is set to zero, the DefaultConcurrency value will be used.
Concurrency int
// An S3 client to use when performing downloads. Leave this as nil to use
// a default client.
S3 *s3.S3
}
// NewDownloader creates a new Downloader structure that downloads an object
// from S3 in concurrent chunks. Pass in an optional DownloadOptions struct
// to customize the downloader behavior.
func NewDownloader(opts *DownloadOptions) *Downloader {
if opts == nil {
opts = DefaultDownloadOptions
}
return &Downloader{opts: opts}
}
// The Downloader structure that calls Download(). It is safe to call Download()
// on this structure for multiple objects and across concurrent goroutines.
type Downloader struct {
opts *DownloadOptions
}
// Download downloads an object in S3 and writes the payload into w using
// concurrent GET requests.
//
// It is safe to call this method for multiple objects and across concurrent
// goroutines.
func (d *Downloader) Download(w io.WriterAt, input *s3.GetObjectInput) (n int64, err error) {
impl := downloader{w: w, in: input, opts: *d.opts}
return impl.download()
}
// downloader is the implementation structure used internally by Downloader.
type downloader struct {
opts DownloadOptions
in *s3.GetObjectInput
w io.WriterAt
wg sync.WaitGroup
m sync.Mutex
pos int64
totalBytes int64
written int64
err error
}
// init initializes the downloader with default options.
func (d *downloader) init() {
d.totalBytes = -1
if d.opts.Concurrency == 0 {
d.opts.Concurrency = DefaultDownloadConcurrency
}
if d.opts.PartSize == 0 {
d.opts.PartSize = DefaultDownloadPartSize
}
if d.opts.S3 == nil {
d.opts.S3 = s3.New(nil)
}
}
// download performs the implementation of the object download across ranged
// GETs.
func (d *downloader) download() (n int64, err error) {
d.init()
// Spin up workers
ch := make(chan dlchunk, d.opts.Concurrency)
for i := 0; i < d.opts.Concurrency; i++ {
d.wg.Add(1)
go d.downloadPart(ch)
}
// Assign work
for d.geterr() == nil {
if d.pos != 0 {
// This is not the first chunk, let's wait until we know the total
// size of the payload so we can see if we have read the entire
// object.
total := d.getTotalBytes()
if total < 0 {
// Total has not yet been set, so sleep and loop around while
// waiting for our first worker to resolve this value.
time.Sleep(10 * time.Millisecond)
continue
} else if d.pos >= total {
break // We're finished queueing chunks
}
}
// Queue the next range of bytes to read.
ch <- dlchunk{w: d.w, start: d.pos, size: d.opts.PartSize}
d.pos += d.opts.PartSize
}
// Wait for completion
close(ch)
d.wg.Wait()
// Return error
return d.written, d.err
}
// downloadPart is an individual goroutine worker reading from the ch channel
// and performing a GetObject request on the data with a given byte range.
//
// If this is the first worker, this operation also resolves the total number
// of bytes to be read so that the worker manager knows when it is finished.
func (d *downloader) downloadPart(ch chan dlchunk) {
defer d.wg.Done()
for {
chunk, ok := <-ch
if !ok {
break
}
if d.geterr() == nil {
// Get the next byte range of data
in := &s3.GetObjectInput{}
awsutil.Copy(in, d.in)
rng := fmt.Sprintf("bytes=%d-%d",
chunk.start, chunk.start+chunk.size-1)
in.Range = &rng
resp, err := d.opts.S3.GetObject(in)
if err != nil {
d.seterr(err)
} else {
d.setTotalBytes(resp) // Set total if not yet set.
n, err := io.Copy(&chunk, resp.Body)
resp.Body.Close()
if err != nil {
d.seterr(err)
}
d.incrwritten(n)
}
}
}
}
// getTotalBytes is a thread-safe getter for retrieving the total byte status.
func (d *downloader) getTotalBytes() int64 {
d.m.Lock()
defer d.m.Unlock()
return d.totalBytes
}
// getTotalBytes is a thread-safe setter for setting the total byte status.
func (d *downloader) setTotalBytes(resp *s3.GetObjectOutput) {
d.m.Lock()
defer d.m.Unlock()
if d.totalBytes >= 0 {
return
}
parts := strings.Split(*resp.ContentRange, "/")
total, err := strconv.ParseInt(parts[len(parts)-1], 10, 64)
if err != nil {
d.err = err
return
}
d.totalBytes = total
}
func (d *downloader) incrwritten(n int64) {
d.m.Lock()
defer d.m.Unlock()
d.written += n
}
// geterr is a thread-safe getter for the error object
func (d *downloader) geterr() error {
d.m.Lock()
defer d.m.Unlock()
return d.err
}
// seterr is a thread-safe setter for the error object
func (d *downloader) seterr(e error) {
d.m.Lock()
defer d.m.Unlock()
d.err = e
}
// dlchunk represents a single chunk of data to write by the worker routine.
// This structure also implements an io.SectionReader style interface for
// io.WriterAt, effectively making it an io.SectionWriter (which does not
// exist).
type dlchunk struct {
w io.WriterAt
start int64
size int64
cur int64
}
// Write wraps io.WriterAt for the dlchunk, writing from the dlchunk's start
// position to its end (or EOF).
func (c *dlchunk) Write(p []byte) (n int, err error) {
if c.cur >= c.size {
return 0, io.EOF
}
n, err = c.w.WriteAt(p, c.start+c.cur)
c.cur += int64(n)
return
}

View file

@ -1,165 +0,0 @@
package s3manager_test
import (
"bytes"
"fmt"
"io"
"io/ioutil"
"net/http"
"regexp"
"strconv"
"sync"
"testing"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/internal/test/unit"
"github.com/aws/aws-sdk-go/service/s3"
"github.com/aws/aws-sdk-go/service/s3/s3manager"
"github.com/stretchr/testify/assert"
)
var _ = unit.Imported
func dlLoggingSvc(data []byte) (*s3.S3, *[]string, *[]string) {
var m sync.Mutex
names := []string{}
ranges := []string{}
svc := s3.New(nil)
svc.Handlers.Send.Clear()
svc.Handlers.Send.PushBack(func(r *aws.Request) {
m.Lock()
defer m.Unlock()
names = append(names, r.Operation.Name)
ranges = append(ranges, *r.Params.(*s3.GetObjectInput).Range)
rerng := regexp.MustCompile(`bytes=(\d+)-(\d+)`)
rng := rerng.FindStringSubmatch(r.HTTPRequest.Header.Get("Range"))
start, _ := strconv.ParseInt(rng[1], 10, 64)
fin, _ := strconv.ParseInt(rng[2], 10, 64)
fin++
if fin > int64(len(data)) {
fin = int64(len(data))
}
r.HTTPResponse = &http.Response{
StatusCode: 200,
Body: ioutil.NopCloser(bytes.NewReader(data[start:fin])),
Header: http.Header{},
}
r.HTTPResponse.Header.Set("Content-Range", fmt.Sprintf("bytes %d-%d/%d",
start, fin, len(data)))
})
return svc, &names, &ranges
}
type dlwriter struct {
buf []byte
}
func newDLWriter(size int) *dlwriter {
return &dlwriter{buf: make([]byte, size)}
}
func (d dlwriter) WriteAt(p []byte, pos int64) (n int, err error) {
if pos > int64(len(d.buf)) {
return 0, io.EOF
}
written := 0
for i, b := range p {
if i >= len(d.buf) {
break
}
d.buf[pos+int64(i)] = b
written++
}
return written, nil
}
func TestDownloadOrder(t *testing.T) {
s, names, ranges := dlLoggingSvc(buf12MB)
opts := &s3manager.DownloadOptions{S3: s, Concurrency: 1}
d := s3manager.NewDownloader(opts)
w := newDLWriter(len(buf12MB))
n, err := d.Download(w, &s3.GetObjectInput{
Bucket: aws.String("bucket"),
Key: aws.String("key"),
})
assert.Nil(t, err)
assert.Equal(t, int64(len(buf12MB)), n)
assert.Equal(t, []string{"GetObject", "GetObject", "GetObject"}, *names)
assert.Equal(t, []string{"bytes=0-5242879", "bytes=5242880-10485759", "bytes=10485760-15728639"}, *ranges)
count := 0
for _, b := range w.buf {
count += int(b)
}
assert.Equal(t, 0, count)
}
func TestDownloadZero(t *testing.T) {
s, names, ranges := dlLoggingSvc([]byte{})
opts := &s3manager.DownloadOptions{S3: s}
d := s3manager.NewDownloader(opts)
w := newDLWriter(0)
n, err := d.Download(w, &s3.GetObjectInput{
Bucket: aws.String("bucket"),
Key: aws.String("key"),
})
assert.Nil(t, err)
assert.Equal(t, int64(0), n)
assert.Equal(t, []string{"GetObject"}, *names)
assert.Equal(t, []string{"bytes=0-5242879"}, *ranges)
}
func TestDownloadSetPartSize(t *testing.T) {
s, names, ranges := dlLoggingSvc([]byte{1, 2, 3})
opts := &s3manager.DownloadOptions{S3: s, PartSize: 1, Concurrency: 1}
d := s3manager.NewDownloader(opts)
w := newDLWriter(3)
n, err := d.Download(w, &s3.GetObjectInput{
Bucket: aws.String("bucket"),
Key: aws.String("key"),
})
assert.Nil(t, err)
assert.Equal(t, int64(3), n)
assert.Equal(t, []string{"GetObject", "GetObject", "GetObject"}, *names)
assert.Equal(t, []string{"bytes=0-0", "bytes=1-1", "bytes=2-2"}, *ranges)
assert.Equal(t, []byte{1, 2, 3}, w.buf)
}
func TestDownloadError(t *testing.T) {
s, names, _ := dlLoggingSvc([]byte{1, 2, 3})
opts := &s3manager.DownloadOptions{S3: s, PartSize: 1, Concurrency: 1}
num := 0
s.Handlers.Send.PushBack(func(r *aws.Request) {
num++
if num > 1 {
r.HTTPResponse.StatusCode = 400
r.HTTPResponse.Body = ioutil.NopCloser(bytes.NewReader([]byte{}))
}
})
d := s3manager.NewDownloader(opts)
w := newDLWriter(3)
n, err := d.Download(w, &s3.GetObjectInput{
Bucket: aws.String("bucket"),
Key: aws.String("key"),
})
assert.NotNil(t, err)
assert.Equal(t, int64(1), n)
assert.Equal(t, []string{"GetObject", "GetObject"}, *names)
assert.Equal(t, []byte{1, 0, 0}, w.buf)
}

View file

@ -1,562 +0,0 @@
package s3manager
import (
"bytes"
"fmt"
"io"
"sort"
"sync"
"time"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/awsutil"
"github.com/aws/aws-sdk-go/service/s3"
)
// The maximum allowed number of parts in a multi-part upload on Amazon S3.
var MaxUploadParts = 10000
// The minimum allowed part size when uploading a part to Amazon S3.
var MinUploadPartSize int64 = 1024 * 1024 * 5
// The default part size to buffer chunks of a payload into.
var DefaultUploadPartSize = MinUploadPartSize
// The default number of goroutines to spin up when using Upload().
var DefaultUploadConcurrency = 5
// The default set of options used when opts is nil in Upload().
var DefaultUploadOptions = &UploadOptions{
PartSize: DefaultUploadPartSize,
Concurrency: DefaultUploadConcurrency,
LeavePartsOnError: false,
S3: nil,
}
// A MultiUploadFailure wraps a failed S3 multipart upload. An error returned
// will satisfy this interface when a multi part upload failed to upload all
// chucks to S3. In the case of a failure the UploadID is needed to operate on
// the chunks, if any, which were uploaded.
//
// Example:
//
// u := s3manager.NewUploader(opts)
// output, err := u.upload(input)
// if err != nil {
// if multierr, ok := err.(MultiUploadFailure); ok {
// // Process error and its associated uploadID
// fmt.Println("Error:", multierr.Code(), multierr.Message(), multierr.UploadID())
// } else {
// // Process error generically
// fmt.Println("Error:", err.Error())
// }
// }
//
type MultiUploadFailure interface {
awserr.Error
// Returns the upload id for the S3 multipart upload that failed.
UploadID() string
}
// So that the Error interface type can be included as an anonymous field
// in the multiUploadError struct and not conflict with the error.Error() method.
type awsError awserr.Error
// A multiUploadError wraps the upload ID of a failed s3 multipart upload.
// Composed of BaseError for code, message, and original error
//
// Should be used for an error that occurred failing a S3 multipart upload,
// and a upload ID is available. If an uploadID is not available a more relevant
type multiUploadError struct {
awsError
// ID for multipart upload which failed.
uploadID string
}
// Error returns the string representation of the error.
//
// See apierr.BaseError ErrorWithExtra for output format
//
// Satisfies the error interface.
func (m multiUploadError) Error() string {
extra := fmt.Sprintf("upload id: %s", m.uploadID)
return awserr.SprintError(m.Code(), m.Message(), extra, m.OrigErr())
}
// String returns the string representation of the error.
// Alias for Error to satisfy the stringer interface.
func (m multiUploadError) String() string {
return m.Error()
}
// UploadID returns the id of the S3 upload which failed.
func (m multiUploadError) UploadID() string {
return m.uploadID
}
// UploadInput contains all input for upload requests to Amazon S3.
type UploadInput struct {
// The canned ACL to apply to the object.
ACL *string `location:"header" locationName:"x-amz-acl" type:"string"`
Bucket *string `location:"uri" locationName:"Bucket" type:"string" required:"true"`
// Specifies caching behavior along the request/reply chain.
CacheControl *string `location:"header" locationName:"Cache-Control" type:"string"`
// Specifies presentational information for the object.
ContentDisposition *string `location:"header" locationName:"Content-Disposition" type:"string"`
// Specifies what content encodings have been applied to the object and thus
// what decoding mechanisms must be applied to obtain the media-type referenced
// by the Content-Type header field.
ContentEncoding *string `location:"header" locationName:"Content-Encoding" type:"string"`
// The language the content is in.
ContentLanguage *string `location:"header" locationName:"Content-Language" type:"string"`
// A standard MIME type describing the format of the object data.
ContentType *string `location:"header" locationName:"Content-Type" type:"string"`
// The date and time at which the object is no longer cacheable.
Expires *time.Time `location:"header" locationName:"Expires" type:"timestamp" timestampFormat:"rfc822"`
// Gives the grantee READ, READ_ACP, and WRITE_ACP permissions on the object.
GrantFullControl *string `location:"header" locationName:"x-amz-grant-full-control" type:"string"`
// Allows grantee to read the object data and its metadata.
GrantRead *string `location:"header" locationName:"x-amz-grant-read" type:"string"`
// Allows grantee to read the object ACL.
GrantReadACP *string `location:"header" locationName:"x-amz-grant-read-acp" type:"string"`
// Allows grantee to write the ACL for the applicable object.
GrantWriteACP *string `location:"header" locationName:"x-amz-grant-write-acp" type:"string"`
Key *string `location:"uri" locationName:"Key" type:"string" required:"true"`
// A map of metadata to store with the object in S3.
Metadata map[string]*string `location:"headers" locationName:"x-amz-meta-" type:"map"`
// Confirms that the requester knows that she or he will be charged for the
// request. Bucket owners need not specify this parameter in their requests.
// Documentation on downloading objects from requester pays buckets can be found
// at http://docs.aws.amazon.com/AmazonS3/latest/dev/ObjectsinRequesterPaysBuckets.html
RequestPayer *string `location:"header" locationName:"x-amz-request-payer" type:"string"`
// Specifies the algorithm to use to when encrypting the object (e.g., AES256,
// aws:kms).
SSECustomerAlgorithm *string `location:"header" locationName:"x-amz-server-side-encryption-customer-algorithm" type:"string"`
// Specifies the customer-provided encryption key for Amazon S3 to use in encrypting
// data. This value is used to store the object and then it is discarded; Amazon
// does not store the encryption key. The key must be appropriate for use with
// the algorithm specified in the x-amz-server-side-encryption-customer-algorithm
// header.
SSECustomerKey *string `location:"header" locationName:"x-amz-server-side-encryption-customer-key" type:"string"`
// Specifies the 128-bit MD5 digest of the encryption key according to RFC 1321.
// Amazon S3 uses this header for a message integrity check to ensure the encryption
// key was transmitted without error.
SSECustomerKeyMD5 *string `location:"header" locationName:"x-amz-server-side-encryption-customer-key-MD5" type:"string"`
// Specifies the AWS KMS key ID to use for object encryption. All GET and PUT
// requests for an object protected by AWS KMS will fail if not made via SSL
// or using SigV4. Documentation on configuring any of the officially supported
// AWS SDKs and CLI can be found at http://docs.aws.amazon.com/AmazonS3/latest/dev/UsingAWSSDK.html#specify-signature-version
SSEKMSKeyID *string `location:"header" locationName:"x-amz-server-side-encryption-aws-kms-key-id" type:"string"`
// The Server-side encryption algorithm used when storing this object in S3
// (e.g., AES256, aws:kms).
ServerSideEncryption *string `location:"header" locationName:"x-amz-server-side-encryption" type:"string"`
// The type of storage to use for the object. Defaults to 'STANDARD'.
StorageClass *string `location:"header" locationName:"x-amz-storage-class" type:"string"`
// If the bucket is configured as a website, redirects requests for this object
// to another object in the same bucket or to an external URL. Amazon S3 stores
// the value of this header in the object metadata.
WebsiteRedirectLocation *string `location:"header" locationName:"x-amz-website-redirect-location" type:"string"`
// The readable body payload to send to S3.
Body io.Reader
}
// UploadOutput represents a response from the Upload() call.
type UploadOutput struct {
// The URL where the object was uploaded to.
Location string
// The ID for a multipart upload to S3. In the case of an error the error
// can be cast to the MultiUploadFailure interface to extract the upload ID.
UploadID string
}
// UploadOptions keeps tracks of extra options to pass to an Upload() call.
type UploadOptions struct {
// The buffer size (in bytes) to use when buffering data into chunks and
// sending them as parts to S3. The minimum allowed part size is 5MB, and
// if this value is set to zero, the DefaultPartSize value will be used.
PartSize int64
// The number of goroutines to spin up in parallel when sending parts.
// If this is set to zero, the DefaultConcurrency value will be used.
Concurrency int
// Setting this value to true will cause the SDK to avoid calling
// AbortMultipartUpload on a failure, leaving all successfully uploaded
// parts on S3 for manual recovery.
//
// Note that storing parts of an incomplete multipart upload counts towards
// space usage on S3 and will add additional costs if not cleaned up.
LeavePartsOnError bool
// The client to use when uploading to S3. Leave this as nil to use the
// default S3 client.
S3 *s3.S3
}
// NewUploader creates a new Uploader object to upload data to S3. Pass in
// an optional opts structure to customize the uploader behavior.
func NewUploader(opts *UploadOptions) *Uploader {
if opts == nil {
opts = DefaultUploadOptions
}
return &Uploader{opts: opts}
}
// The Uploader structure that calls Upload(). It is safe to call Upload()
// on this structure for multiple objects and across concurrent goroutines.
type Uploader struct {
opts *UploadOptions
}
// Upload uploads an object to S3, intelligently buffering large files into
// smaller chunks and sending them in parallel across multiple goroutines. You
// can configure the buffer size and concurrency through the opts parameter.
//
// If opts is set to nil, DefaultUploadOptions will be used.
//
// It is safe to call this method for multiple objects and across concurrent
// goroutines.
func (u *Uploader) Upload(input *UploadInput) (*UploadOutput, error) {
i := uploader{in: input, opts: *u.opts}
return i.upload()
}
// internal structure to manage an upload to S3.
type uploader struct {
in *UploadInput
opts UploadOptions
readerPos int64 // current reader position
totalSize int64 // set to -1 if the size is not known
}
// internal logic for deciding whether to upload a single part or use a
// multipart upload.
func (u *uploader) upload() (*UploadOutput, error) {
u.init()
if u.opts.PartSize < MinUploadPartSize {
msg := fmt.Sprintf("part size must be at least %d bytes", MinUploadPartSize)
return nil, awserr.New("ConfigError", msg, nil)
}
// Do one read to determine if we have more than one part
buf, err := u.nextReader()
if err == io.EOF || err == io.ErrUnexpectedEOF { // single part
return u.singlePart(buf)
} else if err != nil {
return nil, awserr.New("ReadRequestBody", "read upload data failed", err)
}
mu := multiuploader{uploader: u}
return mu.upload(buf)
}
// init will initialize all default options.
func (u *uploader) init() {
if u.opts.S3 == nil {
u.opts.S3 = s3.New(nil)
}
if u.opts.Concurrency == 0 {
u.opts.Concurrency = DefaultUploadConcurrency
}
if u.opts.PartSize == 0 {
u.opts.PartSize = DefaultUploadPartSize
}
// Try to get the total size for some optimizations
u.initSize()
}
// initSize tries to detect the total stream size, setting u.totalSize. If
// the size is not known, totalSize is set to -1.
func (u *uploader) initSize() {
u.totalSize = -1
switch r := u.in.Body.(type) {
case io.Seeker:
pos, _ := r.Seek(0, 1)
defer r.Seek(pos, 0)
n, err := r.Seek(0, 2)
if err != nil {
return
}
u.totalSize = n
// try to adjust partSize if it is too small
if u.totalSize/u.opts.PartSize >= int64(MaxUploadParts) {
u.opts.PartSize = u.totalSize / int64(MaxUploadParts)
}
}
}
// nextReader returns a seekable reader representing the next packet of data.
// This operation increases the shared u.readerPos counter, but note that it
// does not need to be wrapped in a mutex because nextReader is only called
// from the main thread.
func (u *uploader) nextReader() (io.ReadSeeker, error) {
switch r := u.in.Body.(type) {
case io.ReaderAt:
var err error
n := u.opts.PartSize
if u.totalSize >= 0 {
bytesLeft := u.totalSize - u.readerPos
if bytesLeft == 0 {
err = io.EOF
} else if bytesLeft <= u.opts.PartSize {
err = io.ErrUnexpectedEOF
n = bytesLeft
}
}
buf := io.NewSectionReader(r, u.readerPos, n)
u.readerPos += n
return buf, err
default:
packet := make([]byte, u.opts.PartSize)
n, err := io.ReadFull(u.in.Body, packet)
u.readerPos += int64(n)
return bytes.NewReader(packet[0:n]), err
}
}
// singlePart contains upload logic for uploading a single chunk via
// a regular PutObject request. Multipart requests require at least two
// parts, or at least 5MB of data.
func (u *uploader) singlePart(buf io.ReadSeeker) (*UploadOutput, error) {
params := &s3.PutObjectInput{}
awsutil.Copy(params, u.in)
params.Body = buf
req, _ := u.opts.S3.PutObjectRequest(params)
if err := req.Send(); err != nil {
return nil, err
}
url := req.HTTPRequest.URL.String()
return &UploadOutput{Location: url}, nil
}
// internal structure to manage a specific multipart upload to S3.
type multiuploader struct {
*uploader
wg sync.WaitGroup
m sync.Mutex
err error
uploadID string
parts completedParts
}
// keeps track of a single chunk of data being sent to S3.
type chunk struct {
buf io.ReadSeeker
num int64
}
// completedParts is a wrapper to make parts sortable by their part number,
// since S3 required this list to be sent in sorted order.
type completedParts []*s3.CompletedPart
func (a completedParts) Len() int { return len(a) }
func (a completedParts) Swap(i, j int) { a[i], a[j] = a[j], a[i] }
func (a completedParts) Less(i, j int) bool { return *a[i].PartNumber < *a[j].PartNumber }
// upload will perform a multipart upload using the firstBuf buffer containing
// the first chunk of data.
func (u *multiuploader) upload(firstBuf io.ReadSeeker) (*UploadOutput, error) {
params := &s3.CreateMultipartUploadInput{}
awsutil.Copy(params, u.in)
// Create the multipart
resp, err := u.opts.S3.CreateMultipartUpload(params)
if err != nil {
return nil, err
}
u.uploadID = *resp.UploadID
// Create the workers
ch := make(chan chunk, u.opts.Concurrency)
for i := 0; i < u.opts.Concurrency; i++ {
u.wg.Add(1)
go u.readChunk(ch)
}
// Send part 1 to the workers
var num int64 = 1
ch <- chunk{buf: firstBuf, num: num}
// Read and queue the rest of the parts
for u.geterr() == nil {
// This upload exceeded maximum number of supported parts, error now.
if num > int64(MaxUploadParts) {
msg := fmt.Sprintf("exceeded total allowed parts (%d). "+
"Adjust PartSize to fit in this limit", MaxUploadParts)
u.seterr(awserr.New("TotalPartsExceeded", msg, nil))
break
}
num++
buf, err := u.nextReader()
if err == io.EOF {
break
}
ch <- chunk{buf: buf, num: num}
if err != nil && err != io.ErrUnexpectedEOF {
u.seterr(awserr.New(
"ReadRequestBody",
"read multipart upload data failed",
err))
break
}
}
// Close the channel, wait for workers, and complete upload
close(ch)
u.wg.Wait()
complete := u.complete()
if err := u.geterr(); err != nil {
return nil, &multiUploadError{
awsError: awserr.New(
"MultipartUpload",
"upload multipart failed",
err),
uploadID: u.uploadID,
}
}
return &UploadOutput{
Location: *complete.Location,
UploadID: u.uploadID,
}, nil
}
// readChunk runs in worker goroutines to pull chunks off of the ch channel
// and send() them as UploadPart requests.
func (u *multiuploader) readChunk(ch chan chunk) {
defer u.wg.Done()
for {
data, ok := <-ch
if !ok {
break
}
if u.geterr() == nil {
if err := u.send(data); err != nil {
u.seterr(err)
}
}
}
}
// send performs an UploadPart request and keeps track of the completed
// part information.
func (u *multiuploader) send(c chunk) error {
resp, err := u.opts.S3.UploadPart(&s3.UploadPartInput{
Bucket: u.in.Bucket,
Key: u.in.Key,
Body: c.buf,
UploadID: &u.uploadID,
PartNumber: &c.num,
})
if err != nil {
return err
}
n := c.num
completed := &s3.CompletedPart{ETag: resp.ETag, PartNumber: &n}
u.m.Lock()
u.parts = append(u.parts, completed)
u.m.Unlock()
return nil
}
// geterr is a thread-safe getter for the error object
func (u *multiuploader) geterr() error {
u.m.Lock()
defer u.m.Unlock()
return u.err
}
// seterr is a thread-safe setter for the error object
func (u *multiuploader) seterr(e error) {
u.m.Lock()
defer u.m.Unlock()
u.err = e
}
// fail will abort the multipart unless LeavePartsOnError is set to true.
func (u *multiuploader) fail() {
if u.opts.LeavePartsOnError {
return
}
u.opts.S3.AbortMultipartUpload(&s3.AbortMultipartUploadInput{
Bucket: u.in.Bucket,
Key: u.in.Key,
UploadID: &u.uploadID,
})
}
// complete successfully completes a multipart upload and returns the response.
func (u *multiuploader) complete() *s3.CompleteMultipartUploadOutput {
if u.geterr() != nil {
u.fail()
return nil
}
// Parts must be sorted in PartNumber order.
sort.Sort(u.parts)
resp, err := u.opts.S3.CompleteMultipartUpload(&s3.CompleteMultipartUploadInput{
Bucket: u.in.Bucket,
Key: u.in.Key,
UploadID: &u.uploadID,
MultipartUpload: &s3.CompletedMultipartUpload{Parts: u.parts},
})
if err != nil {
u.seterr(err)
u.fail()
}
return resp
}

View file

@ -1,438 +0,0 @@
package s3manager_test
import (
"bytes"
"fmt"
"io"
"io/ioutil"
"net/http"
"sort"
"sync"
"testing"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/awsutil"
"github.com/aws/aws-sdk-go/internal/test/unit"
"github.com/aws/aws-sdk-go/service/s3"
"github.com/aws/aws-sdk-go/service/s3/s3manager"
"github.com/stretchr/testify/assert"
)
var _ = unit.Imported
var buf12MB = make([]byte, 1024*1024*12)
var buf2MB = make([]byte, 1024*1024*2)
var emptyList = []string{}
func val(i interface{}, s string) interface{} {
return awsutil.ValuesAtPath(i, s)[0]
}
func contains(src []string, s string) bool {
for _, v := range src {
if s == v {
return true
}
}
return false
}
func loggingSvc(ignoreOps []string) (*s3.S3, *[]string, *[]interface{}) {
var m sync.Mutex
partNum := 0
names := []string{}
params := []interface{}{}
svc := s3.New(nil)
svc.Handlers.Unmarshal.Clear()
svc.Handlers.UnmarshalMeta.Clear()
svc.Handlers.UnmarshalError.Clear()
svc.Handlers.Send.Clear()
svc.Handlers.Send.PushBack(func(r *aws.Request) {
m.Lock()
defer m.Unlock()
if !contains(ignoreOps, r.Operation.Name) {
names = append(names, r.Operation.Name)
params = append(params, r.Params)
}
r.HTTPResponse = &http.Response{
StatusCode: 200,
Body: ioutil.NopCloser(bytes.NewReader([]byte{})),
}
switch data := r.Data.(type) {
case *s3.CreateMultipartUploadOutput:
data.UploadID = aws.String("UPLOAD-ID")
case *s3.UploadPartOutput:
partNum++
data.ETag = aws.String(fmt.Sprintf("ETAG%d", partNum))
case *s3.CompleteMultipartUploadOutput:
data.Location = aws.String("https://location")
}
})
return svc, &names, &params
}
func buflen(i interface{}) int {
r := i.(io.Reader)
b, _ := ioutil.ReadAll(r)
return len(b)
}
func TestUploadOrderMulti(t *testing.T) {
s, ops, args := loggingSvc(emptyList)
mgr := s3manager.NewUploader(&s3manager.UploadOptions{S3: s})
resp, err := mgr.Upload(&s3manager.UploadInput{
Bucket: aws.String("Bucket"),
Key: aws.String("Key"),
Body: bytes.NewReader(buf12MB),
ServerSideEncryption: aws.String("AES256"),
ContentType: aws.String("content/type"),
})
assert.NoError(t, err)
assert.Equal(t, []string{"CreateMultipartUpload", "UploadPart", "UploadPart", "UploadPart", "CompleteMultipartUpload"}, *ops)
assert.Equal(t, "https://location", resp.Location)
assert.Equal(t, "UPLOAD-ID", resp.UploadID)
// Validate input values
// UploadPart
assert.Equal(t, "UPLOAD-ID", val((*args)[1], "UploadID"))
assert.Equal(t, "UPLOAD-ID", val((*args)[2], "UploadID"))
assert.Equal(t, "UPLOAD-ID", val((*args)[3], "UploadID"))
// CompleteMultipartUpload
assert.Equal(t, "UPLOAD-ID", val((*args)[4], "UploadID"))
assert.Equal(t, int64(1), val((*args)[4], "MultipartUpload.Parts[0].PartNumber"))
assert.Equal(t, int64(2), val((*args)[4], "MultipartUpload.Parts[1].PartNumber"))
assert.Equal(t, int64(3), val((*args)[4], "MultipartUpload.Parts[2].PartNumber"))
assert.Regexp(t, `^ETAG\d+$`, val((*args)[4], "MultipartUpload.Parts[0].ETag"))
assert.Regexp(t, `^ETAG\d+$`, val((*args)[4], "MultipartUpload.Parts[1].ETag"))
assert.Regexp(t, `^ETAG\d+$`, val((*args)[4], "MultipartUpload.Parts[2].ETag"))
// Custom headers
assert.Equal(t, "AES256", val((*args)[0], "ServerSideEncryption"))
assert.Equal(t, "content/type", val((*args)[0], "ContentType"))
}
func TestUploadOrderMultiDifferentPartSize(t *testing.T) {
s, ops, args := loggingSvc(emptyList)
mgr := s3manager.NewUploader(&s3manager.UploadOptions{
S3: s,
PartSize: 1024 * 1024 * 7,
Concurrency: 1,
})
_, err := mgr.Upload(&s3manager.UploadInput{
Bucket: aws.String("Bucket"),
Key: aws.String("Key"),
Body: bytes.NewReader(buf12MB),
})
assert.NoError(t, err)
assert.Equal(t, []string{"CreateMultipartUpload", "UploadPart", "UploadPart", "CompleteMultipartUpload"}, *ops)
// Part lengths
assert.Equal(t, 1024*1024*7, buflen(val((*args)[1], "Body")))
assert.Equal(t, 1024*1024*5, buflen(val((*args)[2], "Body")))
}
func TestUploadIncreasePartSize(t *testing.T) {
s3manager.MaxUploadParts = 2
defer func() { s3manager.MaxUploadParts = 10000 }()
s, ops, args := loggingSvc(emptyList)
opts := &s3manager.UploadOptions{S3: s, Concurrency: 1}
mgr := s3manager.NewUploader(opts)
_, err := mgr.Upload(&s3manager.UploadInput{
Bucket: aws.String("Bucket"),
Key: aws.String("Key"),
Body: bytes.NewReader(buf12MB),
})
assert.NoError(t, err)
assert.Equal(t, int64(0), opts.PartSize) // don't modify orig options
assert.Equal(t, []string{"CreateMultipartUpload", "UploadPart", "UploadPart", "CompleteMultipartUpload"}, *ops)
// Part lengths
assert.Equal(t, 1024*1024*6, buflen(val((*args)[1], "Body")))
assert.Equal(t, 1024*1024*6, buflen(val((*args)[2], "Body")))
}
func TestUploadFailIfPartSizeTooSmall(t *testing.T) {
opts := &s3manager.UploadOptions{PartSize: 5}
mgr := s3manager.NewUploader(opts)
resp, err := mgr.Upload(&s3manager.UploadInput{
Bucket: aws.String("Bucket"),
Key: aws.String("Key"),
Body: bytes.NewReader(buf12MB),
})
assert.Nil(t, resp)
assert.NotNil(t, err)
aerr := err.(awserr.Error)
assert.Equal(t, "ConfigError", aerr.Code())
assert.Contains(t, aerr.Message(), "part size must be at least")
}
func TestUploadOrderSingle(t *testing.T) {
s, ops, args := loggingSvc(emptyList)
mgr := s3manager.NewUploader(&s3manager.UploadOptions{S3: s})
resp, err := mgr.Upload(&s3manager.UploadInput{
Bucket: aws.String("Bucket"),
Key: aws.String("Key"),
Body: bytes.NewReader(buf2MB),
ServerSideEncryption: aws.String("AES256"),
ContentType: aws.String("content/type"),
})
assert.NoError(t, err)
assert.Equal(t, []string{"PutObject"}, *ops)
assert.NotEqual(t, "", resp.Location)
assert.Equal(t, "", resp.UploadID)
assert.Equal(t, "AES256", val((*args)[0], "ServerSideEncryption"))
assert.Equal(t, "content/type", val((*args)[0], "ContentType"))
}
func TestUploadOrderSingleFailure(t *testing.T) {
s, ops, _ := loggingSvc(emptyList)
s.Handlers.Send.PushBack(func(r *aws.Request) {
r.HTTPResponse.StatusCode = 400
})
mgr := s3manager.NewUploader(&s3manager.UploadOptions{S3: s})
resp, err := mgr.Upload(&s3manager.UploadInput{
Bucket: aws.String("Bucket"),
Key: aws.String("Key"),
Body: bytes.NewReader(buf2MB),
})
assert.Error(t, err)
assert.Equal(t, []string{"PutObject"}, *ops)
assert.Nil(t, resp)
}
func TestUploadOrderZero(t *testing.T) {
s, ops, args := loggingSvc(emptyList)
mgr := s3manager.NewUploader(&s3manager.UploadOptions{S3: s})
resp, err := mgr.Upload(&s3manager.UploadInput{
Bucket: aws.String("Bucket"),
Key: aws.String("Key"),
Body: bytes.NewReader(make([]byte, 0)),
})
assert.NoError(t, err)
assert.Equal(t, []string{"PutObject"}, *ops)
assert.NotEqual(t, "", resp.Location)
assert.Equal(t, "", resp.UploadID)
assert.Equal(t, 0, buflen(val((*args)[0], "Body")))
}
func TestUploadOrderMultiFailure(t *testing.T) {
s, ops, _ := loggingSvc(emptyList)
s.Handlers.Send.PushBack(func(r *aws.Request) {
switch t := r.Data.(type) {
case *s3.UploadPartOutput:
if *t.ETag == "ETAG2" {
r.HTTPResponse.StatusCode = 400
}
}
})
mgr := s3manager.NewUploader(&s3manager.UploadOptions{S3: s, Concurrency: 1})
_, err := mgr.Upload(&s3manager.UploadInput{
Bucket: aws.String("Bucket"),
Key: aws.String("Key"),
Body: bytes.NewReader(buf12MB),
})
assert.Error(t, err)
assert.Equal(t, []string{"CreateMultipartUpload", "UploadPart", "UploadPart", "AbortMultipartUpload"}, *ops)
}
func TestUploadOrderMultiFailureOnComplete(t *testing.T) {
s, ops, _ := loggingSvc(emptyList)
s.Handlers.Send.PushBack(func(r *aws.Request) {
switch r.Data.(type) {
case *s3.CompleteMultipartUploadOutput:
r.HTTPResponse.StatusCode = 400
}
})
mgr := s3manager.NewUploader(&s3manager.UploadOptions{S3: s, Concurrency: 1})
_, err := mgr.Upload(&s3manager.UploadInput{
Bucket: aws.String("Bucket"),
Key: aws.String("Key"),
Body: bytes.NewReader(buf12MB),
})
assert.Error(t, err)
assert.Equal(t, []string{"CreateMultipartUpload", "UploadPart", "UploadPart",
"UploadPart", "CompleteMultipartUpload", "AbortMultipartUpload"}, *ops)
}
func TestUploadOrderMultiFailureOnCreate(t *testing.T) {
s, ops, _ := loggingSvc(emptyList)
s.Handlers.Send.PushBack(func(r *aws.Request) {
switch r.Data.(type) {
case *s3.CreateMultipartUploadOutput:
r.HTTPResponse.StatusCode = 400
}
})
mgr := s3manager.NewUploader(&s3manager.UploadOptions{S3: s})
_, err := mgr.Upload(&s3manager.UploadInput{
Bucket: aws.String("Bucket"),
Key: aws.String("Key"),
Body: bytes.NewReader(make([]byte, 1024*1024*12)),
})
assert.Error(t, err)
assert.Equal(t, []string{"CreateMultipartUpload"}, *ops)
}
func TestUploadOrderMultiFailureLeaveParts(t *testing.T) {
s, ops, _ := loggingSvc(emptyList)
s.Handlers.Send.PushBack(func(r *aws.Request) {
switch data := r.Data.(type) {
case *s3.UploadPartOutput:
if *data.ETag == "ETAG2" {
r.HTTPResponse.StatusCode = 400
}
}
})
mgr := s3manager.NewUploader(&s3manager.UploadOptions{
S3: s,
Concurrency: 1,
LeavePartsOnError: true,
})
_, err := mgr.Upload(&s3manager.UploadInput{
Bucket: aws.String("Bucket"),
Key: aws.String("Key"),
Body: bytes.NewReader(make([]byte, 1024*1024*12)),
})
assert.Error(t, err)
assert.Equal(t, []string{"CreateMultipartUpload", "UploadPart", "UploadPart"}, *ops)
}
type failreader struct {
times int
failCount int
}
func (f *failreader) Read(b []byte) (int, error) {
f.failCount++
if f.failCount >= f.times {
return 0, fmt.Errorf("random failure")
}
return len(b), nil
}
func TestUploadOrderReadFail1(t *testing.T) {
s, ops, _ := loggingSvc(emptyList)
mgr := s3manager.NewUploader(&s3manager.UploadOptions{S3: s})
_, err := mgr.Upload(&s3manager.UploadInput{
Bucket: aws.String("Bucket"),
Key: aws.String("Key"),
Body: &failreader{times: 1},
})
assert.Equal(t, "ReadRequestBody", err.(awserr.Error).Code())
assert.EqualError(t, err.(awserr.Error).OrigErr(), "random failure")
assert.Equal(t, []string{}, *ops)
}
func TestUploadOrderReadFail2(t *testing.T) {
s, ops, _ := loggingSvc([]string{"UploadPart"})
mgr := s3manager.NewUploader(&s3manager.UploadOptions{S3: s, Concurrency: 1})
_, err := mgr.Upload(&s3manager.UploadInput{
Bucket: aws.String("Bucket"),
Key: aws.String("Key"),
Body: &failreader{times: 2},
})
assert.Equal(t, "ReadRequestBody", err.(awserr.Error).Code())
assert.EqualError(t, err.(awserr.Error).OrigErr(), "random failure")
assert.Equal(t, []string{"CreateMultipartUpload", "AbortMultipartUpload"}, *ops)
}
type sizedReader struct {
size int
cur int
}
func (s *sizedReader) Read(p []byte) (n int, err error) {
if s.cur >= s.size {
return 0, io.EOF
}
n = len(p)
s.cur += len(p)
if s.cur > s.size {
n -= s.cur - s.size
}
return
}
func TestUploadOrderMultiBufferedReader(t *testing.T) {
s, ops, args := loggingSvc(emptyList)
mgr := s3manager.NewUploader(&s3manager.UploadOptions{S3: s})
_, err := mgr.Upload(&s3manager.UploadInput{
Bucket: aws.String("Bucket"),
Key: aws.String("Key"),
Body: &sizedReader{size: 1024 * 1024 * 12},
})
assert.NoError(t, err)
assert.Equal(t, []string{"CreateMultipartUpload", "UploadPart", "UploadPart", "UploadPart", "CompleteMultipartUpload"}, *ops)
// Part lengths
parts := []int{
buflen(val((*args)[1], "Body")),
buflen(val((*args)[2], "Body")),
buflen(val((*args)[3], "Body")),
}
sort.Ints(parts)
assert.Equal(t, []int{1024 * 1024 * 2, 1024 * 1024 * 5, 1024 * 1024 * 5}, parts)
}
func TestUploadOrderMultiBufferedReaderExceedTotalParts(t *testing.T) {
s3manager.MaxUploadParts = 2
defer func() { s3manager.MaxUploadParts = 10000 }()
s, ops, _ := loggingSvc([]string{"UploadPart"})
mgr := s3manager.NewUploader(&s3manager.UploadOptions{S3: s, Concurrency: 1})
resp, err := mgr.Upload(&s3manager.UploadInput{
Bucket: aws.String("Bucket"),
Key: aws.String("Key"),
Body: &sizedReader{size: 1024 * 1024 * 12},
})
assert.Error(t, err)
assert.Nil(t, resp)
assert.Equal(t, []string{"CreateMultipartUpload", "AbortMultipartUpload"}, *ops)
aerr := err.(awserr.Error)
assert.Equal(t, "TotalPartsExceeded", aerr.Code())
assert.Contains(t, aerr.Message(), "exceeded total allowed parts (2)")
}
func TestUploadOrderSingleBufferedReader(t *testing.T) {
s, ops, _ := loggingSvc(emptyList)
mgr := s3manager.NewUploader(&s3manager.UploadOptions{S3: s})
resp, err := mgr.Upload(&s3manager.UploadInput{
Bucket: aws.String("Bucket"),
Key: aws.String("Key"),
Body: &sizedReader{size: 1024 * 1024 * 2},
})
assert.NoError(t, err)
assert.Equal(t, []string{"PutObject"}, *ops)
assert.NotEqual(t, "", resp.Location)
assert.Equal(t, "", resp.UploadID)
}

View file

@ -1,57 +0,0 @@
// THIS FILE IS AUTOMATICALLY GENERATED. DO NOT EDIT.
package s3
import (
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/internal/protocol/restxml"
"github.com/aws/aws-sdk-go/internal/signer/v4"
)
// S3 is a client for Amazon S3.
type S3 struct {
*aws.Service
}
// Used for custom service initialization logic
var initService func(*aws.Service)
// Used for custom request initialization logic
var initRequest func(*aws.Request)
// New returns a new S3 client.
func New(config *aws.Config) *S3 {
service := &aws.Service{
Config: aws.DefaultConfig.Merge(config),
ServiceName: "s3",
APIVersion: "2006-03-01",
}
service.Initialize()
// Handlers
service.Handlers.Sign.PushBack(v4.Sign)
service.Handlers.Build.PushBack(restxml.Build)
service.Handlers.Unmarshal.PushBack(restxml.Unmarshal)
service.Handlers.UnmarshalMeta.PushBack(restxml.UnmarshalMeta)
service.Handlers.UnmarshalError.PushBack(restxml.UnmarshalError)
// Run custom service initialization if present
if initService != nil {
initService(service)
}
return &S3{service}
}
// newRequest creates a new request for a S3 operation and runs any
// custom request initialization.
func (c *S3) newRequest(op *aws.Operation, params, data interface{}) *aws.Request {
req := aws.NewRequest(c.Service, op, params, data)
// Run custom request initialization if present
if initRequest != nil {
initRequest(req)
}
return req
}

View file

@ -1,81 +0,0 @@
package s3_test
import (
"testing"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/internal/test/unit"
"github.com/aws/aws-sdk-go/service/s3"
"github.com/stretchr/testify/assert"
)
var _ = unit.Imported
func TestSSECustomerKeyOverHTTPError(t *testing.T) {
s := s3.New(&aws.Config{DisableSSL: aws.Bool(true)})
req, _ := s.CopyObjectRequest(&s3.CopyObjectInput{
Bucket: aws.String("bucket"),
CopySource: aws.String("bucket/source"),
Key: aws.String("dest"),
SSECustomerKey: aws.String("key"),
})
err := req.Build()
assert.Error(t, err)
assert.Equal(t, "ConfigError", err.(awserr.Error).Code())
assert.Contains(t, err.(awserr.Error).Message(), "cannot send SSE keys over HTTP")
}
func TestCopySourceSSECustomerKeyOverHTTPError(t *testing.T) {
s := s3.New(&aws.Config{DisableSSL: aws.Bool(true)})
req, _ := s.CopyObjectRequest(&s3.CopyObjectInput{
Bucket: aws.String("bucket"),
CopySource: aws.String("bucket/source"),
Key: aws.String("dest"),
CopySourceSSECustomerKey: aws.String("key"),
})
err := req.Build()
assert.Error(t, err)
assert.Equal(t, "ConfigError", err.(awserr.Error).Code())
assert.Contains(t, err.(awserr.Error).Message(), "cannot send SSE keys over HTTP")
}
func TestComputeSSEKeys(t *testing.T) {
s := s3.New(nil)
req, _ := s.CopyObjectRequest(&s3.CopyObjectInput{
Bucket: aws.String("bucket"),
CopySource: aws.String("bucket/source"),
Key: aws.String("dest"),
SSECustomerKey: aws.String("key"),
CopySourceSSECustomerKey: aws.String("key"),
})
err := req.Build()
assert.NoError(t, err)
assert.Equal(t, "a2V5", req.HTTPRequest.Header.Get("x-amz-server-side-encryption-customer-key"))
assert.Equal(t, "a2V5", req.HTTPRequest.Header.Get("x-amz-copy-source-server-side-encryption-customer-key"))
assert.Equal(t, "PG4LipwVIkqCKLmpjKFTHQ==", req.HTTPRequest.Header.Get("x-amz-server-side-encryption-customer-key-md5"))
assert.Equal(t, "PG4LipwVIkqCKLmpjKFTHQ==", req.HTTPRequest.Header.Get("x-amz-copy-source-server-side-encryption-customer-key-md5"))
}
func TestComputeSSEKeysShortcircuit(t *testing.T) {
s := s3.New(nil)
req, _ := s.CopyObjectRequest(&s3.CopyObjectInput{
Bucket: aws.String("bucket"),
CopySource: aws.String("bucket/source"),
Key: aws.String("dest"),
SSECustomerKey: aws.String("key"),
CopySourceSSECustomerKey: aws.String("key"),
SSECustomerKeyMD5: aws.String("MD5"),
CopySourceSSECustomerKeyMD5: aws.String("MD5"),
})
err := req.Build()
assert.NoError(t, err)
assert.Equal(t, "a2V5", req.HTTPRequest.Header.Get("x-amz-server-side-encryption-customer-key"))
assert.Equal(t, "a2V5", req.HTTPRequest.Header.Get("x-amz-copy-source-server-side-encryption-customer-key"))
assert.Equal(t, "MD5", req.HTTPRequest.Header.Get("x-amz-server-side-encryption-customer-key-md5"))
assert.Equal(t, "MD5", req.HTTPRequest.Header.Get("x-amz-copy-source-server-side-encryption-customer-key-md5"))
}

View file

@ -1,42 +0,0 @@
package s3
import (
"encoding/xml"
"io"
"strings"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
)
type xmlErrorResponse struct {
XMLName xml.Name `xml:"Error"`
Code string `xml:"Code"`
Message string `xml:"Message"`
}
func unmarshalError(r *aws.Request) {
defer r.HTTPResponse.Body.Close()
if r.HTTPResponse.ContentLength == int64(0) {
// No body, use status code to generate an awserr.Error
r.Error = awserr.NewRequestFailure(
awserr.New(strings.Replace(r.HTTPResponse.Status, " ", "", -1), r.HTTPResponse.Status, nil),
r.HTTPResponse.StatusCode,
"",
)
return
}
resp := &xmlErrorResponse{}
err := xml.NewDecoder(r.HTTPResponse.Body).Decode(resp)
if err != nil && err != io.EOF {
r.Error = awserr.New("SerializationError", "failed to decode S3 XML error response", nil)
} else {
r.Error = awserr.NewRequestFailure(
awserr.New(resp.Code, resp.Message, nil),
r.HTTPResponse.StatusCode,
"",
)
}
}

View file

@ -1,53 +0,0 @@
package s3_test
import (
"bytes"
"io/ioutil"
"net/http"
"testing"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/internal/test/unit"
"github.com/aws/aws-sdk-go/service/s3"
"github.com/stretchr/testify/assert"
)
var _ = unit.Imported
var s3StatusCodeErrorTests = []struct {
scode int
status string
body string
code string
message string
}{
{301, "Moved Permanently", "", "MovedPermanently", "Moved Permanently"},
{403, "Forbidden", "", "Forbidden", "Forbidden"},
{400, "Bad Request", "", "BadRequest", "Bad Request"},
{404, "Not Found", "", "NotFound", "Not Found"},
{500, "Internal Error", "", "InternalError", "Internal Error"},
}
func TestStatusCodeError(t *testing.T) {
for _, test := range s3StatusCodeErrorTests {
s := s3.New(nil)
s.Handlers.Send.Clear()
s.Handlers.Send.PushBack(func(r *aws.Request) {
body := ioutil.NopCloser(bytes.NewReader([]byte(test.body)))
r.HTTPResponse = &http.Response{
ContentLength: int64(len(test.body)),
StatusCode: test.scode,
Status: test.status,
Body: body,
}
})
_, err := s.PutBucketACL(&s3.PutBucketACLInput{
Bucket: aws.String("bucket"), ACL: aws.String("public-read"),
})
assert.Error(t, err)
assert.Equal(t, test.code, err.(awserr.Error).Code())
assert.Equal(t, test.message, err.(awserr.Error).Message())
}
}

View file

@ -1,8 +0,0 @@
language: go
go:
- 1.3
- 1.4.2
- tip
script:
- go test ./...
- go build

View file

@ -1,485 +0,0 @@
# Cobra
A Commander for modern go CLI interactions
[![Build Status](https://travis-ci.org/spf13/cobra.svg)](https://travis-ci.org/spf13/cobra)
## Overview
Cobra is a commander providing a simple interface to create powerful modern CLI
interfaces similar to git & go tools. In addition to providing an interface, Cobra
simultaneously provides a controller to organize your application code.
Inspired by go, go-Commander, gh and subcommand, Cobra improves on these by
providing **fully posix compliant flags** (including short & long versions),
**nesting commands**, and the ability to **define your own help and usage** for any or
all commands.
Cobra has an exceptionally clean interface and simple design without needless
constructors or initialization methods.
Applications built with Cobra commands are designed to be as user friendly as
possible. Flags can be placed before or after the command (as long as a
confusing space isnt provided). Both short and long flags can be used. A
command need not even be fully typed. The shortest unambiguous string will
suffice. Help is automatically generated and available for the application or
for a specific command using either the help command or the --help flag.
## Concepts
Cobra is built on a structure of commands & flags.
**Commands** represent actions and **Flags** are modifiers for those actions.
In the following example 'server' is a command and 'port' is a flag.
hugo server --port=1313
### Commands
Command is the central point of the application. Each interaction that
the application supports will be contained in a Command. A command can
have children commands and optionally run an action.
In the example above 'server' is the command
A Command has the following structure:
type Command struct {
Use string // The one-line usage message.
Short string // The short description shown in the 'help' output.
Long string // The long message shown in the 'help <this-command>' output.
Run func(cmd *Command, args []string) // Run runs the command.
}
### Flags
A Flag is a way to modify the behavior of an command. Cobra supports
fully posix compliant flags as well as the go flag package.
A Cobra command can define flags that persist through to children commands
and flags that are only available to that command.
In the example above 'port' is the flag.
Flag functionality is provided by the [pflag
library](https://github.com/ogier/pflag), a fork of the flag standard library
which maintains the same interface while adding posix compliance.
## Usage
Cobra works by creating a set of commands and then organizing them into a tree.
The tree defines the structure of the application.
Once each command is defined with it's corresponding flags, then the
tree is assigned to the commander which is finally executed.
### Installing
Using Cobra is easy. First use go get to install the latest version
of the library.
$ go get github.com/spf13/cobra
Next include cobra in your application.
import "github.com/spf13/cobra"
### Create the root command
The root command represents your binary itself.
Cobra doesn't require any special constructors. Simply create your commands.
var HugoCmd = &cobra.Command{
Use: "hugo",
Short: "Hugo is a very fast static site generator",
Long: `A Fast and Flexible Static Site Generator built with
love by spf13 and friends in Go.
Complete documentation is available at http://hugo.spf13.com`,
Run: func(cmd *cobra.Command, args []string) {
// Do Stuff Here
},
}
### Create additional commands
Additional commands can be defined.
var versionCmd = &cobra.Command{
Use: "version",
Short: "Print the version number of Hugo",
Long: `All software has versions. This is Hugo's`,
Run: func(cmd *cobra.Command, args []string) {
fmt.Println("Hugo Static Site Generator v0.9 -- HEAD")
},
}
### Attach command to its parent
In this example we are attaching it to the root, but commands can be attached at any level.
HugoCmd.AddCommand(versionCmd)
### Assign flags to a command
Since the flags are defined and used in different locations, we need to
define a variable outside with the correct scope to assign the flag to
work with.
var Verbose bool
var Source string
There are two different approaches to assign a flag.
#### Persistent Flags
A flag can be 'persistent' meaning that this flag will be available to the
command it's assigned to as well as every command under that command. For
global flags assign a flag as a persistent flag on the root.
HugoCmd.PersistentFlags().BoolVarP(&Verbose, "verbose", "v", false, "verbose output")
#### Local Flags
A flag can also be assigned locally which will only apply to that specific command.
HugoCmd.Flags().StringVarP(&Source, "source", "s", "", "Source directory to read from")
### Remove a command from its parent
Removing a command is not a common action in simple programs but it allows 3rd parties to customize an existing command tree.
In this example, we remove the existing `VersionCmd` command of an existing root command, and we replace it by our own version.
mainlib.RootCmd.RemoveCommand(mainlib.VersionCmd)
mainlib.RootCmd.AddCommand(versionCmd)
### Once all commands and flags are defined, Execute the commands
Execute should be run on the root for clarity, though it can be called on any command.
HugoCmd.Execute()
## Example
In the example below we have defined three commands. Two are at the top level
and one (cmdTimes) is a child of one of the top commands. In this case the root
is not executable meaning that a subcommand is required. This is accomplished
by not providing a 'Run' for the 'rootCmd'.
We have only defined one flag for a single command.
More documentation about flags is available at https://github.com/spf13/pflag
import(
"github.com/spf13/cobra"
"fmt"
"strings"
)
func main() {
var echoTimes int
var cmdPrint = &cobra.Command{
Use: "print [string to print]",
Short: "Print anything to the screen",
Long: `print is for printing anything back to the screen.
For many years people have printed back to the screen.
`,
Run: func(cmd *cobra.Command, args []string) {
fmt.Println("Print: " + strings.Join(args, " "))
},
}
var cmdEcho = &cobra.Command{
Use: "echo [string to echo]",
Short: "Echo anything to the screen",
Long: `echo is for echoing anything back.
Echo works a lot like print, except it has a child command.
`,
Run: func(cmd *cobra.Command, args []string) {
fmt.Println("Print: " + strings.Join(args, " "))
},
}
var cmdTimes = &cobra.Command{
Use: "times [# times] [string to echo]",
Short: "Echo anything to the screen more times",
Long: `echo things multiple times back to the user by providing
a count and a string.`,
Run: func(cmd *cobra.Command, args []string) {
for i:=0; i < echoTimes; i++ {
fmt.Println("Echo: " + strings.Join(args, " "))
}
},
}
cmdTimes.Flags().IntVarP(&echoTimes, "times", "t", 1, "times to echo the input")
var rootCmd = &cobra.Command{Use: "app"}
rootCmd.AddCommand(cmdPrint, cmdEcho)
cmdEcho.AddCommand(cmdTimes)
rootCmd.Execute()
}
For a more complete example of a larger application, please checkout [Hugo](http://hugo.spf13.com)
## The Help Command
Cobra automatically adds a help command to your application when you have subcommands.
This will be called when a user runs 'app help'. Additionally help will also
support all other commands as input. Say for instance you have a command called
'create' without any additional configuration cobra will work when 'app help
create' is called. Every command will automatically have the '--help' flag added.
### Example
The following output is automatically generated by cobra. Nothing beyond the
command and flag definitions are needed.
> hugo help
A Fast and Flexible Static Site Generator built with
love by spf13 and friends in Go.
Complete documentation is available at http://hugo.spf13.com
Usage:
hugo [flags]
hugo [command]
Available Commands:
server :: Hugo runs it's own a webserver to render the files
version :: Print the version number of Hugo
check :: Check content in the source directory
benchmark :: Benchmark hugo by building a site a number of times
help [command] :: Help about any command
Available Flags:
-b, --base-url="": hostname (and path) to the root eg. http://spf13.com/
-D, --build-drafts=false: include content marked as draft
--config="": config file (default is path/config.yaml|json|toml)
-d, --destination="": filesystem path to write files to
-s, --source="": filesystem path to read files relative from
--stepAnalysis=false: display memory and timing of different steps of the program
--uglyurls=false: if true, use /filename.html instead of /filename/
-v, --verbose=false: verbose output
-w, --watch=false: watch filesystem for changes and recreate as needed
Use "hugo help [command]" for more information about that command.
Help is just a command like any other. There is no special logic or behavior
around it. In fact you can provide your own if you want.
### Defining your own help
You can provide your own Help command or you own template for the default command to use.
The default help command is
func (c *Command) initHelp() {
if c.helpCommand == nil {
c.helpCommand = &Command{
Use: "help [command]",
Short: "Help about any command",
Long: `Help provides help for any command in the application.
Simply type ` + c.Name() + ` help [path to command] for full details.`,
Run: c.HelpFunc(),
}
}
c.AddCommand(c.helpCommand)
}
You can provide your own command, function or template through the following methods.
command.SetHelpCommand(cmd *Command)
command.SetHelpFunc(f func(*Command, []string))
command.SetHelpTemplate(s string)
The latter two will also apply to any children commands.
## Usage
When the user provides an invalid flag or invalid command Cobra responds by
showing the user the 'usage'
### Example
You may recognize this from the help above. That's because the default help
embeds the usage as part of it's output.
Usage:
hugo [flags]
hugo [command]
Available Commands:
server Hugo runs it's own a webserver to render the files
version Print the version number of Hugo
check Check content in the source directory
benchmark Benchmark hugo by building a site a number of times
help [command] Help about any command
Available Flags:
-b, --base-url="": hostname (and path) to the root eg. http://spf13.com/
-D, --build-drafts=false: include content marked as draft
--config="": config file (default is path/config.yaml|json|toml)
-d, --destination="": filesystem path to write files to
-s, --source="": filesystem path to read files relative from
--stepAnalysis=false: display memory and timing of different steps of the program
--uglyurls=false: if true, use /filename.html instead of /filename/
-v, --verbose=false: verbose output
-w, --watch=false: watch filesystem for changes and recreate as needed
### Defining your own usage
You can provide your own usage function or template for cobra to use.
The default usage function is
return func(c *Command) error {
err := tmpl(c.Out(), c.UsageTemplate(), c)
return err
}
Like help the function and template are over ridable through public methods.
command.SetUsageFunc(f func(*Command) error)
command.SetUsageTemplate(s string)
## PreRun or PostRun Hooks
It is possible to run functions before or after the main `Run` function of your command. The `PersistentPreRun` and `PreRun` functions will be executed before `Run`. `PersistendPostRun` and `PostRun` will be executed after `Run`. The `Persistent*Run` functions will be inherrited by children if they do not declare their own. These function are run in the following order:
- `PersistentPreRun`
- `PreRun`
- `Run`
- `PostRun`
- `PersistenPostRun`
And example of two commands which use all of these features is below. When the subcommand in executed it will run the root command's `PersistentPreRun` but not the root command's `PersistentPostRun`
```go
package main
import (
"fmt"
"github.com/spf13/cobra"
)
func main() {
var rootCmd = &cobra.Command{
Use: "root [sub]",
Short: "My root command",
PersistentPreRun: func(cmd *cobra.Command, args []string) {
fmt.Printf("Inside rootCmd PersistentPreRun with args: %v\n", args)
},
PreRun: func(cmd *cobra.Command, args []string) {
fmt.Printf("Inside rootCmd PreRun with args: %v\n", args)
},
Run: func(cmd *cobra.Command, args []string) {
fmt.Printf("Inside rootCmd Run with args: %v\n", args)
},
PostRun: func(cmd *cobra.Command, args []string) {
fmt.Printf("Inside rootCmd PostRun with args: %v\n", args)
},
PersistentPostRun: func(cmd *cobra.Command, args []string) {
fmt.Printf("Inside rootCmd PersistentPostRun with args: %v\n", args)
},
}
var subCmd = &cobra.Command{
Use: "sub [no options!]",
Short: "My sub command",
PreRun: func(cmd *cobra.Command, args []string) {
fmt.Printf("Inside subCmd PreRun with args: %v\n", args)
},
Run: func(cmd *cobra.Command, args []string) {
fmt.Printf("Inside subCmd Run with args: %v\n", args)
},
PostRun: func(cmd *cobra.Command, args []string) {
fmt.Printf("Inside subCmd PostRun with args: %v\n", args)
},
PersistentPostRun: func(cmd *cobra.Command, args []string) {
fmt.Printf("Inside subCmd PersistentPostRun with args: %v\n", args)
},
}
rootCmd.AddCommand(subCmd)
rootCmd.SetArgs([]string{""})
_ = rootCmd.Execute()
fmt.Print("\n")
rootCmd.SetArgs([]string{"sub", "arg1", "arg2"})
_ = rootCmd.Execute()
}
```
## Generating markdown formatted documentation for your command
Cobra can generate a markdown formatted document based on the subcommands, flags, etc. A simple example of how to do this for your command can be found in [Markdown Docs](md_docs.md)
## Generating bash completions for your command
Cobra can generate a bash completions file. If you add more information to your command these completions can be amazingly powerful and flexible. Read more about [Bash Completions](bash_completions.md)
## Debugging
Cobra provides a DebugFlags method on a command which when called will print
out everything Cobra knows about the flags for each command
### Example
command.DebugFlags()
## Release Notes
* **0.9.0** June 17, 2014
* flags can appears anywhere in the args (provided they are unambiguous)
* --help prints usage screen for app or command
* Prefix matching for commands
* Cleaner looking help and usage output
* Extensive test suite
* **0.8.0** Nov 5, 2013
* Reworked interface to remove commander completely
* Command now primary structure
* No initialization needed
* Usage & Help templates & functions definable at any level
* Updated Readme
* **0.7.0** Sept 24, 2013
* Needs more eyes
* Test suite
* Support for automatic error messages
* Support for help command
* Support for printing to any io.Writer instead of os.Stderr
* Support for persistent flags which cascade down tree
* Ready for integration into Hugo
* **0.1.0** Sept 3, 2013
* Implement first draft
## ToDo
* Launch proper documentation site
## Contributing
1. Fork it
2. Create your feature branch (`git checkout -b my-new-feature`)
3. Commit your changes (`git commit -am 'Add some feature'`)
4. Push to the branch (`git push origin my-new-feature`)
5. Create new Pull Request
## Contributors
Names in no particular order:
* [spf13](https://github.com/spf13)
## License
Cobra is released under the Apache 2.0 license. See [LICENSE.txt](https://github.com/spf13/cobra/blob/master/LICENSE.txt)
[![Bitdeli Badge](https://d2weczhvl823v0.cloudfront.net/spf13/cobra/trend.png)](https://bitdeli.com/free "Bitdeli Badge")

View file

@ -1,370 +0,0 @@
package cobra
import (
"bytes"
"fmt"
"os"
"sort"
"strings"
"github.com/spf13/pflag"
)
const (
BashCompFilenameExt = "cobra_annotation_bash_completion_filename_extentions"
BashCompOneRequiredFlag = "cobra_annotation_bash_completion_one_required_flag"
)
func preamble(out *bytes.Buffer) {
fmt.Fprintf(out, `#!/bin/bash
__debug()
{
if [[ -n ${BASH_COMP_DEBUG_FILE} ]]; then
echo "$*" >> "${BASH_COMP_DEBUG_FILE}"
fi
}
__index_of_word()
{
local w word=$1
shift
index=0
for w in "$@"; do
[[ $w = "$word" ]] && return
index=$((index+1))
done
index=-1
}
__contains_word()
{
local w word=$1; shift
for w in "$@"; do
[[ $w = "$word" ]] && return
done
return 1
}
__handle_reply()
{
__debug "${FUNCNAME}"
case $cur in
-*)
compopt -o nospace
local allflags
if [ ${#must_have_one_flag[@]} -ne 0 ]; then
allflags=("${must_have_one_flag[@]}")
else
allflags=("${flags[*]} ${two_word_flags[*]}")
fi
COMPREPLY=( $(compgen -W "${allflags[*]}" -- "$cur") )
[[ $COMPREPLY == *= ]] || compopt +o nospace
return 0;
;;
esac
# check if we are handling a flag with special work handling
local index
__index_of_word "${prev}" "${flags_with_completion[@]}"
if [[ ${index} -ge 0 ]]; then
${flags_completion[${index}]}
return
fi
# we are parsing a flag and don't have a special handler, no completion
if [[ ${cur} != "${words[cword]}" ]]; then
return
fi
local completions
if [[ ${#must_have_one_flag[@]} -ne 0 ]]; then
completions=("${must_have_one_flag[@]}")
elif [[ ${#must_have_one_noun[@]} -ne 0 ]]; then
completions=("${must_have_one_noun[@]}")
else
completions=("${commands[@]}")
fi
COMPREPLY=( $(compgen -W "${completions[*]}" -- "$cur") )
if [[ ${#COMPREPLY[@]} -eq 0 ]]; then
declare -F __custom_func >/dev/null && __custom_func
fi
}
# The arguments should be in the form "ext1|ext2|extn"
__handle_filename_extension_flag()
{
local ext="$1"
_filedir "@(${ext})"
}
__handle_flag()
{
__debug "${FUNCNAME}: c is $c words[c] is ${words[c]}"
# if a command required a flag, and we found it, unset must_have_one_flag()
local flagname=${words[c]}
# if the word contained an =
if [[ ${words[c]} == *"="* ]]; then
flagname=${flagname%%=*} # strip everything after the =
flagname="${flagname}=" # but put the = back
fi
__debug "${FUNCNAME}: looking for ${flagname}"
if __contains_word "${flagname}" "${must_have_one_flag[@]}"; then
must_have_one_flag=()
fi
# skip the argument to a two word flag
if __contains_word "${words[c]}" "${two_word_flags[@]}"; then
c=$((c+1))
# if we are looking for a flags value, don't show commands
if [[ $c -eq $cword ]]; then
commands=()
fi
fi
# skip the flag itself
c=$((c+1))
}
__handle_noun()
{
__debug "${FUNCNAME}: c is $c words[c] is ${words[c]}"
if __contains_word "${words[c]}" "${must_have_one_noun[@]}"; then
must_have_one_noun=()
fi
nouns+=("${words[c]}")
c=$((c+1))
}
__handle_command()
{
__debug "${FUNCNAME}: c is $c words[c] is ${words[c]}"
local next_command
if [[ -n ${last_command} ]]; then
next_command="_${last_command}_${words[c]}"
else
next_command="_${words[c]}"
fi
c=$((c+1))
__debug "${FUNCNAME}: looking for ${next_command}"
declare -F $next_command >/dev/null && $next_command
}
__handle_word()
{
if [[ $c -ge $cword ]]; then
__handle_reply
return
fi
__debug "${FUNCNAME}: c is $c words[c] is ${words[c]}"
if [[ "${words[c]}" == -* ]]; then
__handle_flag
elif __contains_word "${words[c]}" "${commands[@]}"; then
__handle_command
else
__handle_noun
fi
__handle_word
}
`)
}
func postscript(out *bytes.Buffer, name string) {
fmt.Fprintf(out, "__start_%s()\n", name)
fmt.Fprintf(out, `{
local cur prev words cword
_init_completion -s || return
local c=0
local flags=()
local two_word_flags=()
local flags_with_completion=()
local flags_completion=()
local commands=("%s")
local must_have_one_flag=()
local must_have_one_noun=()
local last_command
local nouns=()
__handle_word
}
`, name)
fmt.Fprintf(out, "complete -F __start_%s %s\n", name, name)
fmt.Fprintf(out, "# ex: ts=4 sw=4 et filetype=sh\n")
}
func writeCommands(cmd *Command, out *bytes.Buffer) {
fmt.Fprintf(out, " commands=()\n")
for _, c := range cmd.Commands() {
if len(c.Deprecated) > 0 {
continue
}
fmt.Fprintf(out, " commands+=(%q)\n", c.Name())
}
fmt.Fprintf(out, "\n")
}
func writeFlagHandler(name string, annotations map[string][]string, out *bytes.Buffer) {
for key, value := range annotations {
switch key {
case BashCompFilenameExt:
fmt.Fprintf(out, " flags_with_completion+=(%q)\n", name)
if len(value) > 0 {
ext := "__handle_filename_extension_flag " + strings.Join(value, "|")
fmt.Fprintf(out, " flags_completion+=(%q)\n", ext)
} else {
ext := "_filedir"
fmt.Fprintf(out, " flags_completion+=(%q)\n", ext)
}
}
}
}
func writeShortFlag(flag *pflag.Flag, out *bytes.Buffer) {
b := (flag.Value.Type() == "bool")
name := flag.Shorthand
format := " "
if !b {
format += "two_word_"
}
format += "flags+=(\"-%s\")\n"
fmt.Fprintf(out, format, name)
writeFlagHandler("-"+name, flag.Annotations, out)
}
func writeFlag(flag *pflag.Flag, out *bytes.Buffer) {
b := (flag.Value.Type() == "bool")
name := flag.Name
format := " flags+=(\"--%s"
if !b {
format += "="
}
format += "\")\n"
fmt.Fprintf(out, format, name)
writeFlagHandler("--"+name, flag.Annotations, out)
}
func writeFlags(cmd *Command, out *bytes.Buffer) {
fmt.Fprintf(out, ` flags=()
two_word_flags=()
flags_with_completion=()
flags_completion=()
`)
cmd.NonInheritedFlags().VisitAll(func(flag *pflag.Flag) {
writeFlag(flag, out)
if len(flag.Shorthand) > 0 {
writeShortFlag(flag, out)
}
})
fmt.Fprintf(out, "\n")
}
func writeRequiredFlag(cmd *Command, out *bytes.Buffer) {
fmt.Fprintf(out, " must_have_one_flag=()\n")
flags := cmd.NonInheritedFlags()
flags.VisitAll(func(flag *pflag.Flag) {
for key, _ := range flag.Annotations {
switch key {
case BashCompOneRequiredFlag:
format := " must_have_one_flag+=(\"--%s"
b := (flag.Value.Type() == "bool")
if !b {
format += "="
}
format += "\")\n"
fmt.Fprintf(out, format, flag.Name)
if len(flag.Shorthand) > 0 {
fmt.Fprintf(out, " must_have_one_flag+=(\"-%s\")\n", flag.Shorthand)
}
}
}
})
}
func writeRequiredNoun(cmd *Command, out *bytes.Buffer) {
fmt.Fprintf(out, " must_have_one_noun=()\n")
sort.Sort(sort.StringSlice(cmd.ValidArgs))
for _, value := range cmd.ValidArgs {
fmt.Fprintf(out, " must_have_one_noun+=(%q)\n", value)
}
}
func gen(cmd *Command, out *bytes.Buffer) {
for _, c := range cmd.Commands() {
if len(c.Deprecated) > 0 {
continue
}
gen(c, out)
}
commandName := cmd.CommandPath()
commandName = strings.Replace(commandName, " ", "_", -1)
fmt.Fprintf(out, "_%s()\n{\n", commandName)
fmt.Fprintf(out, " last_command=%q\n", commandName)
writeCommands(cmd, out)
writeFlags(cmd, out)
writeRequiredFlag(cmd, out)
writeRequiredNoun(cmd, out)
fmt.Fprintf(out, "}\n\n")
}
func (cmd *Command) GenBashCompletion(out *bytes.Buffer) {
preamble(out)
if len(cmd.BashCompletionFunction) > 0 {
fmt.Fprintf(out, "%s\n", cmd.BashCompletionFunction)
}
gen(cmd, out)
postscript(out, cmd.Name())
}
func (cmd *Command) GenBashCompletionFile(filename string) error {
out := new(bytes.Buffer)
cmd.GenBashCompletion(out)
outFile, err := os.Create(filename)
if err != nil {
return err
}
defer outFile.Close()
_, err = outFile.Write(out.Bytes())
if err != nil {
return err
}
return nil
}
// MarkFlagRequired adds the BashCompOneRequiredFlag annotation to the named flag, if it exists.
func (cmd *Command) MarkFlagRequired(name string) error {
return MarkFlagRequired(cmd.Flags(), name)
}
// MarkFlagRequired adds the BashCompOneRequiredFlag annotation to the named flag in the flag set, if it exists.
func MarkFlagRequired(flags *pflag.FlagSet, name string) error {
return flags.SetAnnotation(name, BashCompOneRequiredFlag, []string{"true"})
}
// MarkFlagFilename adds the BashCompFilenameExt annotation to the named flag, if it exists.
// Generated bash autocompletion will select filenames for the flag, limiting to named extensions if provided.
func (cmd *Command) MarkFlagFilename(name string, extensions ...string) error {
return MarkFlagFilename(cmd.Flags(), name, extensions...)
}
// MarkFlagFilename adds the BashCompFilenameExt annotation to the named flag in the flag set, if it exists.
// Generated bash autocompletion will select filenames for the flag, limiting to named extensions if provided.
func MarkFlagFilename(flags *pflag.FlagSet, name string, extensions ...string) error {
return flags.SetAnnotation(name, BashCompFilenameExt, extensions)
}

View file

@ -1,80 +0,0 @@
package cobra
import (
"bytes"
"fmt"
"os"
"strings"
"testing"
)
var _ = fmt.Println
var _ = os.Stderr
func checkOmit(t *testing.T, found, unexpected string) {
if strings.Contains(found, unexpected) {
t.Errorf("Unexpected response.\nGot: %q\nBut should not have!\n", unexpected)
}
}
func check(t *testing.T, found, expected string) {
if !strings.Contains(found, expected) {
t.Errorf("Unexpected response.\nExpecting to contain: \n %q\nGot:\n %q\n", expected, found)
}
}
// World worst custom function, just keep telling you to enter hello!
const (
bash_completion_func = `__custom_func() {
COMPREPLY=( "hello" )
}
`
)
func TestBashCompletions(t *testing.T) {
c := initializeWithRootCmd()
cmdEcho.AddCommand(cmdTimes)
c.AddCommand(cmdEcho, cmdPrint, cmdDeprecated)
// custom completion function
c.BashCompletionFunction = bash_completion_func
// required flag
c.MarkFlagRequired("introot")
// valid nouns
validArgs := []string{"pods", "nodes", "services", "replicationControllers"}
c.ValidArgs = validArgs
// filename
var flagval string
c.Flags().StringVar(&flagval, "filename", "", "Enter a filename")
c.MarkFlagFilename("filename", "json", "yaml", "yml")
// filename extensions
var flagvalExt string
c.Flags().StringVar(&flagvalExt, "filename-ext", "", "Enter a filename (extension limited)")
c.MarkFlagFilename("filename-ext")
out := new(bytes.Buffer)
c.GenBashCompletion(out)
str := out.String()
check(t, str, "_cobra-test")
check(t, str, "_cobra-test_echo")
check(t, str, "_cobra-test_echo_times")
check(t, str, "_cobra-test_print")
// check for required flags
check(t, str, `must_have_one_flag+=("--introot=")`)
// check for custom completion function
check(t, str, `COMPREPLY=( "hello" )`)
// check for required nouns
check(t, str, `must_have_one_noun+=("pods")`)
// check for filename extension flags
check(t, str, `flags_completion+=("_filedir")`)
// check for filename extension flags
check(t, str, `flags_completion+=("__handle_filename_extension_flag json|yaml|yml")`)
checkOmit(t, str, cmdDeprecated.Name())
}

View file

@ -1,112 +0,0 @@
// Copyright © 2013 Steve Francia <spf@spf13.com>.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Commands similar to git, go tools and other modern CLI tools
// inspired by go, go-Commander, gh and subcommand
package cobra
import (
"fmt"
"io"
"reflect"
"strconv"
"strings"
"text/template"
)
var initializers []func()
// automatic prefix matching can be a dangerous thing to automatically enable in CLI tools.
// Set this to true to enable it
var EnablePrefixMatching bool = false
// enables an information splash screen on Windows if the CLI is started from explorer.exe.
var EnableWindowsMouseTrap bool = true
var MousetrapHelpText string = `This is a command line tool
You need to open cmd.exe and run it from there.
`
//OnInitialize takes a series of func() arguments and appends them to a slice of func().
func OnInitialize(y ...func()) {
for _, x := range y {
initializers = append(initializers, x)
}
}
//Gt takes two types and checks whether the first type is greater than the second. In case of types Arrays, Chans,
//Maps and Slices, Gt will compare their lengths. Ints are compared directly while strings are first parsed as
//ints and then compared.
func Gt(a interface{}, b interface{}) bool {
var left, right int64
av := reflect.ValueOf(a)
switch av.Kind() {
case reflect.Array, reflect.Chan, reflect.Map, reflect.Slice:
left = int64(av.Len())
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
left = av.Int()
case reflect.String:
left, _ = strconv.ParseInt(av.String(), 10, 64)
}
bv := reflect.ValueOf(b)
switch bv.Kind() {
case reflect.Array, reflect.Chan, reflect.Map, reflect.Slice:
right = int64(bv.Len())
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
right = bv.Int()
case reflect.String:
right, _ = strconv.ParseInt(bv.String(), 10, 64)
}
return left > right
}
//Eq takes two types and checks whether they are equal. Supported types are int and string. Unsupported types will panic.
func Eq(a interface{}, b interface{}) bool {
av := reflect.ValueOf(a)
bv := reflect.ValueOf(b)
switch av.Kind() {
case reflect.Array, reflect.Chan, reflect.Map, reflect.Slice:
panic("Eq called on unsupported type")
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return av.Int() == bv.Int()
case reflect.String:
return av.String() == bv.String()
}
return false
}
//rpad adds padding to the right of a string
func rpad(s string, padding int) string {
template := fmt.Sprintf("%%-%ds", padding)
return fmt.Sprintf(template, s)
}
// tmpl executes the given template text on data, writing the result to w.
func tmpl(w io.Writer, text string, data interface{}) error {
t := template.New("top")
t.Funcs(template.FuncMap{
"trim": strings.TrimSpace,
"rpad": rpad,
"gt": Gt,
"eq": Eq,
})
template.Must(t.Parse(text))
return t.Execute(w, data)
}

View file

@ -1,965 +0,0 @@
package cobra
import (
"bytes"
"fmt"
"os"
"reflect"
"runtime"
"strings"
"testing"
"github.com/spf13/pflag"
)
var _ = fmt.Println
var _ = os.Stderr
var tp, te, tt, t1, tr []string
var rootPersPre, echoPre, echoPersPre, timesPersPre []string
var flagb1, flagb2, flagb3, flagbr, flagbp bool
var flags1, flags2a, flags2b, flags3 string
var flagi1, flagi2, flagi3, flagir int
var globalFlag1 bool
var flagEcho, rootcalled bool
var versionUsed int
const strtwoParentHelp = "help message for parent flag strtwo"
const strtwoChildHelp = "help message for child flag strtwo"
var cmdPrint = &Command{
Use: "print [string to print]",
Short: "Print anything to the screen",
Long: `an absolutely utterly useless command for testing.`,
Run: func(cmd *Command, args []string) {
tp = args
},
}
var cmdEcho = &Command{
Use: "echo [string to echo]",
Aliases: []string{"say"},
Short: "Echo anything to the screen",
Long: `an utterly useless command for testing.`,
Example: "Just run cobra-test echo",
PersistentPreRun: func(cmd *Command, args []string) {
echoPersPre = args
},
PreRun: func(cmd *Command, args []string) {
echoPre = args
},
Run: func(cmd *Command, args []string) {
te = args
},
}
var cmdEchoSub = &Command{
Use: "echosub [string to print]",
Short: "second sub command for echo",
Long: `an absolutely utterly useless command for testing gendocs!.`,
Run: func(cmd *Command, args []string) {
},
}
var cmdDeprecated = &Command{
Use: "deprecated [can't do anything here]",
Short: "A command which is deprecated",
Long: `an absolutely utterly useless command for testing deprecation!.`,
Deprecated: "Please use echo instead",
Run: func(cmd *Command, args []string) {
},
}
var cmdTimes = &Command{
Use: "times [# times] [string to echo]",
Short: "Echo anything to the screen more times",
Long: `a slightly useless command for testing.`,
PersistentPreRun: func(cmd *Command, args []string) {
timesPersPre = args
},
Run: func(cmd *Command, args []string) {
tt = args
},
}
var cmdRootNoRun = &Command{
Use: "cobra-test",
Short: "The root can run it's own function",
Long: "The root description for help",
PersistentPreRun: func(cmd *Command, args []string) {
rootPersPre = args
},
}
var cmdRootSameName = &Command{
Use: "print",
Short: "Root with the same name as a subcommand",
Long: "The root description for help",
}
var cmdRootWithRun = &Command{
Use: "cobra-test",
Short: "The root can run it's own function",
Long: "The root description for help",
Run: func(cmd *Command, args []string) {
tr = args
rootcalled = true
},
}
var cmdSubNoRun = &Command{
Use: "subnorun",
Short: "A subcommand without a Run function",
Long: "A long output about a subcommand without a Run function",
}
var cmdVersion1 = &Command{
Use: "version",
Short: "Print the version number",
Long: `First version of the version command`,
Run: func(cmd *Command, args []string) {
versionUsed = 1
},
}
var cmdVersion2 = &Command{
Use: "version",
Short: "Print the version number",
Long: `Second version of the version command`,
Run: func(cmd *Command, args []string) {
versionUsed = 2
},
}
func flagInit() {
cmdEcho.ResetFlags()
cmdPrint.ResetFlags()
cmdTimes.ResetFlags()
cmdRootNoRun.ResetFlags()
cmdRootSameName.ResetFlags()
cmdRootWithRun.ResetFlags()
cmdSubNoRun.ResetFlags()
cmdRootNoRun.PersistentFlags().StringVarP(&flags2a, "strtwo", "t", "two", strtwoParentHelp)
cmdEcho.Flags().IntVarP(&flagi1, "intone", "i", 123, "help message for flag intone")
cmdTimes.Flags().IntVarP(&flagi2, "inttwo", "j", 234, "help message for flag inttwo")
cmdPrint.Flags().IntVarP(&flagi3, "intthree", "i", 345, "help message for flag intthree")
cmdEcho.PersistentFlags().StringVarP(&flags1, "strone", "s", "one", "help message for flag strone")
cmdEcho.PersistentFlags().BoolVarP(&flagbp, "persistentbool", "p", false, "help message for flag persistentbool")
cmdTimes.PersistentFlags().StringVarP(&flags2b, "strtwo", "t", "2", strtwoChildHelp)
cmdPrint.PersistentFlags().StringVarP(&flags3, "strthree", "s", "three", "help message for flag strthree")
cmdEcho.Flags().BoolVarP(&flagb1, "boolone", "b", true, "help message for flag boolone")
cmdTimes.Flags().BoolVarP(&flagb2, "booltwo", "c", false, "help message for flag booltwo")
cmdPrint.Flags().BoolVarP(&flagb3, "boolthree", "b", true, "help message for flag boolthree")
cmdVersion1.ResetFlags()
cmdVersion2.ResetFlags()
}
func commandInit() {
cmdEcho.ResetCommands()
cmdPrint.ResetCommands()
cmdTimes.ResetCommands()
cmdRootNoRun.ResetCommands()
cmdRootSameName.ResetCommands()
cmdRootWithRun.ResetCommands()
cmdSubNoRun.ResetCommands()
}
func initialize() *Command {
tt, tp, te = nil, nil, nil
rootPersPre, echoPre, echoPersPre, timesPersPre = nil, nil, nil, nil
var c = cmdRootNoRun
flagInit()
commandInit()
return c
}
func initializeWithSameName() *Command {
tt, tp, te = nil, nil, nil
rootPersPre, echoPre, echoPersPre, timesPersPre = nil, nil, nil, nil
var c = cmdRootSameName
flagInit()
commandInit()
return c
}
func initializeWithRootCmd() *Command {
cmdRootWithRun.ResetCommands()
tt, tp, te, tr, rootcalled = nil, nil, nil, nil, false
flagInit()
cmdRootWithRun.Flags().BoolVarP(&flagbr, "boolroot", "b", false, "help message for flag boolroot")
cmdRootWithRun.Flags().IntVarP(&flagir, "introot", "i", 321, "help message for flag introot")
commandInit()
return cmdRootWithRun
}
type resulter struct {
Error error
Output string
Command *Command
}
func fullSetupTest(input string) resulter {
c := initializeWithRootCmd()
return fullTester(c, input)
}
func noRRSetupTest(input string) resulter {
c := initialize()
return fullTester(c, input)
}
func rootOnlySetupTest(input string) resulter {
c := initializeWithRootCmd()
return simpleTester(c, input)
}
func simpleTester(c *Command, input string) resulter {
buf := new(bytes.Buffer)
// Testing flag with invalid input
c.SetOutput(buf)
c.SetArgs(strings.Split(input, " "))
err := c.Execute()
output := buf.String()
return resulter{err, output, c}
}
func fullTester(c *Command, input string) resulter {
buf := new(bytes.Buffer)
// Testing flag with invalid input
c.SetOutput(buf)
cmdEcho.AddCommand(cmdTimes)
c.AddCommand(cmdPrint, cmdEcho, cmdSubNoRun, cmdDeprecated)
c.SetArgs(strings.Split(input, " "))
err := c.Execute()
output := buf.String()
return resulter{err, output, c}
}
func logErr(t *testing.T, found, expected string) {
out := new(bytes.Buffer)
_, _, line, ok := runtime.Caller(2)
if ok {
fmt.Fprintf(out, "Line: %d ", line)
}
fmt.Fprintf(out, "Unexpected response.\nExpecting to contain: \n %q\nGot:\n %q\n", expected, found)
t.Errorf(out.String())
}
func checkResultContains(t *testing.T, x resulter, check string) {
if !strings.Contains(x.Output, check) {
logErr(t, x.Output, check)
}
}
func checkResultOmits(t *testing.T, x resulter, check string) {
if strings.Contains(x.Output, check) {
logErr(t, x.Output, check)
}
}
func checkOutputContains(t *testing.T, c *Command, check string) {
buf := new(bytes.Buffer)
c.SetOutput(buf)
c.Execute()
if !strings.Contains(buf.String(), check) {
logErr(t, buf.String(), check)
}
}
func TestSingleCommand(t *testing.T) {
noRRSetupTest("print one two")
if te != nil || tt != nil {
t.Error("Wrong command called")
}
if tp == nil {
t.Error("Wrong command called")
}
if strings.Join(tp, " ") != "one two" {
t.Error("Command didn't parse correctly")
}
}
func TestChildCommand(t *testing.T) {
noRRSetupTest("echo times one two")
if te != nil || tp != nil {
t.Error("Wrong command called")
}
if tt == nil {
t.Error("Wrong command called")
}
if strings.Join(tt, " ") != "one two" {
t.Error("Command didn't parse correctly")
}
}
func TestCommandAlias(t *testing.T) {
noRRSetupTest("say times one two")
if te != nil || tp != nil {
t.Error("Wrong command called")
}
if tt == nil {
t.Error("Wrong command called")
}
if strings.Join(tt, " ") != "one two" {
t.Error("Command didn't parse correctly")
}
}
func TestPrefixMatching(t *testing.T) {
EnablePrefixMatching = true
noRRSetupTest("ech times one two")
if te != nil || tp != nil {
t.Error("Wrong command called")
}
if tt == nil {
t.Error("Wrong command called")
}
if strings.Join(tt, " ") != "one two" {
t.Error("Command didn't parse correctly")
}
EnablePrefixMatching = false
}
func TestNoPrefixMatching(t *testing.T) {
EnablePrefixMatching = false
noRRSetupTest("ech times one two")
if !(tt == nil && te == nil && tp == nil) {
t.Error("Wrong command called")
}
}
func TestAliasPrefixMatching(t *testing.T) {
EnablePrefixMatching = true
noRRSetupTest("sa times one two")
if te != nil || tp != nil {
t.Error("Wrong command called")
}
if tt == nil {
t.Error("Wrong command called")
}
if strings.Join(tt, " ") != "one two" {
t.Error("Command didn't parse correctly")
}
EnablePrefixMatching = false
}
func TestChildSameName(t *testing.T) {
c := initializeWithSameName()
c.AddCommand(cmdPrint, cmdEcho)
c.SetArgs(strings.Split("print one two", " "))
c.Execute()
if te != nil || tt != nil {
t.Error("Wrong command called")
}
if tp == nil {
t.Error("Wrong command called")
}
if strings.Join(tp, " ") != "one two" {
t.Error("Command didn't parse correctly")
}
}
func TestGrandChildSameName(t *testing.T) {
c := initializeWithSameName()
cmdTimes.AddCommand(cmdPrint)
c.AddCommand(cmdTimes)
c.SetArgs(strings.Split("times print one two", " "))
c.Execute()
if te != nil || tt != nil {
t.Error("Wrong command called")
}
if tp == nil {
t.Error("Wrong command called")
}
if strings.Join(tp, " ") != "one two" {
t.Error("Command didn't parse correctly")
}
}
func TestFlagLong(t *testing.T) {
noRRSetupTest("echo --intone=13 something here")
if strings.Join(te, " ") != "something here" {
t.Errorf("flags didn't leave proper args remaining..%s given", te)
}
if flagi1 != 13 {
t.Errorf("int flag didn't get correct value, had %d", flagi1)
}
if flagi2 != 234 {
t.Errorf("default flag value changed, 234 expected, %d given", flagi2)
}
}
func TestFlagShort(t *testing.T) {
noRRSetupTest("echo -i13 something here")
if strings.Join(te, " ") != "something here" {
t.Errorf("flags didn't leave proper args remaining..%s given", te)
}
if flagi1 != 13 {
t.Errorf("int flag didn't get correct value, had %d", flagi1)
}
if flagi2 != 234 {
t.Errorf("default flag value changed, 234 expected, %d given", flagi2)
}
noRRSetupTest("echo -i 13 something here")
if strings.Join(te, " ") != "something here" {
t.Errorf("flags didn't leave proper args remaining..%s given", te)
}
if flagi1 != 13 {
t.Errorf("int flag didn't get correct value, had %d", flagi1)
}
if flagi2 != 234 {
t.Errorf("default flag value changed, 234 expected, %d given", flagi2)
}
noRRSetupTest("print -i99 one two")
if strings.Join(tp, " ") != "one two" {
t.Errorf("flags didn't leave proper args remaining..%s given", tp)
}
if flagi3 != 99 {
t.Errorf("int flag didn't get correct value, had %d", flagi3)
}
if flagi1 != 123 {
t.Errorf("default flag value changed on different command with same shortname, 234 expected, %d given", flagi2)
}
}
func TestChildCommandFlags(t *testing.T) {
noRRSetupTest("echo times -j 99 one two")
if strings.Join(tt, " ") != "one two" {
t.Errorf("flags didn't leave proper args remaining..%s given", tt)
}
// Testing with flag that shouldn't be persistent
r := noRRSetupTest("echo times -j 99 -i77 one two")
if r.Error == nil {
t.Errorf("invalid flag should generate error")
}
if !strings.Contains(r.Output, "unknown shorthand") {
t.Errorf("Wrong error message displayed, \n %s", r.Output)
}
if flagi2 != 99 {
t.Errorf("flag value should be 99, %d given", flagi2)
}
if flagi1 != 123 {
t.Errorf("unset flag should have default value, expecting 123, given %d", flagi1)
}
// Testing with flag only existing on child
r = noRRSetupTest("echo -j 99 -i77 one two")
if r.Error == nil {
t.Errorf("invalid flag should generate error")
}
if !strings.Contains(r.Output, "unknown shorthand flag") {
t.Errorf("Wrong error message displayed, \n %s", r.Output)
}
// Testing with persistent flag overwritten by child
noRRSetupTest("echo times --strtwo=child one two")
if flags2b != "child" {
t.Errorf("flag value should be child, %s given", flags2b)
}
if flags2a != "two" {
t.Errorf("unset flag should have default value, expecting two, given %s", flags2a)
}
// Testing flag with invalid input
r = noRRSetupTest("echo -i10E")
if r.Error == nil {
t.Errorf("invalid input should generate error")
}
if !strings.Contains(r.Output, "invalid argument \"10E\" for i10E") {
t.Errorf("Wrong error message displayed, \n %s", r.Output)
}
}
func TestTrailingCommandFlags(t *testing.T) {
x := fullSetupTest("echo two -x")
if x.Error == nil {
t.Errorf("invalid flag should generate error")
}
}
func TestInvalidSubcommandFlags(t *testing.T) {
cmd := initializeWithRootCmd()
cmd.AddCommand(cmdTimes)
result := simpleTester(cmd, "times --inttwo=2 --badflag=bar")
checkResultContains(t, result, "unknown flag: --badflag")
if strings.Contains(result.Output, "unknown flag: --inttwo") {
t.Errorf("invalid --badflag flag shouldn't fail on 'unknown' --inttwo flag")
}
}
func TestSubcommandArgEvaluation(t *testing.T) {
cmd := initializeWithRootCmd()
first := &Command{
Use: "first",
Run: func(cmd *Command, args []string) {
},
}
cmd.AddCommand(first)
second := &Command{
Use: "second",
Run: func(cmd *Command, args []string) {
fmt.Fprintf(cmd.Out(), "%v", args)
},
}
first.AddCommand(second)
result := simpleTester(cmd, "first second first third")
expectedOutput := fmt.Sprintf("%v", []string{"first third"})
if result.Output != expectedOutput {
t.Errorf("exptected %v, got %v", expectedOutput, result.Output)
}
}
func TestPersistentFlags(t *testing.T) {
fullSetupTest("echo -s something -p more here")
// persistentFlag should act like normal flag on it's own command
if strings.Join(te, " ") != "more here" {
t.Errorf("flags didn't leave proper args remaining..%s given", te)
}
if flags1 != "something" {
t.Errorf("string flag didn't get correct value, had %v", flags1)
}
if !flagbp {
t.Errorf("persistent bool flag not parsed correctly. Expected true, had %v", flagbp)
}
// persistentFlag should act like normal flag on it's own command
fullSetupTest("echo times -s again -c -p test here")
if strings.Join(tt, " ") != "test here" {
t.Errorf("flags didn't leave proper args remaining..%s given", tt)
}
if flags1 != "again" {
t.Errorf("string flag didn't get correct value, had %v", flags1)
}
if !flagb2 {
t.Errorf("local flag not parsed correctly. Expected true, had %v", flagb2)
}
if !flagbp {
t.Errorf("persistent bool flag not parsed correctly. Expected true, had %v", flagbp)
}
}
func TestHelpCommand(t *testing.T) {
x := fullSetupTest("help")
checkResultContains(t, x, cmdRootWithRun.Long)
x = fullSetupTest("help echo")
checkResultContains(t, x, cmdEcho.Long)
x = fullSetupTest("help echo times")
checkResultContains(t, x, cmdTimes.Long)
}
func TestChildCommandHelp(t *testing.T) {
c := noRRSetupTest("print --help")
checkResultContains(t, c, strtwoParentHelp)
r := noRRSetupTest("echo times --help")
checkResultContains(t, r, strtwoChildHelp)
}
func TestNonRunChildHelp(t *testing.T) {
x := noRRSetupTest("subnorun")
checkResultContains(t, x, cmdSubNoRun.Long)
}
func TestRunnableRootCommand(t *testing.T) {
fullSetupTest("")
if rootcalled != true {
t.Errorf("Root Function was not called")
}
}
func TestRunnableRootCommandNilInput(t *testing.T) {
empty_arg := make([]string, 0)
c := initializeWithRootCmd()
buf := new(bytes.Buffer)
// Testing flag with invalid input
c.SetOutput(buf)
cmdEcho.AddCommand(cmdTimes)
c.AddCommand(cmdPrint, cmdEcho)
c.SetArgs(empty_arg)
c.Execute()
if rootcalled != true {
t.Errorf("Root Function was not called")
}
}
func TestRunnableRootCommandEmptyInput(t *testing.T) {
args := make([]string, 3)
args[0] = ""
args[1] = "--introot=12"
args[2] = ""
c := initializeWithRootCmd()
buf := new(bytes.Buffer)
// Testing flag with invalid input
c.SetOutput(buf)
cmdEcho.AddCommand(cmdTimes)
c.AddCommand(cmdPrint, cmdEcho)
c.SetArgs(args)
c.Execute()
if rootcalled != true {
t.Errorf("Root Function was not called.\n\nOutput was:\n\n%s\n", buf)
}
}
func TestInvalidSubcommandWhenArgsAllowed(t *testing.T) {
fullSetupTest("echo invalid-sub")
if te[0] != "invalid-sub" {
t.Errorf("Subcommand didn't work...")
}
}
func TestRootFlags(t *testing.T) {
fullSetupTest("-i 17 -b")
if flagbr != true {
t.Errorf("flag value should be true, %v given", flagbr)
}
if flagir != 17 {
t.Errorf("flag value should be 17, %d given", flagir)
}
}
func TestRootHelp(t *testing.T) {
x := fullSetupTest("--help")
checkResultContains(t, x, "Available Commands:")
checkResultContains(t, x, "for more information about a command")
if strings.Contains(x.Output, "unknown flag: --help") {
t.Errorf("--help shouldn't trigger an error, Got: \n %s", x.Output)
}
if strings.Contains(x.Output, cmdEcho.Use) {
t.Errorf("--help shouldn't display subcommand's usage, Got: \n %s", x.Output)
}
x = fullSetupTest("echo --help")
if strings.Contains(x.Output, cmdTimes.Use) {
t.Errorf("--help shouldn't display subsubcommand's usage, Got: \n %s", x.Output)
}
checkResultContains(t, x, "Available Commands:")
checkResultContains(t, x, "for more information about a command")
if strings.Contains(x.Output, "unknown flag: --help") {
t.Errorf("--help shouldn't trigger an error, Got: \n %s", x.Output)
}
}
func TestFlagAccess(t *testing.T) {
initialize()
local := cmdTimes.LocalFlags()
inherited := cmdTimes.InheritedFlags()
for _, f := range []string{"inttwo", "strtwo", "booltwo"} {
if local.Lookup(f) == nil {
t.Errorf("LocalFlags expected to contain %s, Got: nil", f)
}
}
if inherited.Lookup("strone") == nil {
t.Errorf("InheritedFlags expected to contain strone, Got: nil")
}
if inherited.Lookup("strtwo") != nil {
t.Errorf("InheritedFlags shouldn not contain overwritten flag strtwo")
}
}
func TestNoNRunnableRootCommandNilInput(t *testing.T) {
args := make([]string, 0)
c := initialize()
buf := new(bytes.Buffer)
// Testing flag with invalid input
c.SetOutput(buf)
cmdEcho.AddCommand(cmdTimes)
c.AddCommand(cmdPrint, cmdEcho)
c.SetArgs(args)
c.Execute()
if !strings.Contains(buf.String(), cmdRootNoRun.Long) {
t.Errorf("Expected to get help output, Got: \n %s", buf)
}
}
func TestRootNoCommandHelp(t *testing.T) {
x := rootOnlySetupTest("--help")
checkResultOmits(t, x, "Available Commands:")
checkResultOmits(t, x, "for more information about a command")
if strings.Contains(x.Output, "unknown flag: --help") {
t.Errorf("--help shouldn't trigger an error, Got: \n %s", x.Output)
}
x = rootOnlySetupTest("echo --help")
checkResultOmits(t, x, "Available Commands:")
checkResultOmits(t, x, "for more information about a command")
if strings.Contains(x.Output, "unknown flag: --help") {
t.Errorf("--help shouldn't trigger an error, Got: \n %s", x.Output)
}
}
func TestRootUnknownCommand(t *testing.T) {
r := noRRSetupTest("bogus")
s := "Error: unknown command \"bogus\" for \"cobra-test\"\nRun 'cobra-test --help' for usage.\n"
if r.Output != s {
t.Errorf("Unexpected response.\nExpecting to be:\n %q\nGot:\n %q\n", s, r.Output)
}
r = noRRSetupTest("--strtwo=a bogus")
if r.Output != s {
t.Errorf("Unexpected response.\nExpecting to be:\n %q\nGot:\n %q\n", s, r.Output)
}
}
func TestFlagsBeforeCommand(t *testing.T) {
// short without space
x := fullSetupTest("-i10 echo")
if x.Error != nil {
t.Errorf("Valid Input shouldn't have errors, got:\n %q", x.Error)
}
// short (int) with equals
// It appears that pflags doesn't support this...
// Commenting out until support can be added
//x = noRRSetupTest("echo -i=10")
//if x.Error != nil {
//t.Errorf("Valid Input shouldn't have errors, got:\n %s", x.Error)
//}
// long with equals
x = noRRSetupTest("--intone=123 echo one two")
if x.Error != nil {
t.Errorf("Valid Input shouldn't have errors, got:\n %s", x.Error)
}
// With parsing error properly reported
x = fullSetupTest("-i10E echo")
if !strings.Contains(x.Output, "invalid argument \"10E\" for i10E") {
t.Errorf("Wrong error message displayed, \n %s", x.Output)
}
//With quotes
x = fullSetupTest("-s=\"walking\" echo")
if x.Error != nil {
t.Errorf("Valid Input shouldn't have errors, got:\n %q", x.Error)
}
//With quotes and space
x = fullSetupTest("-s=\"walking fast\" echo")
if x.Error != nil {
t.Errorf("Valid Input shouldn't have errors, got:\n %q", x.Error)
}
//With inner quote
x = fullSetupTest("-s=\"walking \\\"Inner Quote\\\" fast\" echo")
if x.Error != nil {
t.Errorf("Valid Input shouldn't have errors, got:\n %q", x.Error)
}
//With quotes and space
x = fullSetupTest("-s=\"walking \\\"Inner Quote\\\" fast\" echo")
if x.Error != nil {
t.Errorf("Valid Input shouldn't have errors, got:\n %q", x.Error)
}
}
func TestRemoveCommand(t *testing.T) {
versionUsed = 0
c := initializeWithRootCmd()
c.AddCommand(cmdVersion1)
c.RemoveCommand(cmdVersion1)
x := fullTester(c, "version")
if x.Error == nil {
t.Errorf("Removed command should not have been called\n")
return
}
}
func TestCommandWithoutSubcommands(t *testing.T) {
c := initializeWithRootCmd()
x := simpleTester(c, "")
if x.Error != nil {
t.Errorf("Calling command without subcommands should not have error: %v", x.Error)
return
}
}
func TestCommandWithoutSubcommandsWithArg(t *testing.T) {
c := initializeWithRootCmd()
expectedArgs := []string{"arg"}
x := simpleTester(c, "arg")
if x.Error != nil {
t.Errorf("Calling command without subcommands but with arg should not have error: %v", x.Error)
return
}
if !reflect.DeepEqual(expectedArgs, tr) {
t.Errorf("Calling command without subcommands but with arg has wrong args: expected: %v, actual: %v", expectedArgs, tr)
return
}
}
func TestReplaceCommandWithRemove(t *testing.T) {
versionUsed = 0
c := initializeWithRootCmd()
c.AddCommand(cmdVersion1)
c.RemoveCommand(cmdVersion1)
c.AddCommand(cmdVersion2)
x := fullTester(c, "version")
if x.Error != nil {
t.Errorf("Valid Input shouldn't have errors, got:\n %q", x.Error)
return
}
if versionUsed == 1 {
t.Errorf("Removed command shouldn't be called\n")
}
if versionUsed != 2 {
t.Errorf("Replacing command should have been called but didn't\n")
}
}
func TestDeprecatedSub(t *testing.T) {
c := fullSetupTest("deprecated")
checkResultContains(t, c, cmdDeprecated.Deprecated)
}
func TestPreRun(t *testing.T) {
noRRSetupTest("echo one two")
if echoPre == nil || echoPersPre == nil {
t.Error("PreRun or PersistentPreRun not called")
}
if rootPersPre != nil || timesPersPre != nil {
t.Error("Wrong *Pre functions called!")
}
noRRSetupTest("echo times one two")
if timesPersPre == nil {
t.Error("PreRun or PersistentPreRun not called")
}
if echoPre != nil || echoPersPre != nil || rootPersPre != nil {
t.Error("Wrong *Pre functions called!")
}
noRRSetupTest("print one two")
if rootPersPre == nil {
t.Error("Parent PersistentPreRun not called but should not have been")
}
if echoPre != nil || echoPersPre != nil || timesPersPre != nil {
t.Error("Wrong *Pre functions called!")
}
}
// Check if cmdEchoSub gets PersistentPreRun from rootCmd even if is added last
func TestPeristentPreRunPropagation(t *testing.T) {
rootCmd := initialize()
// First add the cmdEchoSub to cmdPrint
cmdPrint.AddCommand(cmdEchoSub)
// Now add cmdPrint to rootCmd
rootCmd.AddCommand(cmdPrint)
rootCmd.SetArgs(strings.Split("print echosub lala", " "))
rootCmd.Execute()
if rootPersPre == nil || len(rootPersPre) == 0 || rootPersPre[0] != "lala" {
t.Error("RootCmd PersistentPreRun not called but should have been")
}
}
func TestGlobalNormFuncPropagation(t *testing.T) {
normFunc := func(f *pflag.FlagSet, name string) pflag.NormalizedName {
return pflag.NormalizedName(name)
}
rootCmd := initialize()
rootCmd.SetGlobalNormalizationFunc(normFunc)
if reflect.ValueOf(normFunc) != reflect.ValueOf(rootCmd.GlobalNormalizationFunc()) {
t.Error("rootCmd seems to have a wrong normalization function")
}
// First add the cmdEchoSub to cmdPrint
cmdPrint.AddCommand(cmdEchoSub)
if cmdPrint.GlobalNormalizationFunc() != nil && cmdEchoSub.GlobalNormalizationFunc() != nil {
t.Error("cmdPrint and cmdEchoSub should had no normalization functions")
}
// Now add cmdPrint to rootCmd
rootCmd.AddCommand(cmdPrint)
if reflect.ValueOf(cmdPrint.GlobalNormalizationFunc()).Pointer() != reflect.ValueOf(rootCmd.GlobalNormalizationFunc()).Pointer() ||
reflect.ValueOf(cmdEchoSub.GlobalNormalizationFunc()).Pointer() != reflect.ValueOf(rootCmd.GlobalNormalizationFunc()).Pointer() {
t.Error("cmdPrint and cmdEchoSub should had the normalization function of rootCmd")
}
}

View file

@ -1,90 +0,0 @@
package cobra
import (
"reflect"
"testing"
)
func TestStripFlags(t *testing.T) {
tests := []struct {
input []string
output []string
}{
{
[]string{"foo", "bar"},
[]string{"foo", "bar"},
},
{
[]string{"foo", "--bar", "-b"},
[]string{"foo"},
},
{
[]string{"-b", "foo", "--bar", "bar"},
[]string{},
},
{
[]string{"-i10", "echo"},
[]string{"echo"},
},
{
[]string{"-i=10", "echo"},
[]string{"echo"},
},
{
[]string{"--int=100", "echo"},
[]string{"echo"},
},
{
[]string{"-ib", "echo", "-bfoo", "baz"},
[]string{"echo", "baz"},
},
{
[]string{"-i=baz", "bar", "-i", "foo", "blah"},
[]string{"bar", "blah"},
},
{
[]string{"--int=baz", "-bbar", "-i", "foo", "blah"},
[]string{"blah"},
},
{
[]string{"--cat", "bar", "-i", "foo", "blah"},
[]string{"bar", "blah"},
},
{
[]string{"-c", "bar", "-i", "foo", "blah"},
[]string{"bar", "blah"},
},
{
[]string{"--persist", "bar"},
[]string{"bar"},
},
{
[]string{"-p", "bar"},
[]string{"bar"},
},
}
cmdPrint := &Command{
Use: "print [string to print]",
Short: "Print anything to the screen",
Long: `an utterly useless command for testing.`,
Run: func(cmd *Command, args []string) {
tp = args
},
}
var flagi int
var flagstr string
var flagbool bool
cmdPrint.PersistentFlags().BoolVarP(&flagbool, "persist", "p", false, "help for persistent one")
cmdPrint.Flags().IntVarP(&flagi, "int", "i", 345, "help message for flag int")
cmdPrint.Flags().StringVarP(&flagstr, "bar", "b", "bar", "help message for flag string")
cmdPrint.Flags().BoolVarP(&flagbool, "cat", "c", false, "help message for flag bool")
for _, test := range tests {
output := stripFlags(test.input, cmdPrint)
if !reflect.DeepEqual(test.output, output) {
t.Errorf("expected: %v, got: %v", test.output, output)
}
}
}

View file

@ -1,138 +0,0 @@
//Copyright 2015 Red Hat Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package cobra
import (
"bytes"
"fmt"
"os"
"sort"
"strings"
"time"
)
func printOptions(out *bytes.Buffer, cmd *Command, name string) {
flags := cmd.NonInheritedFlags()
flags.SetOutput(out)
if flags.HasFlags() {
fmt.Fprintf(out, "### Options\n\n```\n")
flags.PrintDefaults()
fmt.Fprintf(out, "```\n\n")
}
parentFlags := cmd.InheritedFlags()
parentFlags.SetOutput(out)
if parentFlags.HasFlags() {
fmt.Fprintf(out, "### Options inherited from parent commands\n\n```\n")
parentFlags.PrintDefaults()
fmt.Fprintf(out, "```\n\n")
}
}
type byName []*Command
func (s byName) Len() int { return len(s) }
func (s byName) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
func (s byName) Less(i, j int) bool { return s[i].Name() < s[j].Name() }
func GenMarkdown(cmd *Command, out *bytes.Buffer) {
GenMarkdownCustom(cmd, out, func(s string) string { return s })
}
func GenMarkdownCustom(cmd *Command, out *bytes.Buffer, linkHandler func(string) string) {
name := cmd.CommandPath()
short := cmd.Short
long := cmd.Long
if len(long) == 0 {
long = short
}
fmt.Fprintf(out, "## %s\n\n", name)
fmt.Fprintf(out, "%s\n\n", short)
fmt.Fprintf(out, "### Synopsis\n\n")
fmt.Fprintf(out, "\n%s\n\n", long)
if cmd.Runnable() {
fmt.Fprintf(out, "```\n%s\n```\n\n", cmd.UseLine())
}
if len(cmd.Example) > 0 {
fmt.Fprintf(out, "### Examples\n\n")
fmt.Fprintf(out, "```\n%s\n```\n\n", cmd.Example)
}
printOptions(out, cmd, name)
if len(cmd.Commands()) > 0 || cmd.HasParent() {
fmt.Fprintf(out, "### SEE ALSO\n")
if cmd.HasParent() {
parent := cmd.Parent()
pname := parent.CommandPath()
link := pname + ".md"
link = strings.Replace(link, " ", "_", -1)
fmt.Fprintf(out, "* [%s](%s)\t - %s\n", pname, linkHandler(link), parent.Short)
}
children := cmd.Commands()
sort.Sort(byName(children))
for _, child := range children {
if len(child.Deprecated) > 0 {
continue
}
cname := name + " " + child.Name()
link := cname + ".md"
link = strings.Replace(link, " ", "_", -1)
fmt.Fprintf(out, "* [%s](%s)\t - %s\n", cname, linkHandler(link), child.Short)
}
fmt.Fprintf(out, "\n")
}
fmt.Fprintf(out, "###### Auto generated by spf13/cobra at %s\n", time.Now().UTC())
}
func GenMarkdownTree(cmd *Command, dir string) {
identity := func(s string) string { return s }
emptyStr := func(s string) string { return "" }
GenMarkdownTreeCustom(cmd, dir, emptyStr, identity)
}
func GenMarkdownTreeCustom(cmd *Command, dir string, filePrepender func(string) string, linkHandler func(string) string) {
for _, c := range cmd.Commands() {
GenMarkdownTreeCustom(c, dir, filePrepender, linkHandler)
}
out := new(bytes.Buffer)
GenMarkdownCustom(cmd, out, linkHandler)
filename := cmd.CommandPath()
filename = dir + strings.Replace(filename, " ", "_", -1) + ".md"
outFile, err := os.Create(filename)
if err != nil {
fmt.Println(err)
os.Exit(1)
}
defer outFile.Close()
_, err = outFile.WriteString(filePrepender(filename))
if err != nil {
fmt.Println(err)
os.Exit(1)
}
_, err = outFile.Write(out.Bytes())
if err != nil {
fmt.Println(err)
os.Exit(1)
}
}

View file

@ -1,81 +0,0 @@
# Generating Markdown Docs For Your Own cobra.Command
## Generate markdown docs for the entire command tree
This program can actually generate docs for the kubectl command in the kubernetes project
```go
package main
import (
"io/ioutil"
"os"
"github.com/GoogleCloudPlatform/kubernetes/pkg/kubectl/cmd"
"github.com/spf13/cobra"
)
func main() {
kubectl := cmd.NewFactory(nil).NewKubectlCommand(os.Stdin, ioutil.Discard, ioutil.Discard)
cobra.GenMarkdownTree(kubectl, "./")
}
```
This will generate a whole series of files, one for each command in the tree, in the directory specified (in this case "./")
## Generate markdown docs for a single command
You may wish to have more control over the output, or only generate for a single command, instead of the entire command tree. If this is the case you may prefer to `GenMarkdown` instead of `GenMarkdownTree`
```go
out := new(bytes.Buffer)
cobra.GenMarkdown(cmd, out)
```
This will write the markdown doc for ONLY "cmd" into the out, buffer.
## Customize the output
Both `GenMarkdown` and `GenMarkdownTree` have alternate versions with callbacks to get some control of the output:
```go
func GenMarkdownTreeCustom(cmd *Command, dir string, filePrepender func(string) string, linkHandler func(string) string) {
//...
}
```
```go
func GenMarkdownCustom(cmd *Command, out *bytes.Buffer, linkHandler func(string) string) {
//...
}
```
The `filePrepender` will prepend the return value given the full filepath to the rendered Markdown file. A common use case is to add front matter to use the generated documentation with [Hugo](http://gohugo.io/):
```go
const fmTemplate = `---
date: %s
title: "%s"
slug: %s
url: %s
---
`
filePrepender := func(filename string) string {
now := time.Now().Format(time.RFC3339)
name := filepath.Base(filename)
base := strings.TrimSuffix(name, path.Ext(name))
url := "/commands/" + strings.ToLower(base) + "/"
return fmt.Sprintf(fmTemplate, now, strings.Replace(base, "_", " ", -1), base, url)
}
```
The `linkHandler` can be used to customize the rendered internal links to the commands, given a filename:
```go
linkHandler := func(name string) string {
base := strings.TrimSuffix(name, path.Ext(name))
return "/commands/" + strings.ToLower(base) + "/"
}
```

View file

@ -1,67 +0,0 @@
package cobra
import (
"bytes"
"fmt"
"os"
"strings"
"testing"
)
var _ = fmt.Println
var _ = os.Stderr
func TestGenMdDoc(t *testing.T) {
c := initializeWithRootCmd()
// Need two commands to run the command alphabetical sort
cmdEcho.AddCommand(cmdTimes, cmdEchoSub, cmdDeprecated)
c.AddCommand(cmdPrint, cmdEcho)
cmdRootWithRun.PersistentFlags().StringVarP(&flags2a, "rootflag", "r", "two", strtwoParentHelp)
out := new(bytes.Buffer)
// We generate on s subcommand so we have both subcommands and parents
GenMarkdown(cmdEcho, out)
found := out.String()
// Our description
expected := cmdEcho.Long
if !strings.Contains(found, expected) {
t.Errorf("Unexpected response.\nExpecting to contain: \n %q\nGot:\n %q\n", expected, found)
}
// Better have our example
expected = cmdEcho.Example
if !strings.Contains(found, expected) {
t.Errorf("Unexpected response.\nExpecting to contain: \n %q\nGot:\n %q\n", expected, found)
}
// A local flag
expected = "boolone"
if !strings.Contains(found, expected) {
t.Errorf("Unexpected response.\nExpecting to contain: \n %q\nGot:\n %q\n", expected, found)
}
// persistent flag on parent
expected = "rootflag"
if !strings.Contains(found, expected) {
t.Errorf("Unexpected response.\nExpecting to contain: \n %q\nGot:\n %q\n", expected, found)
}
// We better output info about our parent
expected = cmdRootWithRun.Short
if !strings.Contains(found, expected) {
t.Errorf("Unexpected response.\nExpecting to contain: \n %q\nGot:\n %q\n", expected, found)
}
// And about subcommands
expected = cmdEchoSub.Short
if !strings.Contains(found, expected) {
t.Errorf("Unexpected response.\nExpecting to contain: \n %q\nGot:\n %q\n", expected, found)
}
unexpected := cmdDeprecated.Short
if strings.Contains(found, unexpected) {
t.Errorf("Unexpected response.\nFound: %v\nBut should not have!!\n", unexpected)
}
}

View file

@ -1,8 +0,0 @@
sudo: false
language: go
go:
- 1.3
- 1.4
- tip

View file

@ -1,180 +0,0 @@
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package pflag
import (
"bytes"
"fmt"
"strconv"
"testing"
)
// This value can be a boolean ("true", "false") or "maybe"
type triStateValue int
const (
triStateFalse triStateValue = 0
triStateTrue triStateValue = 1
triStateMaybe triStateValue = 2
)
const strTriStateMaybe = "maybe"
func (v *triStateValue) IsBoolFlag() bool {
return true
}
func (v *triStateValue) Get() interface{} {
return triStateValue(*v)
}
func (v *triStateValue) Set(s string) error {
if s == strTriStateMaybe {
*v = triStateMaybe
return nil
}
boolVal, err := strconv.ParseBool(s)
if boolVal {
*v = triStateTrue
} else {
*v = triStateFalse
}
return err
}
func (v *triStateValue) String() string {
if *v == triStateMaybe {
return strTriStateMaybe
}
return fmt.Sprintf("%v", bool(*v == triStateTrue))
}
// The type of the flag as requred by the pflag.Value interface
func (v *triStateValue) Type() string {
return "version"
}
func setUpFlagSet(tristate *triStateValue) *FlagSet {
f := NewFlagSet("test", ContinueOnError)
*tristate = triStateFalse
flag := f.VarPF(tristate, "tristate", "t", "tristate value (true, maybe or false)")
flag.NoOptDefVal = "true"
return f
}
func TestExplicitTrue(t *testing.T) {
var tristate triStateValue
f := setUpFlagSet(&tristate)
err := f.Parse([]string{"--tristate=true"})
if err != nil {
t.Fatal("expected no error; got", err)
}
if tristate != triStateTrue {
t.Fatal("expected", triStateTrue, "(triStateTrue) but got", tristate, "instead")
}
}
func TestImplicitTrue(t *testing.T) {
var tristate triStateValue
f := setUpFlagSet(&tristate)
err := f.Parse([]string{"--tristate"})
if err != nil {
t.Fatal("expected no error; got", err)
}
if tristate != triStateTrue {
t.Fatal("expected", triStateTrue, "(triStateTrue) but got", tristate, "instead")
}
}
func TestShortFlag(t *testing.T) {
var tristate triStateValue
f := setUpFlagSet(&tristate)
err := f.Parse([]string{"-t"})
if err != nil {
t.Fatal("expected no error; got", err)
}
if tristate != triStateTrue {
t.Fatal("expected", triStateTrue, "(triStateTrue) but got", tristate, "instead")
}
}
func TestShortFlagExtraArgument(t *testing.T) {
var tristate triStateValue
f := setUpFlagSet(&tristate)
// The"maybe"turns into an arg, since short boolean options will only do true/false
err := f.Parse([]string{"-t", "maybe"})
if err != nil {
t.Fatal("expected no error; got", err)
}
if tristate != triStateTrue {
t.Fatal("expected", triStateTrue, "(triStateTrue) but got", tristate, "instead")
}
args := f.Args()
if len(args) != 1 || args[0] != "maybe" {
t.Fatal("expected an extra 'maybe' argument to stick around")
}
}
func TestExplicitMaybe(t *testing.T) {
var tristate triStateValue
f := setUpFlagSet(&tristate)
err := f.Parse([]string{"--tristate=maybe"})
if err != nil {
t.Fatal("expected no error; got", err)
}
if tristate != triStateMaybe {
t.Fatal("expected", triStateMaybe, "(triStateMaybe) but got", tristate, "instead")
}
}
func TestExplicitFalse(t *testing.T) {
var tristate triStateValue
f := setUpFlagSet(&tristate)
err := f.Parse([]string{"--tristate=false"})
if err != nil {
t.Fatal("expected no error; got", err)
}
if tristate != triStateFalse {
t.Fatal("expected", triStateFalse, "(triStateFalse) but got", tristate, "instead")
}
}
func TestImplicitFalse(t *testing.T) {
var tristate triStateValue
f := setUpFlagSet(&tristate)
err := f.Parse([]string{})
if err != nil {
t.Fatal("expected no error; got", err)
}
if tristate != triStateFalse {
t.Fatal("expected", triStateFalse, "(triStateFalse) but got", tristate, "instead")
}
}
func TestInvalidValue(t *testing.T) {
var tristate triStateValue
f := setUpFlagSet(&tristate)
var buf bytes.Buffer
f.SetOutput(&buf)
err := f.Parse([]string{"--tristate=invalid"})
if err == nil {
t.Fatal("expected an error but did not get any, tristate has value", tristate)
}
}
func TestBoolP(t *testing.T) {
b := BoolP("bool", "b", false, "bool value in CommandLine")
c := BoolP("c", "c", false, "other bool value")
args := []string{"--bool"}
if err := CommandLine.Parse(args); err != nil {
t.Error("expected no error, got ", err)
}
if *b != true {
t.Errorf("expected b=true got b=%s", b)
}
if *c != false {
t.Errorf("expect c=false got c=%s", c)
}
}

View file

@ -1,77 +0,0 @@
// Copyright 2012 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// These examples demonstrate more intricate uses of the flag package.
package pflag_test
import (
"errors"
"fmt"
"strings"
"time"
flag "github.com/spf13/pflag"
)
// Example 1: A single string flag called "species" with default value "gopher".
var species = flag.String("species", "gopher", "the species we are studying")
// Example 2: A flag with a shorthand letter.
var gopherType = flag.StringP("gopher_type", "g", "pocket", "the variety of gopher")
// Example 3: A user-defined flag type, a slice of durations.
type interval []time.Duration
// String is the method to format the flag's value, part of the flag.Value interface.
// The String method's output will be used in diagnostics.
func (i *interval) String() string {
return fmt.Sprint(*i)
}
func (i *interval) Type() string {
return "interval"
}
// Set is the method to set the flag value, part of the flag.Value interface.
// Set's argument is a string to be parsed to set the flag.
// It's a comma-separated list, so we split it.
func (i *interval) Set(value string) error {
// If we wanted to allow the flag to be set multiple times,
// accumulating values, we would delete this if statement.
// That would permit usages such as
// -deltaT 10s -deltaT 15s
// and other combinations.
if len(*i) > 0 {
return errors.New("interval flag already set")
}
for _, dt := range strings.Split(value, ",") {
duration, err := time.ParseDuration(dt)
if err != nil {
return err
}
*i = append(*i, duration)
}
return nil
}
// Define a flag to accumulate durations. Because it has a special type,
// we need to use the Var function and therefore create the flag during
// init.
var intervalFlag interval
func init() {
// Tie the command-line flag to the intervalFlag variable and
// set a usage message.
flag.Var(&intervalFlag, "deltaT", "comma-separated list of intervals to use between events")
}
func Example() {
// All the interesting pieces are with the variables declared above, but
// to enable the flag package to see the flags defined there, one must
// execute, typically at the start of main (not init!):
// flag.Parse()
// We don't run it here because this is not a main function and
// the testing suite has already parsed the flags.
}

View file

@ -1,29 +0,0 @@
// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package pflag
import (
"io/ioutil"
"os"
)
// Additional routines compiled into the package only during testing.
// ResetForTesting clears all flag state and sets the usage function as directed.
// After calling ResetForTesting, parse errors in flag handling will not
// exit the program.
func ResetForTesting(usage func()) {
CommandLine = &FlagSet{
name: os.Args[0],
errorHandling: ContinueOnError,
output: ioutil.Discard,
}
Usage = usage
}
// GetCommandLine returns the default FlagSet.
func GetCommandLine() *FlagSet {
return CommandLine
}

View file

@ -1,755 +0,0 @@
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package pflag
import (
"bytes"
"fmt"
"io"
"io/ioutil"
"net"
"os"
"reflect"
"sort"
"strings"
"testing"
"time"
)
var (
test_bool = Bool("test_bool", false, "bool value")
test_int = Int("test_int", 0, "int value")
test_int64 = Int64("test_int64", 0, "int64 value")
test_uint = Uint("test_uint", 0, "uint value")
test_uint64 = Uint64("test_uint64", 0, "uint64 value")
test_string = String("test_string", "0", "string value")
test_float64 = Float64("test_float64", 0, "float64 value")
test_duration = Duration("test_duration", 0, "time.Duration value")
test_optional_int = Int("test_optional_int", 0, "optional int value")
normalizeFlagNameInvocations = 0
)
func boolString(s string) string {
if s == "0" {
return "false"
}
return "true"
}
func TestEverything(t *testing.T) {
m := make(map[string]*Flag)
desired := "0"
visitor := func(f *Flag) {
if len(f.Name) > 5 && f.Name[0:5] == "test_" {
m[f.Name] = f
ok := false
switch {
case f.Value.String() == desired:
ok = true
case f.Name == "test_bool" && f.Value.String() == boolString(desired):
ok = true
case f.Name == "test_duration" && f.Value.String() == desired+"s":
ok = true
}
if !ok {
t.Error("Visit: bad value", f.Value.String(), "for", f.Name)
}
}
}
VisitAll(visitor)
if len(m) != 9 {
t.Error("VisitAll misses some flags")
for k, v := range m {
t.Log(k, *v)
}
}
m = make(map[string]*Flag)
Visit(visitor)
if len(m) != 0 {
t.Errorf("Visit sees unset flags")
for k, v := range m {
t.Log(k, *v)
}
}
// Now set all flags
Set("test_bool", "true")
Set("test_int", "1")
Set("test_int64", "1")
Set("test_uint", "1")
Set("test_uint64", "1")
Set("test_string", "1")
Set("test_float64", "1")
Set("test_duration", "1s")
Set("test_optional_int", "1")
desired = "1"
Visit(visitor)
if len(m) != 9 {
t.Error("Visit fails after set")
for k, v := range m {
t.Log(k, *v)
}
}
// Now test they're visited in sort order.
var flagNames []string
Visit(func(f *Flag) { flagNames = append(flagNames, f.Name) })
if !sort.StringsAreSorted(flagNames) {
t.Errorf("flag names not sorted: %v", flagNames)
}
}
func TestUsage(t *testing.T) {
called := false
ResetForTesting(func() { called = true })
if GetCommandLine().Parse([]string{"--x"}) == nil {
t.Error("parse did not fail for unknown flag")
}
if !called {
t.Error("did not call Usage for unknown flag")
}
}
func TestAnnotation(t *testing.T) {
f := NewFlagSet("shorthand", ContinueOnError)
if err := f.SetAnnotation("missing-flag", "key", nil); err == nil {
t.Errorf("Expected error setting annotation on non-existent flag")
}
f.StringP("stringa", "a", "", "string value")
if err := f.SetAnnotation("stringa", "key", nil); err != nil {
t.Errorf("Unexpected error setting new nil annotation: %v", err)
}
if annotation := f.Lookup("stringa").Annotations["key"]; annotation != nil {
t.Errorf("Unexpected annotation: %v", annotation)
}
f.StringP("stringb", "b", "", "string2 value")
if err := f.SetAnnotation("stringb", "key", []string{"value1"}); err != nil {
t.Errorf("Unexpected error setting new annotation: %v", err)
}
if annotation := f.Lookup("stringb").Annotations["key"]; !reflect.DeepEqual(annotation, []string{"value1"}) {
t.Errorf("Unexpected annotation: %v", annotation)
}
if err := f.SetAnnotation("stringb", "key", []string{"value2"}); err != nil {
t.Errorf("Unexpected error updating annotation: %v", err)
}
if annotation := f.Lookup("stringb").Annotations["key"]; !reflect.DeepEqual(annotation, []string{"value2"}) {
t.Errorf("Unexpected annotation: %v", annotation)
}
}
func testParse(f *FlagSet, t *testing.T) {
if f.Parsed() {
t.Error("f.Parse() = true before Parse")
}
boolFlag := f.Bool("bool", false, "bool value")
bool2Flag := f.Bool("bool2", false, "bool2 value")
bool3Flag := f.Bool("bool3", false, "bool3 value")
intFlag := f.Int("int", 0, "int value")
int8Flag := f.Int8("int8", 0, "int value")
int32Flag := f.Int32("int32", 0, "int value")
int64Flag := f.Int64("int64", 0, "int64 value")
uintFlag := f.Uint("uint", 0, "uint value")
uint8Flag := f.Uint8("uint8", 0, "uint value")
uint16Flag := f.Uint16("uint16", 0, "uint value")
uint32Flag := f.Uint32("uint32", 0, "uint value")
uint64Flag := f.Uint64("uint64", 0, "uint64 value")
stringFlag := f.String("string", "0", "string value")
float32Flag := f.Float32("float32", 0, "float32 value")
float64Flag := f.Float64("float64", 0, "float64 value")
ipFlag := f.IP("ip", net.ParseIP("127.0.0.1"), "ip value")
maskFlag := f.IPMask("mask", ParseIPv4Mask("0.0.0.0"), "mask value")
durationFlag := f.Duration("duration", 5*time.Second, "time.Duration value")
optionalIntNoValueFlag := f.Int("optional-int-no-value", 0, "int value")
f.Lookup("optional-int-no-value").NoOptDefVal = "9"
optionalIntWithValueFlag := f.Int("optional-int-with-value", 0, "int value")
f.Lookup("optional-int-no-value").NoOptDefVal = "9"
extra := "one-extra-argument"
args := []string{
"--bool",
"--bool2=true",
"--bool3=false",
"--int=22",
"--int8=-8",
"--int32=-32",
"--int64=0x23",
"--uint", "24",
"--uint8=8",
"--uint16=16",
"--uint32=32",
"--uint64=25",
"--string=hello",
"--float32=-172e12",
"--float64=2718e28",
"--ip=10.11.12.13",
"--mask=255.255.255.0",
"--duration=2m",
"--optional-int-no-value",
"--optional-int-with-value=42",
extra,
}
if err := f.Parse(args); err != nil {
t.Fatal(err)
}
if !f.Parsed() {
t.Error("f.Parse() = false after Parse")
}
if *boolFlag != true {
t.Error("bool flag should be true, is ", *boolFlag)
}
if v, err := f.GetBool("bool"); err != nil || v != *boolFlag {
t.Error("GetBool does not work.")
}
if *bool2Flag != true {
t.Error("bool2 flag should be true, is ", *bool2Flag)
}
if *bool3Flag != false {
t.Error("bool3 flag should be false, is ", *bool2Flag)
}
if *intFlag != 22 {
t.Error("int flag should be 22, is ", *intFlag)
}
if v, err := f.GetInt("int"); err != nil || v != *intFlag {
t.Error("GetInt does not work.")
}
if *int8Flag != -8 {
t.Error("int8 flag should be 0x23, is ", *int8Flag)
}
if v, err := f.GetInt8("int8"); err != nil || v != *int8Flag {
t.Error("GetInt8 does not work.")
}
if *int32Flag != -32 {
t.Error("int32 flag should be 0x23, is ", *int32Flag)
}
if v, err := f.GetInt32("int32"); err != nil || v != *int32Flag {
t.Error("GetInt32 does not work.")
}
if *int64Flag != 0x23 {
t.Error("int64 flag should be 0x23, is ", *int64Flag)
}
if v, err := f.GetInt64("int64"); err != nil || v != *int64Flag {
t.Error("GetInt64 does not work.")
}
if *uintFlag != 24 {
t.Error("uint flag should be 24, is ", *uintFlag)
}
if v, err := f.GetUint("uint"); err != nil || v != *uintFlag {
t.Error("GetUint does not work.")
}
if *uint8Flag != 8 {
t.Error("uint8 flag should be 8, is ", *uint8Flag)
}
if v, err := f.GetUint8("uint8"); err != nil || v != *uint8Flag {
t.Error("GetUint8 does not work.")
}
if *uint16Flag != 16 {
t.Error("uint16 flag should be 16, is ", *uint16Flag)
}
if v, err := f.GetUint16("uint16"); err != nil || v != *uint16Flag {
t.Error("GetUint16 does not work.")
}
if *uint32Flag != 32 {
t.Error("uint32 flag should be 32, is ", *uint32Flag)
}
if v, err := f.GetUint32("uint32"); err != nil || v != *uint32Flag {
t.Error("GetUint32 does not work.")
}
if *uint64Flag != 25 {
t.Error("uint64 flag should be 25, is ", *uint64Flag)
}
if v, err := f.GetUint64("uint64"); err != nil || v != *uint64Flag {
t.Error("GetUint64 does not work.")
}
if *stringFlag != "hello" {
t.Error("string flag should be `hello`, is ", *stringFlag)
}
if v, err := f.GetString("string"); err != nil || v != *stringFlag {
t.Error("GetString does not work.")
}
if *float32Flag != -172e12 {
t.Error("float32 flag should be -172e12, is ", *float32Flag)
}
if v, err := f.GetFloat32("float32"); err != nil || v != *float32Flag {
t.Errorf("GetFloat32 returned %v but float32Flag was %v", v, *float32Flag)
}
if *float64Flag != 2718e28 {
t.Error("float64 flag should be 2718e28, is ", *float64Flag)
}
if v, err := f.GetFloat64("float64"); err != nil || v != *float64Flag {
t.Errorf("GetFloat64 returned %v but float64Flag was %v", v, *float64Flag)
}
if !(*ipFlag).Equal(net.ParseIP("10.11.12.13")) {
t.Error("ip flag should be 10.11.12.13, is ", *ipFlag)
}
if v, err := f.GetIP("ip"); err != nil || !v.Equal(*ipFlag) {
t.Errorf("GetIP returned %v but ipFlag was %v", v, *ipFlag)
}
if (*maskFlag).String() != ParseIPv4Mask("255.255.255.0").String() {
t.Error("mask flag should be 255.255.255.0, is ", (*maskFlag).String())
}
if v, err := f.GetIPv4Mask("mask"); err != nil || v.String() != (*maskFlag).String() {
t.Errorf("GetIP returned %v but maskFlag was %v", v, *maskFlag, err)
}
if *durationFlag != 2*time.Minute {
t.Error("duration flag should be 2m, is ", *durationFlag)
}
if v, err := f.GetDuration("duration"); err != nil || v != *durationFlag {
t.Error("GetDuration does not work.")
}
if _, err := f.GetInt("duration"); err == nil {
t.Error("GetInt parsed a time.Duration?!?!")
}
if *optionalIntNoValueFlag != 9 {
t.Error("optional int flag should be the default value, is ", *optionalIntNoValueFlag)
}
if *optionalIntWithValueFlag != 42 {
t.Error("optional int flag should be 42, is ", *optionalIntWithValueFlag)
}
if len(f.Args()) != 1 {
t.Error("expected one argument, got", len(f.Args()))
} else if f.Args()[0] != extra {
t.Errorf("expected argument %q got %q", extra, f.Args()[0])
}
}
func TestShorthand(t *testing.T) {
f := NewFlagSet("shorthand", ContinueOnError)
if f.Parsed() {
t.Error("f.Parse() = true before Parse")
}
boolaFlag := f.BoolP("boola", "a", false, "bool value")
boolbFlag := f.BoolP("boolb", "b", false, "bool2 value")
boolcFlag := f.BoolP("boolc", "c", false, "bool3 value")
booldFlag := f.BoolP("boold", "d", false, "bool4 value")
stringaFlag := f.StringP("stringa", "s", "0", "string value")
stringzFlag := f.StringP("stringz", "z", "0", "string value")
extra := "interspersed-argument"
notaflag := "--i-look-like-a-flag"
args := []string{
"-ab",
extra,
"-cs",
"hello",
"-z=something",
"-d=true",
"--",
notaflag,
}
f.SetOutput(ioutil.Discard)
if err := f.Parse(args); err != nil {
t.Error("expected no error, got ", err)
}
if !f.Parsed() {
t.Error("f.Parse() = false after Parse")
}
if *boolaFlag != true {
t.Error("boola flag should be true, is ", *boolaFlag)
}
if *boolbFlag != true {
t.Error("boolb flag should be true, is ", *boolbFlag)
}
if *boolcFlag != true {
t.Error("boolc flag should be true, is ", *boolcFlag)
}
if *booldFlag != true {
t.Error("boold flag should be true, is ", *booldFlag)
}
if *stringaFlag != "hello" {
t.Error("stringa flag should be `hello`, is ", *stringaFlag)
}
if *stringzFlag != "something" {
t.Error("stringz flag should be `something`, is ", *stringzFlag)
}
if len(f.Args()) != 2 {
t.Error("expected one argument, got", len(f.Args()))
} else if f.Args()[0] != extra {
t.Errorf("expected argument %q got %q", extra, f.Args()[0])
} else if f.Args()[1] != notaflag {
t.Errorf("expected argument %q got %q", notaflag, f.Args()[1])
}
}
func TestParse(t *testing.T) {
ResetForTesting(func() { t.Error("bad parse") })
testParse(GetCommandLine(), t)
}
func TestFlagSetParse(t *testing.T) {
testParse(NewFlagSet("test", ContinueOnError), t)
}
func replaceSeparators(name string, from []string, to string) string {
result := name
for _, sep := range from {
result = strings.Replace(result, sep, to, -1)
}
// Type convert to indicate normalization has been done.
return result
}
func wordSepNormalizeFunc(f *FlagSet, name string) NormalizedName {
seps := []string{"-", "_"}
name = replaceSeparators(name, seps, ".")
normalizeFlagNameInvocations++
return NormalizedName(name)
}
func testWordSepNormalizedNames(args []string, t *testing.T) {
f := NewFlagSet("normalized", ContinueOnError)
if f.Parsed() {
t.Error("f.Parse() = true before Parse")
}
withDashFlag := f.Bool("with-dash-flag", false, "bool value")
// Set this after some flags have been added and before others.
f.SetNormalizeFunc(wordSepNormalizeFunc)
withUnderFlag := f.Bool("with_under_flag", false, "bool value")
withBothFlag := f.Bool("with-both_flag", false, "bool value")
if err := f.Parse(args); err != nil {
t.Fatal(err)
}
if !f.Parsed() {
t.Error("f.Parse() = false after Parse")
}
if *withDashFlag != true {
t.Error("withDashFlag flag should be true, is ", *withDashFlag)
}
if *withUnderFlag != true {
t.Error("withUnderFlag flag should be true, is ", *withUnderFlag)
}
if *withBothFlag != true {
t.Error("withBothFlag flag should be true, is ", *withBothFlag)
}
}
func TestWordSepNormalizedNames(t *testing.T) {
args := []string{
"--with-dash-flag",
"--with-under-flag",
"--with-both-flag",
}
testWordSepNormalizedNames(args, t)
args = []string{
"--with_dash_flag",
"--with_under_flag",
"--with_both_flag",
}
testWordSepNormalizedNames(args, t)
args = []string{
"--with-dash_flag",
"--with-under_flag",
"--with-both_flag",
}
testWordSepNormalizedNames(args, t)
}
func aliasAndWordSepFlagNames(f *FlagSet, name string) NormalizedName {
seps := []string{"-", "_"}
oldName := replaceSeparators("old-valid_flag", seps, ".")
newName := replaceSeparators("valid-flag", seps, ".")
name = replaceSeparators(name, seps, ".")
switch name {
case oldName:
name = newName
break
}
return NormalizedName(name)
}
func TestCustomNormalizedNames(t *testing.T) {
f := NewFlagSet("normalized", ContinueOnError)
if f.Parsed() {
t.Error("f.Parse() = true before Parse")
}
validFlag := f.Bool("valid-flag", false, "bool value")
f.SetNormalizeFunc(aliasAndWordSepFlagNames)
someOtherFlag := f.Bool("some-other-flag", false, "bool value")
args := []string{"--old_valid_flag", "--some-other_flag"}
if err := f.Parse(args); err != nil {
t.Fatal(err)
}
if *validFlag != true {
t.Errorf("validFlag is %v even though we set the alias --old_valid_falg", *validFlag)
}
if *someOtherFlag != true {
t.Error("someOtherFlag should be true, is ", *someOtherFlag)
}
}
// Every flag we add, the name (displayed also in usage) should normalized
func TestNormalizationFuncShouldChangeFlagName(t *testing.T) {
// Test normalization after addition
f := NewFlagSet("normalized", ContinueOnError)
f.Bool("valid_flag", false, "bool value")
if f.Lookup("valid_flag").Name != "valid_flag" {
t.Error("The new flag should have the name 'valid_flag' instead of ", f.Lookup("valid_flag").Name)
}
f.SetNormalizeFunc(wordSepNormalizeFunc)
if f.Lookup("valid_flag").Name != "valid.flag" {
t.Error("The new flag should have the name 'valid.flag' instead of ", f.Lookup("valid_flag").Name)
}
// Test normalization before addition
f = NewFlagSet("normalized", ContinueOnError)
f.SetNormalizeFunc(wordSepNormalizeFunc)
f.Bool("valid_flag", false, "bool value")
if f.Lookup("valid_flag").Name != "valid.flag" {
t.Error("The new flag should have the name 'valid.flag' instead of ", f.Lookup("valid_flag").Name)
}
}
// Declare a user-defined flag type.
type flagVar []string
func (f *flagVar) String() string {
return fmt.Sprint([]string(*f))
}
func (f *flagVar) Set(value string) error {
*f = append(*f, value)
return nil
}
func (f *flagVar) Type() string {
return "flagVar"
}
func TestUserDefined(t *testing.T) {
var flags FlagSet
flags.Init("test", ContinueOnError)
var v flagVar
flags.VarP(&v, "v", "v", "usage")
if err := flags.Parse([]string{"--v=1", "-v2", "-v", "3"}); err != nil {
t.Error(err)
}
if len(v) != 3 {
t.Fatal("expected 3 args; got ", len(v))
}
expect := "[1 2 3]"
if v.String() != expect {
t.Errorf("expected value %q got %q", expect, v.String())
}
}
func TestSetOutput(t *testing.T) {
var flags FlagSet
var buf bytes.Buffer
flags.SetOutput(&buf)
flags.Init("test", ContinueOnError)
flags.Parse([]string{"--unknown"})
if out := buf.String(); !strings.Contains(out, "--unknown") {
t.Logf("expected output mentioning unknown; got %q", out)
}
}
// This tests that one can reset the flags. This still works but not well, and is
// superseded by FlagSet.
func TestChangingArgs(t *testing.T) {
ResetForTesting(func() { t.Fatal("bad parse") })
oldArgs := os.Args
defer func() { os.Args = oldArgs }()
os.Args = []string{"cmd", "--before", "subcmd"}
before := Bool("before", false, "")
if err := GetCommandLine().Parse(os.Args[1:]); err != nil {
t.Fatal(err)
}
cmd := Arg(0)
os.Args = []string{"subcmd", "--after", "args"}
after := Bool("after", false, "")
Parse()
args := Args()
if !*before || cmd != "subcmd" || !*after || len(args) != 1 || args[0] != "args" {
t.Fatalf("expected true subcmd true [args] got %v %v %v %v", *before, cmd, *after, args)
}
}
// Test that -help invokes the usage message and returns ErrHelp.
func TestHelp(t *testing.T) {
var helpCalled = false
fs := NewFlagSet("help test", ContinueOnError)
fs.Usage = func() { helpCalled = true }
var flag bool
fs.BoolVar(&flag, "flag", false, "regular flag")
// Regular flag invocation should work
err := fs.Parse([]string{"--flag=true"})
if err != nil {
t.Fatal("expected no error; got ", err)
}
if !flag {
t.Error("flag was not set by --flag")
}
if helpCalled {
t.Error("help called for regular flag")
helpCalled = false // reset for next test
}
// Help flag should work as expected.
err = fs.Parse([]string{"--help"})
if err == nil {
t.Fatal("error expected")
}
if err != ErrHelp {
t.Fatal("expected ErrHelp; got ", err)
}
if !helpCalled {
t.Fatal("help was not called")
}
// If we define a help flag, that should override.
var help bool
fs.BoolVar(&help, "help", false, "help flag")
helpCalled = false
err = fs.Parse([]string{"--help"})
if err != nil {
t.Fatal("expected no error for defined --help; got ", err)
}
if helpCalled {
t.Fatal("help was called; should not have been for defined help flag")
}
}
func TestNoInterspersed(t *testing.T) {
f := NewFlagSet("test", ContinueOnError)
f.SetInterspersed(false)
f.Bool("true", true, "always true")
f.Bool("false", false, "always false")
err := f.Parse([]string{"--true", "break", "--false"})
if err != nil {
t.Fatal("expected no error; got ", err)
}
args := f.Args()
if len(args) != 2 || args[0] != "break" || args[1] != "--false" {
t.Fatal("expected interspersed options/non-options to fail")
}
}
func TestTermination(t *testing.T) {
f := NewFlagSet("termination", ContinueOnError)
boolFlag := f.BoolP("bool", "l", false, "bool value")
if f.Parsed() {
t.Error("f.Parse() = true before Parse")
}
arg1 := "ls"
arg2 := "-l"
args := []string{
"--",
arg1,
arg2,
}
f.SetOutput(ioutil.Discard)
if err := f.Parse(args); err != nil {
t.Fatal("expected no error; got ", err)
}
if !f.Parsed() {
t.Error("f.Parse() = false after Parse")
}
if *boolFlag {
t.Error("expected boolFlag=false, got true")
}
if len(f.Args()) != 2 {
t.Errorf("expected 2 arguments, got %d: %v", len(f.Args()), f.Args())
}
if f.Args()[0] != arg1 {
t.Errorf("expected argument %q got %q", arg1, f.Args()[0])
}
if f.Args()[1] != arg2 {
t.Errorf("expected argument %q got %q", arg2, f.Args()[1])
}
}
func TestDeprecatedFlagInDocs(t *testing.T) {
f := NewFlagSet("bob", ContinueOnError)
f.Bool("badflag", true, "always true")
f.MarkDeprecated("badflag", "use --good-flag instead")
out := new(bytes.Buffer)
f.SetOutput(out)
f.PrintDefaults()
if strings.Contains(out.String(), "badflag") {
t.Errorf("found deprecated flag in usage!")
}
}
func parseReturnStderr(t *testing.T, f *FlagSet, args []string) (string, error) {
oldStderr := os.Stderr
r, w, _ := os.Pipe()
os.Stderr = w
err := f.Parse(args)
outC := make(chan string)
// copy the output in a separate goroutine so printing can't block indefinitely
go func() {
var buf bytes.Buffer
io.Copy(&buf, r)
outC <- buf.String()
}()
w.Close()
os.Stderr = oldStderr
out := <-outC
return out, err
}
func TestDeprecatedFlagUsage(t *testing.T) {
f := NewFlagSet("bob", ContinueOnError)
f.Bool("badflag", true, "always true")
usageMsg := "use --good-flag instead"
f.MarkDeprecated("badflag", usageMsg)
args := []string{"--badflag"}
out, err := parseReturnStderr(t, f, args)
if err != nil {
t.Fatal("expected no error; got ", err)
}
if !strings.Contains(out, usageMsg) {
t.Errorf("usageMsg not printed when using a deprecated flag!")
}
}
func TestDeprecatedFlagUsageNormalized(t *testing.T) {
f := NewFlagSet("bob", ContinueOnError)
f.Bool("bad-double_flag", true, "always true")
f.SetNormalizeFunc(wordSepNormalizeFunc)
usageMsg := "use --good-flag instead"
f.MarkDeprecated("bad_double-flag", usageMsg)
args := []string{"--bad_double_flag"}
out, err := parseReturnStderr(t, f, args)
if err != nil {
t.Fatal("expected no error; got ", err)
}
if !strings.Contains(out, usageMsg) {
t.Errorf("usageMsg not printed when using a deprecated flag!")
}
}
// Name normalization function should be called only once on flag addition
func TestMultipleNormalizeFlagNameInvocations(t *testing.T) {
normalizeFlagNameInvocations = 0
f := NewFlagSet("normalized", ContinueOnError)
f.SetNormalizeFunc(wordSepNormalizeFunc)
f.Bool("with_under_flag", false, "bool value")
if normalizeFlagNameInvocations != 1 {
t.Fatal("Expected normalizeFlagNameInvocations to be 1; got ", normalizeFlagNameInvocations)
}
}

View file

@ -1,49 +0,0 @@
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package pflag
import (
"fmt"
"strconv"
"strings"
"testing"
)
func setUpISFlagSet(isp *[]int) *FlagSet {
f := NewFlagSet("test", ContinueOnError)
f.IntSliceVar(isp, "is", []int{}, "Command seperated list!")
return f
}
func TestIS(t *testing.T) {
var is []int
f := setUpISFlagSet(&is)
vals := []string{"1", "2", "4", "3"}
arg := fmt.Sprintf("--is=%s", strings.Join(vals, ","))
err := f.Parse([]string{arg})
if err != nil {
t.Fatal("expected no error; got", err)
}
for i, v := range is {
d, err := strconv.Atoi(vals[i])
if err != nil {
t.Fatalf("got error: %v", err)
}
if d != v {
t.Fatalf("expected is[%d] to be %s but got: %d", i, vals[i], v)
}
}
getIS, err := f.GetIntSlice("is")
for i, v := range getIS {
d, err := strconv.Atoi(vals[i])
if err != nil {
t.Fatalf("got error: %v", err)
}
if d != v {
t.Fatalf("expected is[%d] to be %s but got: %d from GetIntSlice", i, vals[i], v)
}
}
}

View file

@ -1,44 +0,0 @@
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package pflag
import (
"fmt"
"strings"
"testing"
)
func setUpSSFlagSet(ssp *[]string) *FlagSet {
f := NewFlagSet("test", ContinueOnError)
f.StringSliceVar(ssp, "ss", []string{}, "Command seperated list!")
return f
}
func TestSS(t *testing.T) {
var ss []string
f := setUpSSFlagSet(&ss)
vals := []string{"one", "two", "4", "3"}
arg := fmt.Sprintf("--ss=%s", strings.Join(vals, ","))
err := f.Parse([]string{arg})
if err != nil {
t.Fatal("expected no error; got", err)
}
for i, v := range ss {
if vals[i] != v {
t.Fatal("expected ss[%d] to be %s but got: %s", i, vals[i], v)
}
}
getSS, err := f.GetStringSlice("ss")
if err != nil {
t.Fatal("got an error from GetStringSlice(): %v", err)
}
for i, v := range getSS {
if vals[i] != v {
t.Fatal("expected ss[%d] to be %s from GetStringSlice but got: %s", i, vals[i], v)
}
}
}

View file

@ -1,14 +0,0 @@
Copyright (c) 2013 Vaughan Newton
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated
documentation files (the "Software"), to deal in the Software without restriction, including without limitation the
rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit
persons to whom the Software is furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the
Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE
WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

View file

@ -1,70 +0,0 @@
go-ini
======
INI parsing library for Go (golang).
View the API documentation [here](http://godoc.org/github.com/vaughan0/go-ini).
Usage
-----
Parse an INI file:
```go
import "github.com/vaughan0/go-ini"
file, err := ini.LoadFile("myfile.ini")
```
Get data from the parsed file:
```go
name, ok := file.Get("person", "name")
if !ok {
panic("'name' variable missing from 'person' section")
}
```
Iterate through values in a section:
```go
for key, value := range file["mysection"] {
fmt.Printf("%s => %s\n", key, value)
}
```
Iterate through sections in a file:
```go
for name, section := range file {
fmt.Printf("Section name: %s\n", name)
}
```
File Format
-----------
INI files are parsed by go-ini line-by-line. Each line may be one of the following:
* A section definition: [section-name]
* A property: key = value
* A comment: #blahblah _or_ ;blahblah
* Blank. The line will be ignored.
Properties defined before any section headers are placed in the default section, which has
the empty string as it's key.
Example:
```ini
# I am a comment
; So am I!
[apples]
colour = red or green
shape = applish
[oranges]
shape = square
colour = blue
```

View file

@ -1,123 +0,0 @@
// Package ini provides functions for parsing INI configuration files.
package ini
import (
"bufio"
"fmt"
"io"
"os"
"regexp"
"strings"
)
var (
sectionRegex = regexp.MustCompile(`^\[(.*)\]$`)
assignRegex = regexp.MustCompile(`^([^=]+)=(.*)$`)
)
// ErrSyntax is returned when there is a syntax error in an INI file.
type ErrSyntax struct {
Line int
Source string // The contents of the erroneous line, without leading or trailing whitespace
}
func (e ErrSyntax) Error() string {
return fmt.Sprintf("invalid INI syntax on line %d: %s", e.Line, e.Source)
}
// A File represents a parsed INI file.
type File map[string]Section
// A Section represents a single section of an INI file.
type Section map[string]string
// Returns a named Section. A Section will be created if one does not already exist for the given name.
func (f File) Section(name string) Section {
section := f[name]
if section == nil {
section = make(Section)
f[name] = section
}
return section
}
// Looks up a value for a key in a section and returns that value, along with a boolean result similar to a map lookup.
func (f File) Get(section, key string) (value string, ok bool) {
if s := f[section]; s != nil {
value, ok = s[key]
}
return
}
// Loads INI data from a reader and stores the data in the File.
func (f File) Load(in io.Reader) (err error) {
bufin, ok := in.(*bufio.Reader)
if !ok {
bufin = bufio.NewReader(in)
}
return parseFile(bufin, f)
}
// Loads INI data from a named file and stores the data in the File.
func (f File) LoadFile(file string) (err error) {
in, err := os.Open(file)
if err != nil {
return
}
defer in.Close()
return f.Load(in)
}
func parseFile(in *bufio.Reader, file File) (err error) {
section := ""
lineNum := 0
for done := false; !done; {
var line string
if line, err = in.ReadString('\n'); err != nil {
if err == io.EOF {
done = true
} else {
return
}
}
lineNum++
line = strings.TrimSpace(line)
if len(line) == 0 {
// Skip blank lines
continue
}
if line[0] == ';' || line[0] == '#' {
// Skip comments
continue
}
if groups := assignRegex.FindStringSubmatch(line); groups != nil {
key, val := groups[1], groups[2]
key, val = strings.TrimSpace(key), strings.TrimSpace(val)
file.Section(section)[key] = val
} else if groups := sectionRegex.FindStringSubmatch(line); groups != nil {
name := strings.TrimSpace(groups[1])
section = name
// Create the section if it does not exist
file.Section(section)
} else {
return ErrSyntax{lineNum, line}
}
}
return nil
}
// Loads and returns a File from a reader.
func Load(in io.Reader) (File, error) {
file := make(File)
err := file.Load(in)
return file, err
}
// Loads and returns an INI File from a file on disk.
func LoadFile(filename string) (File, error) {
file := make(File)
err := file.LoadFile(filename)
return file, err
}

View file

@ -1,43 +0,0 @@
package ini
import (
"reflect"
"syscall"
"testing"
)
func TestLoadFile(t *testing.T) {
originalOpenFiles := numFilesOpen(t)
file, err := LoadFile("test.ini")
if err != nil {
t.Fatal(err)
}
if originalOpenFiles != numFilesOpen(t) {
t.Error("test.ini not closed")
}
if !reflect.DeepEqual(file, File{"default": {"stuff": "things"}}) {
t.Error("file not read correctly")
}
}
func numFilesOpen(t *testing.T) (num uint64) {
var rlimit syscall.Rlimit
err := syscall.Getrlimit(syscall.RLIMIT_NOFILE, &rlimit)
if err != nil {
t.Fatal(err)
}
maxFds := int(rlimit.Cur)
var stat syscall.Stat_t
for i := 0; i < maxFds; i++ {
if syscall.Fstat(i, &stat) == nil {
num++
} else {
return
}
}
return
}

View file

@ -1,89 +0,0 @@
package ini
import (
"reflect"
"strings"
"testing"
)
func TestLoad(t *testing.T) {
src := `
# Comments are ignored
herp = derp
[foo]
hello=world
whitespace should = not matter
; sneaky semicolon-style comment
multiple = equals = signs
[bar]
this = that`
file, err := Load(strings.NewReader(src))
if err != nil {
t.Fatal(err)
}
check := func(section, key, expect string) {
if value, _ := file.Get(section, key); value != expect {
t.Errorf("Get(%q, %q): expected %q, got %q", section, key, expect, value)
}
}
check("", "herp", "derp")
check("foo", "hello", "world")
check("foo", "whitespace should", "not matter")
check("foo", "multiple", "equals = signs")
check("bar", "this", "that")
}
func TestSyntaxError(t *testing.T) {
src := `
# Line 2
[foo]
bar = baz
# Here's an error on line 6:
wut?
herp = derp`
_, err := Load(strings.NewReader(src))
t.Logf("%T: %v", err, err)
if err == nil {
t.Fatal("expected an error, got nil")
}
syntaxErr, ok := err.(ErrSyntax)
if !ok {
t.Fatal("expected an error of type ErrSyntax")
}
if syntaxErr.Line != 6 {
t.Fatal("incorrect line number")
}
if syntaxErr.Source != "wut?" {
t.Fatal("incorrect source")
}
}
func TestDefinedSectionBehaviour(t *testing.T) {
check := func(src string, expect File) {
file, err := Load(strings.NewReader(src))
if err != nil {
t.Fatal(err)
}
if !reflect.DeepEqual(file, expect) {
t.Errorf("expected %v, got %v", expect, file)
}
}
// No sections for an empty file
check("", File{})
// Default section only if there are actually values for it
check("foo=bar", File{"": {"foo": "bar"}})
// User-defined sections should always be present, even if empty
check("[a]\n[b]\nfoo=bar", File{
"a": {},
"b": {"foo": "bar"},
})
check("foo=bar\n[a]\nthis=that", File{
"": {"foo": "bar"},
"a": {"this": "that"},
})
}

View file

@ -1,2 +0,0 @@
[default]
stuff = things

4
s3.go
View file

@ -11,6 +11,7 @@ import (
"time" "time"
"github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/s3" "github.com/aws/aws-sdk-go/service/s3"
) )
@ -20,8 +21,9 @@ type s3Provider struct {
} }
func newS3Provider() (*s3Provider, error) { func newS3Provider() (*s3Provider, error) {
sess := session.Must(session.NewSession())
return &s3Provider{ return &s3Provider{
conn: s3.New(&aws.Config{}), conn: s3.New(sess),
}, nil }, nil
} }

202
vendor/github.com/aws/aws-sdk-go/LICENSE.txt generated vendored Normal file
View file

@ -0,0 +1,202 @@
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

3
vendor/github.com/aws/aws-sdk-go/NOTICE.txt generated vendored Normal file
View file

@ -0,0 +1,3 @@
AWS SDK for Go
Copyright 2015 Amazon.com, Inc. or its affiliates. All Rights Reserved.
Copyright 2014-2015 Stripe, Inc.

View file

@ -14,13 +14,13 @@ package awserr
// if err != nil { // if err != nil {
// if awsErr, ok := err.(awserr.Error); ok { // if awsErr, ok := err.(awserr.Error); ok {
// // Get error details // // Get error details
// log.Println("Error:", err.Code(), err.Message()) // log.Println("Error:", awsErr.Code(), awsErr.Message())
// //
// // Prints out full error message, including original error if there was one. // // Prints out full error message, including original error if there was one.
// log.Println("Error:", err.Error()) // log.Println("Error:", awsErr.Error())
// //
// // Get original error // // Get original error
// if origErr := err.Err(); origErr != nil { // if origErr := awsErr.OrigErr(); origErr != nil {
// // operate on original error. // // operate on original error.
// } // }
// } else { // } else {
@ -42,15 +42,55 @@ type Error interface {
OrigErr() error OrigErr() error
} }
// BatchError is a batch of errors which also wraps lower level errors with
// code, message, and original errors. Calling Error() will include all errors
// that occurred in the batch.
//
// Deprecated: Replaced with BatchedErrors. Only defined for backwards
// compatibility.
type BatchError interface {
// Satisfy the generic error interface.
error
// Returns the short phrase depicting the classification of the error.
Code() string
// Returns the error details message.
Message() string
// Returns the original error if one was set. Nil is returned if not set.
OrigErrs() []error
}
// BatchedErrors is a batch of errors which also wraps lower level errors with
// code, message, and original errors. Calling Error() will include all errors
// that occurred in the batch.
//
// Replaces BatchError
type BatchedErrors interface {
// Satisfy the base Error interface.
Error
// Returns the original error if one was set. Nil is returned if not set.
OrigErrs() []error
}
// New returns an Error object described by the code, message, and origErr. // New returns an Error object described by the code, message, and origErr.
// //
// If origErr satisfies the Error interface it will not be wrapped within a new // If origErr satisfies the Error interface it will not be wrapped within a new
// Error object and will instead be returned. // Error object and will instead be returned.
func New(code, message string, origErr error) Error { func New(code, message string, origErr error) Error {
if e, ok := origErr.(Error); ok && e != nil { var errs []error
return e if origErr != nil {
errs = append(errs, origErr)
} }
return newBaseError(code, message, origErr) return newBaseError(code, message, errs)
}
// NewBatchError returns an BatchedErrors with a collection of errors as an
// array of errors.
func NewBatchError(code, message string, errs []error) BatchedErrors {
return newBaseError(code, message, errs)
} }
// A RequestFailure is an interface to extract request failure information from // A RequestFailure is an interface to extract request failure information from
@ -63,9 +103,9 @@ func New(code, message string, origErr error) Error {
// output, err := s3manage.Upload(svc, input, opts) // output, err := s3manage.Upload(svc, input, opts)
// if err != nil { // if err != nil {
// if reqerr, ok := err.(RequestFailure); ok { // if reqerr, ok := err.(RequestFailure); ok {
// log.Printf("Request failed", reqerr.Code(), reqerr.Message(), reqerr.RequestID()) // log.Println("Request failed", reqerr.Code(), reqerr.Message(), reqerr.RequestID())
// } else { // } else {
// log.Printf("Error:", err.Error() // log.Println("Error:", err.Error())
// } // }
// } // }
// //

View file

@ -31,23 +31,27 @@ type baseError struct {
// Optional original error this error is based off of. Allows building // Optional original error this error is based off of. Allows building
// chained errors. // chained errors.
origErr error errs []error
} }
// newBaseError returns an error object for the code, message, and err. // newBaseError returns an error object for the code, message, and errors.
// //
// code is a short no whitespace phrase depicting the classification of // code is a short no whitespace phrase depicting the classification of
// the error that is being created. // the error that is being created.
// //
// message is the free flow string containing detailed information about the error. // message is the free flow string containing detailed information about the
// error.
// //
// origErr is the error object which will be nested under the new error to be returned. // origErrs is the error objects which will be nested under the new errors to
func newBaseError(code, message string, origErr error) *baseError { // be returned.
return &baseError{ func newBaseError(code, message string, origErrs []error) *baseError {
b := &baseError{
code: code, code: code,
message: message, message: message,
origErr: origErr, errs: origErrs,
} }
return b
} }
// Error returns the string representation of the error. // Error returns the string representation of the error.
@ -56,7 +60,12 @@ func newBaseError(code, message string, origErr error) *baseError {
// //
// Satisfies the error interface. // Satisfies the error interface.
func (b baseError) Error() string { func (b baseError) Error() string {
return SprintError(b.code, b.message, "", b.origErr) size := len(b.errs)
if size > 0 {
return SprintError(b.code, b.message, "", errorList(b.errs))
}
return SprintError(b.code, b.message, "", nil)
} }
// String returns the string representation of the error. // String returns the string representation of the error.
@ -75,10 +84,28 @@ func (b baseError) Message() string {
return b.message return b.message
} }
// OrigErr returns the original error if one was set. Nil is returned if no error // OrigErr returns the original error if one was set. Nil is returned if no
// was set. // error was set. This only returns the first element in the list. If the full
// list is needed, use BatchedErrors.
func (b baseError) OrigErr() error { func (b baseError) OrigErr() error {
return b.origErr switch len(b.errs) {
case 0:
return nil
case 1:
return b.errs[0]
default:
if err, ok := b.errs[0].(Error); ok {
return NewBatchError(err.Code(), err.Message(), b.errs[1:])
}
return NewBatchError("BatchedErrors",
"multiple errors occurred", b.errs)
}
}
// OrigErrs returns the original errors if one was set. An empty slice is
// returned if no error was set.
func (b baseError) OrigErrs() []error {
return b.errs
} }
// So that the Error interface type can be included as an anonymous field // So that the Error interface type can be included as an anonymous field
@ -94,8 +121,8 @@ type requestError struct {
requestID string requestID string
} }
// newRequestError returns a wrapped error with additional information for request // newRequestError returns a wrapped error with additional information for
// status code, and service requestID. // request status code, and service requestID.
// //
// Should be used to wrap all request which involve service requests. Even if // Should be used to wrap all request which involve service requests. Even if
// the request failed without a service response, but had an HTTP status code // the request failed without a service response, but had an HTTP status code
@ -113,7 +140,7 @@ func newRequestError(err Error, statusCode int, requestID string) *requestError
// Error returns the string representation of the error. // Error returns the string representation of the error.
// Satisfies the error interface. // Satisfies the error interface.
func (r requestError) Error() string { func (r requestError) Error() string {
extra := fmt.Sprintf("status code: %d, request id: [%s]", extra := fmt.Sprintf("status code: %d, request id: %s",
r.statusCode, r.requestID) r.statusCode, r.requestID)
return SprintError(r.Code(), r.Message(), extra, r.OrigErr()) return SprintError(r.Code(), r.Message(), extra, r.OrigErr())
} }
@ -133,3 +160,35 @@ func (r requestError) StatusCode() int {
func (r requestError) RequestID() string { func (r requestError) RequestID() string {
return r.requestID return r.requestID
} }
// OrigErrs returns the original errors if one was set. An empty slice is
// returned if no error was set.
func (r requestError) OrigErrs() []error {
if b, ok := r.awsError.(BatchedErrors); ok {
return b.OrigErrs()
}
return []error{r.OrigErr()}
}
// An error list that satisfies the golang interface
type errorList []error
// Error returns the string representation of the error.
//
// Satisfies the error interface.
func (e errorList) Error() string {
msg := ""
// How do we want to handle the array size being zero
if size := len(e); size > 0 {
for i := 0; i < size; i++ {
msg += fmt.Sprintf("%s", e[i].Error())
// We check the next index to see if it is within the slice.
// If it is, then we append a newline. We do this, because unit tests
// could be broken with the additional '\n'
if i+1 < size {
msg += "\n"
}
}
}
return msg
}

View file

@ -3,6 +3,7 @@ package awsutil
import ( import (
"io" "io"
"reflect" "reflect"
"time"
) )
// Copy deeply copies a src structure to dst. Useful for copying request and // Copy deeply copies a src structure to dst. Useful for copying request and
@ -49,7 +50,14 @@ func rcopy(dst, src reflect.Value, root bool) {
} else { } else {
e := src.Type().Elem() e := src.Type().Elem()
if dst.CanSet() && !src.IsNil() { if dst.CanSet() && !src.IsNil() {
if _, ok := src.Interface().(*time.Time); !ok {
dst.Set(reflect.New(e)) dst.Set(reflect.New(e))
} else {
tempValue := reflect.New(e)
tempValue.Elem().Set(src.Elem())
// Sets time.Time's unexported values
dst.Set(tempValue)
}
} }
if src.Elem().IsValid() { if src.Elem().IsValid() {
// Keep the current root state since the depth hasn't changed // Keep the current root state since the depth hasn't changed
@ -57,16 +65,13 @@ func rcopy(dst, src reflect.Value, root bool) {
} }
} }
case reflect.Struct: case reflect.Struct:
if !root {
dst.Set(reflect.New(src.Type()).Elem())
}
t := dst.Type() t := dst.Type()
for i := 0; i < t.NumField(); i++ { for i := 0; i < t.NumField(); i++ {
name := t.Field(i).Name name := t.Field(i).Name
srcval := src.FieldByName(name) srcVal := src.FieldByName(name)
if srcval.IsValid() { dstVal := dst.FieldByName(name)
rcopy(dst.FieldByName(name), srcval, false) if srcVal.IsValid() && dstVal.CanSet() {
rcopy(dstVal, srcVal, false)
} }
} }
case reflect.Slice: case reflect.Slice:

27
vendor/github.com/aws/aws-sdk-go/aws/awsutil/equal.go generated vendored Normal file
View file

@ -0,0 +1,27 @@
package awsutil
import (
"reflect"
)
// DeepEqual returns if the two values are deeply equal like reflect.DeepEqual.
// In addition to this, this method will also dereference the input values if
// possible so the DeepEqual performed will not fail if one parameter is a
// pointer and the other is not.
//
// DeepEqual will not perform indirection of nested values of the input parameters.
func DeepEqual(a, b interface{}) bool {
ra := reflect.Indirect(reflect.ValueOf(a))
rb := reflect.Indirect(reflect.ValueOf(b))
if raValid, rbValid := ra.IsValid(), rb.IsValid(); !raValid && !rbValid {
// If the elements are both nil, and of the same type the are equal
// If they are of different types they are not equal
return reflect.TypeOf(a) == reflect.TypeOf(b)
} else if raValid != rbValid {
// Both values must be valid to be equal
return false
}
return reflect.DeepEqual(ra.Interface(), rb.Interface())
}

View file

@ -5,18 +5,20 @@ import (
"regexp" "regexp"
"strconv" "strconv"
"strings" "strings"
"github.com/jmespath/go-jmespath"
) )
var indexRe = regexp.MustCompile(`(.+)\[(-?\d+)?\]$`) var indexRe = regexp.MustCompile(`(.+)\[(-?\d+)?\]$`)
// rValuesAtPath returns a slice of values found in value v. The values // rValuesAtPath returns a slice of values found in value v. The values
// in v are explored recursively so all nested values are collected. // in v are explored recursively so all nested values are collected.
func rValuesAtPath(v interface{}, path string, create bool, caseSensitive bool) []reflect.Value { func rValuesAtPath(v interface{}, path string, createPath, caseSensitive, nilTerm bool) []reflect.Value {
pathparts := strings.Split(path, "||") pathparts := strings.Split(path, "||")
if len(pathparts) > 1 { if len(pathparts) > 1 {
for _, pathpart := range pathparts { for _, pathpart := range pathparts {
vals := rValuesAtPath(v, pathpart, create, caseSensitive) vals := rValuesAtPath(v, pathpart, createPath, caseSensitive, nilTerm)
if vals != nil && len(vals) > 0 { if len(vals) > 0 {
return vals return vals
} }
} }
@ -74,7 +76,16 @@ func rValuesAtPath(v interface{}, path string, create bool, caseSensitive bool)
return false return false
}) })
if create && value.Kind() == reflect.Ptr && value.IsNil() { if nilTerm && value.Kind() == reflect.Ptr && len(components[1:]) == 0 {
if !value.IsNil() {
value.Set(reflect.Zero(value.Type()))
}
return []reflect.Value{value}
}
if createPath && value.Kind() == reflect.Ptr && value.IsNil() {
// TODO if the value is the terminus it should not be created
// if the value to be set to its position is nil.
value.Set(reflect.New(value.Type().Elem())) value.Set(reflect.New(value.Type().Elem()))
value = value.Elem() value = value.Elem()
} else { } else {
@ -82,7 +93,7 @@ func rValuesAtPath(v interface{}, path string, create bool, caseSensitive bool)
} }
if value.Kind() == reflect.Slice || value.Kind() == reflect.Map { if value.Kind() == reflect.Slice || value.Kind() == reflect.Map {
if !create && value.IsNil() { if !createPath && value.IsNil() {
value = reflect.ValueOf(nil) value = reflect.ValueOf(nil)
} }
} }
@ -95,8 +106,8 @@ func rValuesAtPath(v interface{}, path string, create bool, caseSensitive bool)
if indexStar || index != nil { if indexStar || index != nil {
nextvals = []reflect.Value{} nextvals = []reflect.Value{}
for _, value := range values { for _, valItem := range values {
value := reflect.Indirect(value) value := reflect.Indirect(valItem)
if value.Kind() != reflect.Slice { if value.Kind() != reflect.Slice {
continue continue
} }
@ -114,7 +125,7 @@ func rValuesAtPath(v interface{}, path string, create bool, caseSensitive bool)
// pull out index // pull out index
i := int(*index) i := int(*index)
if i >= value.Len() { // check out of bounds if i >= value.Len() { // check out of bounds
if create { if createPath {
// TODO resize slice // TODO resize slice
} else { } else {
continue continue
@ -125,7 +136,7 @@ func rValuesAtPath(v interface{}, path string, create bool, caseSensitive bool)
value = reflect.Indirect(value.Index(i)) value = reflect.Indirect(value.Index(i))
if value.Kind() == reflect.Slice || value.Kind() == reflect.Map { if value.Kind() == reflect.Slice || value.Kind() == reflect.Map {
if !create && value.IsNil() { if !createPath && value.IsNil() {
value = reflect.ValueOf(nil) value = reflect.ValueOf(nil)
} }
} }
@ -142,46 +153,70 @@ func rValuesAtPath(v interface{}, path string, create bool, caseSensitive bool)
return values return values
} }
// ValuesAtPath returns a list of objects at the lexical path inside of a structure // ValuesAtPath returns a list of values at the case insensitive lexical
func ValuesAtPath(i interface{}, path string) []interface{} { // path inside of a structure.
if rvals := rValuesAtPath(i, path, false, true); rvals != nil { func ValuesAtPath(i interface{}, path string) ([]interface{}, error) {
vals := make([]interface{}, len(rvals)) result, err := jmespath.Search(path, i)
for i, rval := range rvals { if err != nil {
vals[i] = rval.Interface() return nil, err
} }
return vals
v := reflect.ValueOf(result)
if !v.IsValid() || (v.Kind() == reflect.Ptr && v.IsNil()) {
return nil, nil
} }
return nil if s, ok := result.([]interface{}); ok {
return s, err
}
if v.Kind() == reflect.Map && v.Len() == 0 {
return nil, nil
}
if v.Kind() == reflect.Slice {
out := make([]interface{}, v.Len())
for i := 0; i < v.Len(); i++ {
out[i] = v.Index(i).Interface()
}
return out, nil
}
return []interface{}{result}, nil
} }
// ValuesAtAnyPath returns a list of objects at the case-insensitive lexical // SetValueAtPath sets a value at the case insensitive lexical path inside
// path inside of a structure // of a structure.
func ValuesAtAnyPath(i interface{}, path string) []interface{} {
if rvals := rValuesAtPath(i, path, false, false); rvals != nil {
vals := make([]interface{}, len(rvals))
for i, rval := range rvals {
vals[i] = rval.Interface()
}
return vals
}
return nil
}
// SetValueAtPath sets an object at the lexical path inside of a structure
func SetValueAtPath(i interface{}, path string, v interface{}) { func SetValueAtPath(i interface{}, path string, v interface{}) {
if rvals := rValuesAtPath(i, path, true, true); rvals != nil { if rvals := rValuesAtPath(i, path, true, false, v == nil); rvals != nil {
for _, rval := range rvals { for _, rval := range rvals {
rval.Set(reflect.ValueOf(v)) if rval.Kind() == reflect.Ptr && rval.IsNil() {
continue
}
setValue(rval, v)
} }
} }
} }
// SetValueAtAnyPath sets an object at the case insensitive lexical path inside func setValue(dstVal reflect.Value, src interface{}) {
// of a structure if dstVal.Kind() == reflect.Ptr {
func SetValueAtAnyPath(i interface{}, path string, v interface{}) { dstVal = reflect.Indirect(dstVal)
if rvals := rValuesAtPath(i, path, true, false); rvals != nil {
for _, rval := range rvals {
rval.Set(reflect.ValueOf(v))
} }
srcVal := reflect.ValueOf(src)
if !srcVal.IsValid() { // src is literal nil
if dstVal.CanAddr() {
// Convert to pointer so that pointer's value can be nil'ed
// dstVal = dstVal.Addr()
} }
dstVal.Set(reflect.Zero(dstVal.Type()))
} else if srcVal.Kind() == reflect.Ptr {
if srcVal.IsNil() {
srcVal = reflect.Zero(dstVal.Type())
} else {
srcVal = reflect.ValueOf(src).Elem()
}
dstVal.Set(srcVal)
} else {
dstVal.Set(srcVal)
}
} }

View file

@ -61,6 +61,12 @@ func prettify(v reflect.Value, indent int, buf *bytes.Buffer) {
buf.WriteString("\n" + strings.Repeat(" ", indent) + "}") buf.WriteString("\n" + strings.Repeat(" ", indent) + "}")
case reflect.Slice: case reflect.Slice:
strtype := v.Type().String()
if strtype == "[]uint8" {
fmt.Fprintf(buf, "<binary> len %d", v.Len())
break
}
nl, id, id2 := "", "", "" nl, id, id2 := "", "", ""
if v.Len() > 3 { if v.Len() > 3 {
nl, id, id2 = "\n", strings.Repeat(" ", indent), strings.Repeat(" ", indent+2) nl, id, id2 = "\n", strings.Repeat(" ", indent), strings.Repeat(" ", indent+2)
@ -91,6 +97,10 @@ func prettify(v reflect.Value, indent int, buf *bytes.Buffer) {
buf.WriteString("\n" + strings.Repeat(" ", indent) + "}") buf.WriteString("\n" + strings.Repeat(" ", indent) + "}")
default: default:
if !v.IsValid() {
fmt.Fprint(buf, "<invalid value>")
return
}
format := "%v" format := "%v"
switch v.Interface().(type) { switch v.Interface().(type) {
case string: case string:

View file

@ -0,0 +1,89 @@
package awsutil
import (
"bytes"
"fmt"
"reflect"
"strings"
)
// StringValue returns the string representation of a value.
func StringValue(i interface{}) string {
var buf bytes.Buffer
stringValue(reflect.ValueOf(i), 0, &buf)
return buf.String()
}
func stringValue(v reflect.Value, indent int, buf *bytes.Buffer) {
for v.Kind() == reflect.Ptr {
v = v.Elem()
}
switch v.Kind() {
case reflect.Struct:
buf.WriteString("{\n")
names := []string{}
for i := 0; i < v.Type().NumField(); i++ {
name := v.Type().Field(i).Name
f := v.Field(i)
if name[0:1] == strings.ToLower(name[0:1]) {
continue // ignore unexported fields
}
if (f.Kind() == reflect.Ptr || f.Kind() == reflect.Slice) && f.IsNil() {
continue // ignore unset fields
}
names = append(names, name)
}
for i, n := range names {
val := v.FieldByName(n)
buf.WriteString(strings.Repeat(" ", indent+2))
buf.WriteString(n + ": ")
stringValue(val, indent+2, buf)
if i < len(names)-1 {
buf.WriteString(",\n")
}
}
buf.WriteString("\n" + strings.Repeat(" ", indent) + "}")
case reflect.Slice:
nl, id, id2 := "", "", ""
if v.Len() > 3 {
nl, id, id2 = "\n", strings.Repeat(" ", indent), strings.Repeat(" ", indent+2)
}
buf.WriteString("[" + nl)
for i := 0; i < v.Len(); i++ {
buf.WriteString(id2)
stringValue(v.Index(i), indent+2, buf)
if i < v.Len()-1 {
buf.WriteString("," + nl)
}
}
buf.WriteString(nl + id + "]")
case reflect.Map:
buf.WriteString("{\n")
for i, k := range v.MapKeys() {
buf.WriteString(strings.Repeat(" ", indent+2))
buf.WriteString(k.String() + ": ")
stringValue(v.MapIndex(k), indent+2, buf)
if i < v.Len()-1 {
buf.WriteString(",\n")
}
}
buf.WriteString("\n" + strings.Repeat(" ", indent) + "}")
default:
format := "%v"
switch v.Interface().(type) {
case string:
format = "%q"
}
fmt.Fprintf(buf, format, v.Interface())
}
}

90
vendor/github.com/aws/aws-sdk-go/aws/client/client.go generated vendored Normal file
View file

@ -0,0 +1,90 @@
package client
import (
"fmt"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/client/metadata"
"github.com/aws/aws-sdk-go/aws/request"
)
// A Config provides configuration to a service client instance.
type Config struct {
Config *aws.Config
Handlers request.Handlers
Endpoint string
SigningRegion string
SigningName string
}
// ConfigProvider provides a generic way for a service client to receive
// the ClientConfig without circular dependencies.
type ConfigProvider interface {
ClientConfig(serviceName string, cfgs ...*aws.Config) Config
}
// ConfigNoResolveEndpointProvider same as ConfigProvider except it will not
// resolve the endpoint automatically. The service client's endpoint must be
// provided via the aws.Config.Endpoint field.
type ConfigNoResolveEndpointProvider interface {
ClientConfigNoResolveEndpoint(cfgs ...*aws.Config) Config
}
// A Client implements the base client request and response handling
// used by all service clients.
type Client struct {
request.Retryer
metadata.ClientInfo
Config aws.Config
Handlers request.Handlers
}
// New will return a pointer to a new initialized service client.
func New(cfg aws.Config, info metadata.ClientInfo, handlers request.Handlers, options ...func(*Client)) *Client {
svc := &Client{
Config: cfg,
ClientInfo: info,
Handlers: handlers.Copy(),
}
switch retryer, ok := cfg.Retryer.(request.Retryer); {
case ok:
svc.Retryer = retryer
case cfg.Retryer != nil && cfg.Logger != nil:
s := fmt.Sprintf("WARNING: %T does not implement request.Retryer; using DefaultRetryer instead", cfg.Retryer)
cfg.Logger.Log(s)
fallthrough
default:
maxRetries := aws.IntValue(cfg.MaxRetries)
if cfg.MaxRetries == nil || maxRetries == aws.UseServiceDefaultRetries {
maxRetries = 3
}
svc.Retryer = DefaultRetryer{NumMaxRetries: maxRetries}
}
svc.AddDebugHandlers()
for _, option := range options {
option(svc)
}
return svc
}
// NewRequest returns a new Request pointer for the service API
// operation and parameters.
func (c *Client) NewRequest(operation *request.Operation, params interface{}, data interface{}) *request.Request {
return request.New(c.Config, c.ClientInfo, c.Handlers, c.Retryer, operation, params, data)
}
// AddDebugHandlers injects debug logging handlers into the service to log request
// debug information.
func (c *Client) AddDebugHandlers() {
if !c.Config.LogLevel.AtLeast(aws.LogDebug) {
return
}
c.Handlers.Send.PushFrontNamed(request.NamedHandler{Name: "awssdk.client.LogRequest", Fn: logRequest})
c.Handlers.Send.PushBackNamed(request.NamedHandler{Name: "awssdk.client.LogResponse", Fn: logResponse})
}

View file

@ -0,0 +1,96 @@
package client
import (
"math/rand"
"sync"
"time"
"github.com/aws/aws-sdk-go/aws/request"
)
// DefaultRetryer implements basic retry logic using exponential backoff for
// most services. If you want to implement custom retry logic, implement the
// request.Retryer interface or create a structure type that composes this
// struct and override the specific methods. For example, to override only
// the MaxRetries method:
//
// type retryer struct {
// service.DefaultRetryer
// }
//
// // This implementation always has 100 max retries
// func (d retryer) MaxRetries() uint { return 100 }
type DefaultRetryer struct {
NumMaxRetries int
}
// MaxRetries returns the number of maximum returns the service will use to make
// an individual API request.
func (d DefaultRetryer) MaxRetries() int {
return d.NumMaxRetries
}
var seededRand = rand.New(&lockedSource{src: rand.NewSource(time.Now().UnixNano())})
// RetryRules returns the delay duration before retrying this request again
func (d DefaultRetryer) RetryRules(r *request.Request) time.Duration {
// Set the upper limit of delay in retrying at ~five minutes
minTime := 30
throttle := d.shouldThrottle(r)
if throttle {
minTime = 500
}
retryCount := r.RetryCount
if retryCount > 13 {
retryCount = 13
} else if throttle && retryCount > 8 {
retryCount = 8
}
delay := (1 << uint(retryCount)) * (seededRand.Intn(minTime) + minTime)
return time.Duration(delay) * time.Millisecond
}
// ShouldRetry returns true if the request should be retried.
func (d DefaultRetryer) ShouldRetry(r *request.Request) bool {
// If one of the other handlers already set the retry state
// we don't want to override it based on the service's state
if r.Retryable != nil {
return *r.Retryable
}
if r.HTTPResponse.StatusCode >= 500 {
return true
}
return r.IsErrorRetryable() || d.shouldThrottle(r)
}
// ShouldThrottle returns true if the request should be throttled.
func (d DefaultRetryer) shouldThrottle(r *request.Request) bool {
if r.HTTPResponse.StatusCode == 502 ||
r.HTTPResponse.StatusCode == 503 ||
r.HTTPResponse.StatusCode == 504 {
return true
}
return r.IsErrorThrottle()
}
// lockedSource is a thread-safe implementation of rand.Source
type lockedSource struct {
lk sync.Mutex
src rand.Source
}
func (r *lockedSource) Int63() (n int64) {
r.lk.Lock()
n = r.src.Int63()
r.lk.Unlock()
return
}
func (r *lockedSource) Seed(seed int64) {
r.lk.Lock()
r.src.Seed(seed)
r.lk.Unlock()
}

108
vendor/github.com/aws/aws-sdk-go/aws/client/logger.go generated vendored Normal file
View file

@ -0,0 +1,108 @@
package client
import (
"bytes"
"fmt"
"io"
"io/ioutil"
"net/http/httputil"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/request"
)
const logReqMsg = `DEBUG: Request %s/%s Details:
---[ REQUEST POST-SIGN ]-----------------------------
%s
-----------------------------------------------------`
const logReqErrMsg = `DEBUG ERROR: Request %s/%s:
---[ REQUEST DUMP ERROR ]-----------------------------
%s
------------------------------------------------------`
type logWriter struct {
// Logger is what we will use to log the payload of a response.
Logger aws.Logger
// buf stores the contents of what has been read
buf *bytes.Buffer
}
func (logger *logWriter) Write(b []byte) (int, error) {
return logger.buf.Write(b)
}
type teeReaderCloser struct {
// io.Reader will be a tee reader that is used during logging.
// This structure will read from a body and write the contents to a logger.
io.Reader
// Source is used just to close when we are done reading.
Source io.ReadCloser
}
func (reader *teeReaderCloser) Close() error {
return reader.Source.Close()
}
func logRequest(r *request.Request) {
logBody := r.Config.LogLevel.Matches(aws.LogDebugWithHTTPBody)
dumpedBody, err := httputil.DumpRequestOut(r.HTTPRequest, logBody)
if err != nil {
r.Config.Logger.Log(fmt.Sprintf(logReqErrMsg, r.ClientInfo.ServiceName, r.Operation.Name, err))
return
}
if logBody {
// Reset the request body because dumpRequest will re-wrap the r.HTTPRequest's
// Body as a NoOpCloser and will not be reset after read by the HTTP
// client reader.
r.ResetBody()
}
r.Config.Logger.Log(fmt.Sprintf(logReqMsg, r.ClientInfo.ServiceName, r.Operation.Name, string(dumpedBody)))
}
const logRespMsg = `DEBUG: Response %s/%s Details:
---[ RESPONSE ]--------------------------------------
%s
-----------------------------------------------------`
const logRespErrMsg = `DEBUG ERROR: Response %s/%s:
---[ RESPONSE DUMP ERROR ]-----------------------------
%s
-----------------------------------------------------`
func logResponse(r *request.Request) {
lw := &logWriter{r.Config.Logger, bytes.NewBuffer(nil)}
r.HTTPResponse.Body = &teeReaderCloser{
Reader: io.TeeReader(r.HTTPResponse.Body, lw),
Source: r.HTTPResponse.Body,
}
handlerFn := func(req *request.Request) {
body, err := httputil.DumpResponse(req.HTTPResponse, false)
if err != nil {
lw.Logger.Log(fmt.Sprintf(logRespErrMsg, req.ClientInfo.ServiceName, req.Operation.Name, err))
return
}
b, err := ioutil.ReadAll(lw.buf)
if err != nil {
lw.Logger.Log(fmt.Sprintf(logRespErrMsg, req.ClientInfo.ServiceName, req.Operation.Name, err))
return
}
lw.Logger.Log(fmt.Sprintf(logRespMsg, req.ClientInfo.ServiceName, req.Operation.Name, string(body)))
if req.Config.LogLevel.Matches(aws.LogDebugWithHTTPBody) {
lw.Logger.Log(string(b))
}
}
const handlerName = "awsdk.client.LogResponse.ResponseBody"
r.Handlers.Unmarshal.SetBackNamed(request.NamedHandler{
Name: handlerName, Fn: handlerFn,
})
r.Handlers.UnmarshalError.SetBackNamed(request.NamedHandler{
Name: handlerName, Fn: handlerFn,
})
}

View file

@ -0,0 +1,12 @@
package metadata
// ClientInfo wraps immutable data from the client.Client structure.
type ClientInfo struct {
ServiceName string
APIVersion string
Endpoint string
SigningName string
SigningRegion string
JSONVersion string
TargetPrefix string
}

470
vendor/github.com/aws/aws-sdk-go/aws/config.go generated vendored Normal file
View file

@ -0,0 +1,470 @@
package aws
import (
"net/http"
"time"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/endpoints"
)
// UseServiceDefaultRetries instructs the config to use the service's own
// default number of retries. This will be the default action if
// Config.MaxRetries is nil also.
const UseServiceDefaultRetries = -1
// RequestRetryer is an alias for a type that implements the request.Retryer
// interface.
type RequestRetryer interface{}
// A Config provides service configuration for service clients. By default,
// all clients will use the defaults.DefaultConfig tructure.
//
// // Create Session with MaxRetry configuration to be shared by multiple
// // service clients.
// sess := session.Must(session.NewSession(&aws.Config{
// MaxRetries: aws.Int(3),
// }))
//
// // Create S3 service client with a specific Region.
// svc := s3.New(sess, &aws.Config{
// Region: aws.String("us-west-2"),
// })
type Config struct {
// Enables verbose error printing of all credential chain errors.
// Should be used when wanting to see all errors while attempting to
// retrieve credentials.
CredentialsChainVerboseErrors *bool
// The credentials object to use when signing requests. Defaults to a
// chain of credential providers to search for credentials in environment
// variables, shared credential file, and EC2 Instance Roles.
Credentials *credentials.Credentials
// An optional endpoint URL (hostname only or fully qualified URI)
// that overrides the default generated endpoint for a client. Set this
// to `""` to use the default generated endpoint.
//
// @note You must still provide a `Region` value when specifying an
// endpoint for a client.
Endpoint *string
// The resolver to use for looking up endpoints for AWS service clients
// to use based on region.
EndpointResolver endpoints.Resolver
// EnforceShouldRetryCheck is used in the AfterRetryHandler to always call
// ShouldRetry regardless of whether or not if request.Retryable is set.
// This will utilize ShouldRetry method of custom retryers. If EnforceShouldRetryCheck
// is not set, then ShouldRetry will only be called if request.Retryable is nil.
// Proper handling of the request.Retryable field is important when setting this field.
EnforceShouldRetryCheck *bool
// The region to send requests to. This parameter is required and must
// be configured globally or on a per-client basis unless otherwise
// noted. A full list of regions is found in the "Regions and Endpoints"
// document.
//
// @see http://docs.aws.amazon.com/general/latest/gr/rande.html
// AWS Regions and Endpoints
Region *string
// Set this to `true` to disable SSL when sending requests. Defaults
// to `false`.
DisableSSL *bool
// The HTTP client to use when sending requests. Defaults to
// `http.DefaultClient`.
HTTPClient *http.Client
// An integer value representing the logging level. The default log level
// is zero (LogOff), which represents no logging. To enable logging set
// to a LogLevel Value.
LogLevel *LogLevelType
// The logger writer interface to write logging messages to. Defaults to
// standard out.
Logger Logger
// The maximum number of times that a request will be retried for failures.
// Defaults to -1, which defers the max retry setting to the service
// specific configuration.
MaxRetries *int
// Retryer guides how HTTP requests should be retried in case of
// recoverable failures.
//
// When nil or the value does not implement the request.Retryer interface,
// the request.DefaultRetryer will be used.
//
// When both Retryer and MaxRetries are non-nil, the former is used and
// the latter ignored.
//
// To set the Retryer field in a type-safe manner and with chaining, use
// the request.WithRetryer helper function:
//
// cfg := request.WithRetryer(aws.NewConfig(), myRetryer)
//
Retryer RequestRetryer
// Disables semantic parameter validation, which validates input for
// missing required fields and/or other semantic request input errors.
DisableParamValidation *bool
// Disables the computation of request and response checksums, e.g.,
// CRC32 checksums in Amazon DynamoDB.
DisableComputeChecksums *bool
// Set this to `true` to force the request to use path-style addressing,
// i.e., `http://s3.amazonaws.com/BUCKET/KEY`. By default, the S3 client
// will use virtual hosted bucket addressing when possible
// (`http://BUCKET.s3.amazonaws.com/KEY`).
//
// @note This configuration option is specific to the Amazon S3 service.
// @see http://docs.aws.amazon.com/AmazonS3/latest/dev/VirtualHosting.html
// Amazon S3: Virtual Hosting of Buckets
S3ForcePathStyle *bool
// Set this to `true` to disable the SDK adding the `Expect: 100-Continue`
// header to PUT requests over 2MB of content. 100-Continue instructs the
// HTTP client not to send the body until the service responds with a
// `continue` status. This is useful to prevent sending the request body
// until after the request is authenticated, and validated.
//
// http://docs.aws.amazon.com/AmazonS3/latest/API/RESTObjectPUT.html
//
// 100-Continue is only enabled for Go 1.6 and above. See `http.Transport`'s
// `ExpectContinueTimeout` for information on adjusting the continue wait
// timeout. https://golang.org/pkg/net/http/#Transport
//
// You should use this flag to disble 100-Continue if you experience issues
// with proxies or third party S3 compatible services.
S3Disable100Continue *bool
// Set this to `true` to enable S3 Accelerate feature. For all operations
// compatible with S3 Accelerate will use the accelerate endpoint for
// requests. Requests not compatible will fall back to normal S3 requests.
//
// The bucket must be enable for accelerate to be used with S3 client with
// accelerate enabled. If the bucket is not enabled for accelerate an error
// will be returned. The bucket name must be DNS compatible to also work
// with accelerate.
S3UseAccelerate *bool
// Set this to `true` to disable the EC2Metadata client from overriding the
// default http.Client's Timeout. This is helpful if you do not want the
// EC2Metadata client to create a new http.Client. This options is only
// meaningful if you're not already using a custom HTTP client with the
// SDK. Enabled by default.
//
// Must be set and provided to the session.NewSession() in order to disable
// the EC2Metadata overriding the timeout for default credentials chain.
//
// Example:
// sess := session.Must(session.NewSession(aws.NewConfig()
// .WithEC2MetadataDiableTimeoutOverride(true)))
//
// svc := s3.New(sess)
//
EC2MetadataDisableTimeoutOverride *bool
// Instructs the endpiont to be generated for a service client to
// be the dual stack endpoint. The dual stack endpoint will support
// both IPv4 and IPv6 addressing.
//
// Setting this for a service which does not support dual stack will fail
// to make requets. It is not recommended to set this value on the session
// as it will apply to all service clients created with the session. Even
// services which don't support dual stack endpoints.
//
// If the Endpoint config value is also provided the UseDualStack flag
// will be ignored.
//
// Only supported with.
//
// sess := session.Must(session.NewSession())
//
// svc := s3.New(sess, &aws.Config{
// UseDualStack: aws.Bool(true),
// })
UseDualStack *bool
// SleepDelay is an override for the func the SDK will call when sleeping
// during the lifecycle of a request. Specifically this will be used for
// request delays. This value should only be used for testing. To adjust
// the delay of a request see the aws/client.DefaultRetryer and
// aws/request.Retryer.
//
// SleepDelay will prevent any Context from being used for canceling retry
// delay of an API operation. It is recommended to not use SleepDelay at all
// and specify a Retryer instead.
SleepDelay func(time.Duration)
// DisableRestProtocolURICleaning will not clean the URL path when making rest protocol requests.
// Will default to false. This would only be used for empty directory names in s3 requests.
//
// Example:
// sess := session.Must(session.NewSession(&aws.Config{
// DisableRestProtocolURICleaning: aws.Bool(true),
// }))
//
// svc := s3.New(sess)
// out, err := svc.GetObject(&s3.GetObjectInput {
// Bucket: aws.String("bucketname"),
// Key: aws.String("//foo//bar//moo"),
// })
DisableRestProtocolURICleaning *bool
}
// NewConfig returns a new Config pointer that can be chained with builder
// methods to set multiple configuration values inline without using pointers.
//
// // Create Session with MaxRetry configuration to be shared by multiple
// // service clients.
// sess := session.Must(session.NewSession(aws.NewConfig().
// WithMaxRetries(3),
// ))
//
// // Create S3 service client with a specific Region.
// svc := s3.New(sess, aws.NewConfig().
// WithRegion("us-west-2"),
// )
func NewConfig() *Config {
return &Config{}
}
// WithCredentialsChainVerboseErrors sets a config verbose errors boolean and returning
// a Config pointer.
func (c *Config) WithCredentialsChainVerboseErrors(verboseErrs bool) *Config {
c.CredentialsChainVerboseErrors = &verboseErrs
return c
}
// WithCredentials sets a config Credentials value returning a Config pointer
// for chaining.
func (c *Config) WithCredentials(creds *credentials.Credentials) *Config {
c.Credentials = creds
return c
}
// WithEndpoint sets a config Endpoint value returning a Config pointer for
// chaining.
func (c *Config) WithEndpoint(endpoint string) *Config {
c.Endpoint = &endpoint
return c
}
// WithEndpointResolver sets a config EndpointResolver value returning a
// Config pointer for chaining.
func (c *Config) WithEndpointResolver(resolver endpoints.Resolver) *Config {
c.EndpointResolver = resolver
return c
}
// WithRegion sets a config Region value returning a Config pointer for
// chaining.
func (c *Config) WithRegion(region string) *Config {
c.Region = &region
return c
}
// WithDisableSSL sets a config DisableSSL value returning a Config pointer
// for chaining.
func (c *Config) WithDisableSSL(disable bool) *Config {
c.DisableSSL = &disable
return c
}
// WithHTTPClient sets a config HTTPClient value returning a Config pointer
// for chaining.
func (c *Config) WithHTTPClient(client *http.Client) *Config {
c.HTTPClient = client
return c
}
// WithMaxRetries sets a config MaxRetries value returning a Config pointer
// for chaining.
func (c *Config) WithMaxRetries(max int) *Config {
c.MaxRetries = &max
return c
}
// WithDisableParamValidation sets a config DisableParamValidation value
// returning a Config pointer for chaining.
func (c *Config) WithDisableParamValidation(disable bool) *Config {
c.DisableParamValidation = &disable
return c
}
// WithDisableComputeChecksums sets a config DisableComputeChecksums value
// returning a Config pointer for chaining.
func (c *Config) WithDisableComputeChecksums(disable bool) *Config {
c.DisableComputeChecksums = &disable
return c
}
// WithLogLevel sets a config LogLevel value returning a Config pointer for
// chaining.
func (c *Config) WithLogLevel(level LogLevelType) *Config {
c.LogLevel = &level
return c
}
// WithLogger sets a config Logger value returning a Config pointer for
// chaining.
func (c *Config) WithLogger(logger Logger) *Config {
c.Logger = logger
return c
}
// WithS3ForcePathStyle sets a config S3ForcePathStyle value returning a Config
// pointer for chaining.
func (c *Config) WithS3ForcePathStyle(force bool) *Config {
c.S3ForcePathStyle = &force
return c
}
// WithS3Disable100Continue sets a config S3Disable100Continue value returning
// a Config pointer for chaining.
func (c *Config) WithS3Disable100Continue(disable bool) *Config {
c.S3Disable100Continue = &disable
return c
}
// WithS3UseAccelerate sets a config S3UseAccelerate value returning a Config
// pointer for chaining.
func (c *Config) WithS3UseAccelerate(enable bool) *Config {
c.S3UseAccelerate = &enable
return c
}
// WithUseDualStack sets a config UseDualStack value returning a Config
// pointer for chaining.
func (c *Config) WithUseDualStack(enable bool) *Config {
c.UseDualStack = &enable
return c
}
// WithEC2MetadataDisableTimeoutOverride sets a config EC2MetadataDisableTimeoutOverride value
// returning a Config pointer for chaining.
func (c *Config) WithEC2MetadataDisableTimeoutOverride(enable bool) *Config {
c.EC2MetadataDisableTimeoutOverride = &enable
return c
}
// WithSleepDelay overrides the function used to sleep while waiting for the
// next retry. Defaults to time.Sleep.
func (c *Config) WithSleepDelay(fn func(time.Duration)) *Config {
c.SleepDelay = fn
return c
}
// MergeIn merges the passed in configs into the existing config object.
func (c *Config) MergeIn(cfgs ...*Config) {
for _, other := range cfgs {
mergeInConfig(c, other)
}
}
func mergeInConfig(dst *Config, other *Config) {
if other == nil {
return
}
if other.CredentialsChainVerboseErrors != nil {
dst.CredentialsChainVerboseErrors = other.CredentialsChainVerboseErrors
}
if other.Credentials != nil {
dst.Credentials = other.Credentials
}
if other.Endpoint != nil {
dst.Endpoint = other.Endpoint
}
if other.EndpointResolver != nil {
dst.EndpointResolver = other.EndpointResolver
}
if other.Region != nil {
dst.Region = other.Region
}
if other.DisableSSL != nil {
dst.DisableSSL = other.DisableSSL
}
if other.HTTPClient != nil {
dst.HTTPClient = other.HTTPClient
}
if other.LogLevel != nil {
dst.LogLevel = other.LogLevel
}
if other.Logger != nil {
dst.Logger = other.Logger
}
if other.MaxRetries != nil {
dst.MaxRetries = other.MaxRetries
}
if other.Retryer != nil {
dst.Retryer = other.Retryer
}
if other.DisableParamValidation != nil {
dst.DisableParamValidation = other.DisableParamValidation
}
if other.DisableComputeChecksums != nil {
dst.DisableComputeChecksums = other.DisableComputeChecksums
}
if other.S3ForcePathStyle != nil {
dst.S3ForcePathStyle = other.S3ForcePathStyle
}
if other.S3Disable100Continue != nil {
dst.S3Disable100Continue = other.S3Disable100Continue
}
if other.S3UseAccelerate != nil {
dst.S3UseAccelerate = other.S3UseAccelerate
}
if other.UseDualStack != nil {
dst.UseDualStack = other.UseDualStack
}
if other.EC2MetadataDisableTimeoutOverride != nil {
dst.EC2MetadataDisableTimeoutOverride = other.EC2MetadataDisableTimeoutOverride
}
if other.SleepDelay != nil {
dst.SleepDelay = other.SleepDelay
}
if other.DisableRestProtocolURICleaning != nil {
dst.DisableRestProtocolURICleaning = other.DisableRestProtocolURICleaning
}
if other.EnforceShouldRetryCheck != nil {
dst.EnforceShouldRetryCheck = other.EnforceShouldRetryCheck
}
}
// Copy will return a shallow copy of the Config object. If any additional
// configurations are provided they will be merged into the new config returned.
func (c *Config) Copy(cfgs ...*Config) *Config {
dst := &Config{}
dst.MergeIn(c)
for _, cfg := range cfgs {
dst.MergeIn(cfg)
}
return dst
}

71
vendor/github.com/aws/aws-sdk-go/aws/context.go generated vendored Normal file
View file

@ -0,0 +1,71 @@
package aws
import (
"time"
)
// Context is an copy of the Go v1.7 stdlib's context.Context interface.
// It is represented as a SDK interface to enable you to use the "WithContext"
// API methods with Go v1.6 and a Context type such as golang.org/x/net/context.
//
// See https://golang.org/pkg/context on how to use contexts.
type Context interface {
// Deadline returns the time when work done on behalf of this context
// should be canceled. Deadline returns ok==false when no deadline is
// set. Successive calls to Deadline return the same results.
Deadline() (deadline time.Time, ok bool)
// Done returns a channel that's closed when work done on behalf of this
// context should be canceled. Done may return nil if this context can
// never be canceled. Successive calls to Done return the same value.
Done() <-chan struct{}
// Err returns a non-nil error value after Done is closed. Err returns
// Canceled if the context was canceled or DeadlineExceeded if the
// context's deadline passed. No other values for Err are defined.
// After Done is closed, successive calls to Err return the same value.
Err() error
// Value returns the value associated with this context for key, or nil
// if no value is associated with key. Successive calls to Value with
// the same key returns the same result.
//
// Use context values only for request-scoped data that transits
// processes and API boundaries, not for passing optional parameters to
// functions.
Value(key interface{}) interface{}
}
// BackgroundContext returns a context that will never be canceled, has no
// values, and no deadline. This context is used by the SDK to provide
// backwards compatibility with non-context API operations and functionality.
//
// Go 1.6 and before:
// This context function is equivalent to context.Background in the Go stdlib.
//
// Go 1.7 and later:
// The context returned will be the value returned by context.Background()
//
// See https://golang.org/pkg/context for more information on Contexts.
func BackgroundContext() Context {
return backgroundCtx
}
// SleepWithContext will wait for the timer duration to expire, or the context
// is canceled. Which ever happens first. If the context is canceled the Context's
// error will be returned.
//
// Expects Context to always return a non-nil error if the Done channel is closed.
func SleepWithContext(ctx Context, dur time.Duration) error {
t := time.NewTimer(dur)
defer t.Stop()
select {
case <-t.C:
break
case <-ctx.Done():
return ctx.Err()
}
return nil
}

41
vendor/github.com/aws/aws-sdk-go/aws/context_1_6.go generated vendored Normal file
View file

@ -0,0 +1,41 @@
// +build !go1.7
package aws
import "time"
// An emptyCtx is a copy of the the Go 1.7 context.emptyCtx type. This
// is copied to provide a 1.6 and 1.5 safe version of context that is compatible
// with Go 1.7's Context.
//
// An emptyCtx is never canceled, has no values, and has no deadline. It is not
// struct{}, since vars of this type must have distinct addresses.
type emptyCtx int
func (*emptyCtx) Deadline() (deadline time.Time, ok bool) {
return
}
func (*emptyCtx) Done() <-chan struct{} {
return nil
}
func (*emptyCtx) Err() error {
return nil
}
func (*emptyCtx) Value(key interface{}) interface{} {
return nil
}
func (e *emptyCtx) String() string {
switch e {
case backgroundCtx:
return "aws.BackgroundContext"
}
return "unknown empty Context"
}
var (
backgroundCtx = new(emptyCtx)
)

Some files were not shown because too many files have changed in this diff Show more